From a63ad3305651a6f43b588132213b285210a7f0ba Mon Sep 17 00:00:00 2001 From: caixilong Date: Thu, 28 Aug 2025 16:40:41 +0800 Subject: [PATCH 01/74] add heap test cases --- .../host/mem/shmem_host_heap_test.cpp | 187 ++++++++++++++++++ 1 file changed, 187 insertions(+) diff --git a/tests/unittest/host/mem/shmem_host_heap_test.cpp b/tests/unittest/host/mem/shmem_host_heap_test.cpp index fda5acda..f26cbbbe 100644 --- a/tests/unittest/host/mem/shmem_host_heap_test.cpp +++ b/tests/unittest/host/mem/shmem_host_heap_test.cpp @@ -192,6 +192,32 @@ TEST_F(ShareMemoryManagerTest, calloc_large_memory_failed) local_mem_size, process_count); } +TEST_F(ShareMemoryManagerTest, calloc_multiply_overflow_size_t_max) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + const size_t nmemb = static_cast(~0ULL); + const size_t each = 2; + + void *p = shmem_calloc(nmemb, each); + EXPECT_EQ(nullptr, p); + + void *ok = shmem_malloc(4096UL); + EXPECT_NE(nullptr, ok); + shmem_free(ok); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + TEST_F(ShareMemoryManagerTest, align_zero) { const int process_count = test_gnpu_num; @@ -281,6 +307,167 @@ TEST_F(ShareMemoryManagerTest, align_not_two_power_failed) local_mem_size, process_count); } +TEST_F(ShareMemoryManagerTest, stress_malloc_calloc_align_no_leak) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + constexpr int rounds = 500; + std::vector ptrs; + ptrs.reserve(rounds * 3); + + for (int i = 0; i < rounds; ++i) { + void *p1 = shmem_malloc(1024UL + (i % 7) * 128UL); + EXPECT_NE(nullptr, p1); + ptrs.push_back(p1); + + void *p2 = shmem_calloc(32, 16 + (i % 5)); + EXPECT_NE(nullptr, p2); + ptrs.push_back(p2); + + void *p3 = shmem_align(64, 1536UL + (i % 3) * 64UL); + EXPECT_NE(nullptr, p3); + ptrs.push_back(p3); + + if ((i % 4) == 0) { + shmem_free(p1); + ptrs[ptrs.size()-3] = nullptr; + } + if ((i % 6) == 0) { + shmem_free(p2); + ptrs[ptrs.size()-2] = nullptr; + } + } + + for (void *p : ptrs) { + if (p) shmem_free(p); + } + + void *big = shmem_malloc(heap_memory_size / 2); + EXPECT_NE(nullptr, big); + shmem_free(big); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + +TEST_F(ShareMemoryManagerTest, calls_before_init_and_after_finalize) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + + EXPECT_EQ(nullptr, shmem_malloc(1024UL)); + EXPECT_EQ(nullptr, shmem_calloc(4, 256)); + EXPECT_EQ(nullptr, shmem_align(64, 4096UL)); + shmem_free(nullptr); + + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + void *ok = shmem_malloc(2048UL); + EXPECT_NE(nullptr, ok); + shmem_free(ok); + + test_finalize(stream, device_id); + + EXPECT_EQ(nullptr, shmem_malloc(1024UL)); + EXPECT_EQ(nullptr, shmem_calloc(2, 512)); + EXPECT_EQ(nullptr, shmem_align(32, 1024UL)); + shmem_free(nullptr); + }, + local_mem_size, process_count); +} + +TEST_F(ShareMemoryManagerTest, free_nullptr_is_noop) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + shmem_free(nullptr); + + void *p = shmem_malloc(8192UL); + EXPECT_NE(nullptr, p); + shmem_free(p); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + +TEST_F(ShareMemoryManagerTest, double_free_should_not_corrupt_heap) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + size_t sz = 64UL * 1024UL; + void *p = shmem_malloc(sz); + ASSERT_NE(nullptr, p); + + shmem_free(p); + shmem_free(p); + + void *q = shmem_malloc(sz); + EXPECT_NE(nullptr, q); + shmem_free(q); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + +TEST_F(ShareMemoryManagerTest, free_middle_pointer_should_not_work_and_not_corrupt) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + size_t sz = 128UL * 1024UL; + uint8_t *base = static_cast(shmem_malloc(sz)); + ASSERT_NE(nullptr, base); + + void *middle = base + 64; + + shmem_free(middle); + + shmem_free(base); + + void *again = shmem_malloc(sz); + EXPECT_NE(nullptr, again); + shmem_free(again); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + TEST_F(ShareMemoryManagerTest, free_merge) { const int process_count = test_gnpu_num; -- Gitee From f845f913cdc9f25b1821d46849b80a2d2013c065 Mon Sep 17 00:00:00 2001 From: caixilong Date: Thu, 28 Aug 2025 19:26:44 +0800 Subject: [PATCH 02/74] fix align test bug --- tests/unittest/host/mem/shmem_host_heap_test.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittest/host/mem/shmem_host_heap_test.cpp b/tests/unittest/host/mem/shmem_host_heap_test.cpp index f26cbbbe..9df0d7d8 100644 --- a/tests/unittest/host/mem/shmem_host_heap_test.cpp +++ b/tests/unittest/host/mem/shmem_host_heap_test.cpp @@ -230,7 +230,6 @@ TEST_F(ShareMemoryManagerTest, align_zero) const size_t alignment = 16; auto ptr = shmem_align(alignment, 0UL); EXPECT_EQ(nullptr, ptr); - EXPECT_EQ(reinterpret_cast(ptr) & alignment, 0u); test_finalize(stream, device_id); }, local_mem_size, process_count); @@ -249,7 +248,7 @@ TEST_F(ShareMemoryManagerTest, align_one_piece_success) const size_t size = 128UL; auto ptr = shmem_align(alignment, size); EXPECT_NE(nullptr, ptr); - EXPECT_EQ(reinterpret_cast(ptr) & alignment, 0u); + EXPECT_EQ(reinterpret_cast(ptr) & (alignment - 1), 0u); test_finalize(stream, device_id); }, local_mem_size, process_count); @@ -267,6 +266,7 @@ TEST_F(ShareMemoryManagerTest, align_full_space_success) const size_t alignment = 16; auto ptr = shmem_align(alignment, heap_memory_size); EXPECT_NE(nullptr, ptr); + EXPECT_EQ(reinterpret_cast(ptr) & (alignment - 1), 0u); test_finalize(stream, device_id); }, local_mem_size, process_count); -- Gitee From ae2abc66ff7bfe10581f5a3f878ce2c7f7626248 Mon Sep 17 00:00:00 2001 From: caixilong Date: Mon, 1 Sep 2025 10:11:35 +0800 Subject: [PATCH 03/74] fix ci --- .gitmodules | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.gitmodules b/.gitmodules index 1064fb85..1dfda128 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,12 @@ [submodule "3rdparty/memfabric_hybrid"] path = 3rdparty/memfabric_hybrid url = https://gitee.com/ascend/memfabric_hybrid.git + branch = br_release_shmem_1.0 +[submodule "3rdparty/catlass"] + path = 3rdparty/catlass + url = https://gitee.com/ascend/catlass.git +[submodule "3rdparty/googletest"] + path = 3rdparty/googletest + url = https://gitee.com/mirrors/googletest.git + branch = v1.14.x + shallow = true # depth=1 -- Gitee From 18b5294c3aeb76a80e104a722fb8dd7957cf513e Mon Sep 17 00:00:00 2001 From: caixilong Date: Tue, 2 Sep 2025 09:47:23 +0800 Subject: [PATCH 04/74] change googletest clone link --- scripts/build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/build.sh b/scripts/build.sh index bdcf7452..cb8ca06b 100644 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -104,7 +104,7 @@ function fn_build_googletest() return 0 fi cd $THIRD_PARTY_DIR - [[ ! -d "googletest" ]] && git clone --branch v1.14.0 --depth 1 https://github.com/google/googletest.git + [[ ! -d "googletest" ]] && git clone --branch v1.14.0 --depth 1 https://gitee.com/mirrors/googletest.git cd googletest rm -rf build && mkdir build && cd build -- Gitee From 998d6f131fc55ed655ae61bf22e7a9fe273777a5 Mon Sep 17 00:00:00 2001 From: caixilong Date: Thu, 28 Aug 2025 16:40:41 +0800 Subject: [PATCH 05/74] add heap test cases --- .../host/mem/shmem_host_heap_test.cpp | 187 ++++++++++++++++++ 1 file changed, 187 insertions(+) diff --git a/tests/unittest/host/mem/shmem_host_heap_test.cpp b/tests/unittest/host/mem/shmem_host_heap_test.cpp index fda5acda..f26cbbbe 100644 --- a/tests/unittest/host/mem/shmem_host_heap_test.cpp +++ b/tests/unittest/host/mem/shmem_host_heap_test.cpp @@ -192,6 +192,32 @@ TEST_F(ShareMemoryManagerTest, calloc_large_memory_failed) local_mem_size, process_count); } +TEST_F(ShareMemoryManagerTest, calloc_multiply_overflow_size_t_max) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + const size_t nmemb = static_cast(~0ULL); + const size_t each = 2; + + void *p = shmem_calloc(nmemb, each); + EXPECT_EQ(nullptr, p); + + void *ok = shmem_malloc(4096UL); + EXPECT_NE(nullptr, ok); + shmem_free(ok); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + TEST_F(ShareMemoryManagerTest, align_zero) { const int process_count = test_gnpu_num; @@ -281,6 +307,167 @@ TEST_F(ShareMemoryManagerTest, align_not_two_power_failed) local_mem_size, process_count); } +TEST_F(ShareMemoryManagerTest, stress_malloc_calloc_align_no_leak) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + constexpr int rounds = 500; + std::vector ptrs; + ptrs.reserve(rounds * 3); + + for (int i = 0; i < rounds; ++i) { + void *p1 = shmem_malloc(1024UL + (i % 7) * 128UL); + EXPECT_NE(nullptr, p1); + ptrs.push_back(p1); + + void *p2 = shmem_calloc(32, 16 + (i % 5)); + EXPECT_NE(nullptr, p2); + ptrs.push_back(p2); + + void *p3 = shmem_align(64, 1536UL + (i % 3) * 64UL); + EXPECT_NE(nullptr, p3); + ptrs.push_back(p3); + + if ((i % 4) == 0) { + shmem_free(p1); + ptrs[ptrs.size()-3] = nullptr; + } + if ((i % 6) == 0) { + shmem_free(p2); + ptrs[ptrs.size()-2] = nullptr; + } + } + + for (void *p : ptrs) { + if (p) shmem_free(p); + } + + void *big = shmem_malloc(heap_memory_size / 2); + EXPECT_NE(nullptr, big); + shmem_free(big); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + +TEST_F(ShareMemoryManagerTest, calls_before_init_and_after_finalize) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + + EXPECT_EQ(nullptr, shmem_malloc(1024UL)); + EXPECT_EQ(nullptr, shmem_calloc(4, 256)); + EXPECT_EQ(nullptr, shmem_align(64, 4096UL)); + shmem_free(nullptr); + + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + void *ok = shmem_malloc(2048UL); + EXPECT_NE(nullptr, ok); + shmem_free(ok); + + test_finalize(stream, device_id); + + EXPECT_EQ(nullptr, shmem_malloc(1024UL)); + EXPECT_EQ(nullptr, shmem_calloc(2, 512)); + EXPECT_EQ(nullptr, shmem_align(32, 1024UL)); + shmem_free(nullptr); + }, + local_mem_size, process_count); +} + +TEST_F(ShareMemoryManagerTest, free_nullptr_is_noop) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + shmem_free(nullptr); + + void *p = shmem_malloc(8192UL); + EXPECT_NE(nullptr, p); + shmem_free(p); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + +TEST_F(ShareMemoryManagerTest, double_free_should_not_corrupt_heap) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + size_t sz = 64UL * 1024UL; + void *p = shmem_malloc(sz); + ASSERT_NE(nullptr, p); + + shmem_free(p); + shmem_free(p); + + void *q = shmem_malloc(sz); + EXPECT_NE(nullptr, q); + shmem_free(q); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + +TEST_F(ShareMemoryManagerTest, free_middle_pointer_should_not_work_and_not_corrupt) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + size_t sz = 128UL * 1024UL; + uint8_t *base = static_cast(shmem_malloc(sz)); + ASSERT_NE(nullptr, base); + + void *middle = base + 64; + + shmem_free(middle); + + shmem_free(base); + + void *again = shmem_malloc(sz); + EXPECT_NE(nullptr, again); + shmem_free(again); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + TEST_F(ShareMemoryManagerTest, free_merge) { const int process_count = test_gnpu_num; -- Gitee From e82586c7886f56237ec6a07b476cad9a82daf335 Mon Sep 17 00:00:00 2001 From: caixilong Date: Thu, 28 Aug 2025 19:26:44 +0800 Subject: [PATCH 06/74] fix align test bug --- tests/unittest/host/mem/shmem_host_heap_test.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittest/host/mem/shmem_host_heap_test.cpp b/tests/unittest/host/mem/shmem_host_heap_test.cpp index f26cbbbe..9df0d7d8 100644 --- a/tests/unittest/host/mem/shmem_host_heap_test.cpp +++ b/tests/unittest/host/mem/shmem_host_heap_test.cpp @@ -230,7 +230,6 @@ TEST_F(ShareMemoryManagerTest, align_zero) const size_t alignment = 16; auto ptr = shmem_align(alignment, 0UL); EXPECT_EQ(nullptr, ptr); - EXPECT_EQ(reinterpret_cast(ptr) & alignment, 0u); test_finalize(stream, device_id); }, local_mem_size, process_count); @@ -249,7 +248,7 @@ TEST_F(ShareMemoryManagerTest, align_one_piece_success) const size_t size = 128UL; auto ptr = shmem_align(alignment, size); EXPECT_NE(nullptr, ptr); - EXPECT_EQ(reinterpret_cast(ptr) & alignment, 0u); + EXPECT_EQ(reinterpret_cast(ptr) & (alignment - 1), 0u); test_finalize(stream, device_id); }, local_mem_size, process_count); @@ -267,6 +266,7 @@ TEST_F(ShareMemoryManagerTest, align_full_space_success) const size_t alignment = 16; auto ptr = shmem_align(alignment, heap_memory_size); EXPECT_NE(nullptr, ptr); + EXPECT_EQ(reinterpret_cast(ptr) & (alignment - 1), 0u); test_finalize(stream, device_id); }, local_mem_size, process_count); -- Gitee From 37c00f6cc6b860e58960ad9c1034965718249d2c Mon Sep 17 00:00:00 2001 From: caixilong Date: Mon, 1 Sep 2025 10:11:35 +0800 Subject: [PATCH 07/74] fix ci --- .gitmodules | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.gitmodules b/.gitmodules index 1064fb85..1dfda128 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,12 @@ [submodule "3rdparty/memfabric_hybrid"] path = 3rdparty/memfabric_hybrid url = https://gitee.com/ascend/memfabric_hybrid.git + branch = br_release_shmem_1.0 +[submodule "3rdparty/catlass"] + path = 3rdparty/catlass + url = https://gitee.com/ascend/catlass.git +[submodule "3rdparty/googletest"] + path = 3rdparty/googletest + url = https://gitee.com/mirrors/googletest.git + branch = v1.14.x + shallow = true # depth=1 -- Gitee From c4071cc941b04552c0e114053a6ad03c6587cd54 Mon Sep 17 00:00:00 2001 From: caixilong Date: Tue, 2 Sep 2025 09:47:23 +0800 Subject: [PATCH 08/74] change googletest clone link --- scripts/build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/build.sh b/scripts/build.sh index bdcf7452..cb8ca06b 100644 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -104,7 +104,7 @@ function fn_build_googletest() return 0 fi cd $THIRD_PARTY_DIR - [[ ! -d "googletest" ]] && git clone --branch v1.14.0 --depth 1 https://github.com/google/googletest.git + [[ ! -d "googletest" ]] && git clone --branch v1.14.0 --depth 1 https://gitee.com/mirrors/googletest.git cd googletest rm -rf build && mkdir build && cd build -- Gitee From d71aa8df50394d24eec6f4a8bb8cf193b7aee153 Mon Sep 17 00:00:00 2001 From: caixilong Date: Thu, 28 Aug 2025 16:40:41 +0800 Subject: [PATCH 09/74] add heap test cases --- .../host/mem/shmem_host_heap_test.cpp | 187 ++++++++++++++++++ 1 file changed, 187 insertions(+) diff --git a/tests/unittest/host/mem/shmem_host_heap_test.cpp b/tests/unittest/host/mem/shmem_host_heap_test.cpp index fda5acda..f26cbbbe 100644 --- a/tests/unittest/host/mem/shmem_host_heap_test.cpp +++ b/tests/unittest/host/mem/shmem_host_heap_test.cpp @@ -192,6 +192,32 @@ TEST_F(ShareMemoryManagerTest, calloc_large_memory_failed) local_mem_size, process_count); } +TEST_F(ShareMemoryManagerTest, calloc_multiply_overflow_size_t_max) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + const size_t nmemb = static_cast(~0ULL); + const size_t each = 2; + + void *p = shmem_calloc(nmemb, each); + EXPECT_EQ(nullptr, p); + + void *ok = shmem_malloc(4096UL); + EXPECT_NE(nullptr, ok); + shmem_free(ok); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + TEST_F(ShareMemoryManagerTest, align_zero) { const int process_count = test_gnpu_num; @@ -281,6 +307,167 @@ TEST_F(ShareMemoryManagerTest, align_not_two_power_failed) local_mem_size, process_count); } +TEST_F(ShareMemoryManagerTest, stress_malloc_calloc_align_no_leak) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + constexpr int rounds = 500; + std::vector ptrs; + ptrs.reserve(rounds * 3); + + for (int i = 0; i < rounds; ++i) { + void *p1 = shmem_malloc(1024UL + (i % 7) * 128UL); + EXPECT_NE(nullptr, p1); + ptrs.push_back(p1); + + void *p2 = shmem_calloc(32, 16 + (i % 5)); + EXPECT_NE(nullptr, p2); + ptrs.push_back(p2); + + void *p3 = shmem_align(64, 1536UL + (i % 3) * 64UL); + EXPECT_NE(nullptr, p3); + ptrs.push_back(p3); + + if ((i % 4) == 0) { + shmem_free(p1); + ptrs[ptrs.size()-3] = nullptr; + } + if ((i % 6) == 0) { + shmem_free(p2); + ptrs[ptrs.size()-2] = nullptr; + } + } + + for (void *p : ptrs) { + if (p) shmem_free(p); + } + + void *big = shmem_malloc(heap_memory_size / 2); + EXPECT_NE(nullptr, big); + shmem_free(big); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + +TEST_F(ShareMemoryManagerTest, calls_before_init_and_after_finalize) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + + EXPECT_EQ(nullptr, shmem_malloc(1024UL)); + EXPECT_EQ(nullptr, shmem_calloc(4, 256)); + EXPECT_EQ(nullptr, shmem_align(64, 4096UL)); + shmem_free(nullptr); + + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + void *ok = shmem_malloc(2048UL); + EXPECT_NE(nullptr, ok); + shmem_free(ok); + + test_finalize(stream, device_id); + + EXPECT_EQ(nullptr, shmem_malloc(1024UL)); + EXPECT_EQ(nullptr, shmem_calloc(2, 512)); + EXPECT_EQ(nullptr, shmem_align(32, 1024UL)); + shmem_free(nullptr); + }, + local_mem_size, process_count); +} + +TEST_F(ShareMemoryManagerTest, free_nullptr_is_noop) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + shmem_free(nullptr); + + void *p = shmem_malloc(8192UL); + EXPECT_NE(nullptr, p); + shmem_free(p); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + +TEST_F(ShareMemoryManagerTest, double_free_should_not_corrupt_heap) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + size_t sz = 64UL * 1024UL; + void *p = shmem_malloc(sz); + ASSERT_NE(nullptr, p); + + shmem_free(p); + shmem_free(p); + + void *q = shmem_malloc(sz); + EXPECT_NE(nullptr, q); + shmem_free(q); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + +TEST_F(ShareMemoryManagerTest, free_middle_pointer_should_not_work_and_not_corrupt) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + size_t sz = 128UL * 1024UL; + uint8_t *base = static_cast(shmem_malloc(sz)); + ASSERT_NE(nullptr, base); + + void *middle = base + 64; + + shmem_free(middle); + + shmem_free(base); + + void *again = shmem_malloc(sz); + EXPECT_NE(nullptr, again); + shmem_free(again); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + TEST_F(ShareMemoryManagerTest, free_merge) { const int process_count = test_gnpu_num; -- Gitee From 9889bd7c86b51639a6751085e63a8d0e5c95f88d Mon Sep 17 00:00:00 2001 From: caixilong Date: Thu, 28 Aug 2025 19:26:44 +0800 Subject: [PATCH 10/74] fix align test bug --- tests/unittest/host/mem/shmem_host_heap_test.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittest/host/mem/shmem_host_heap_test.cpp b/tests/unittest/host/mem/shmem_host_heap_test.cpp index f26cbbbe..9df0d7d8 100644 --- a/tests/unittest/host/mem/shmem_host_heap_test.cpp +++ b/tests/unittest/host/mem/shmem_host_heap_test.cpp @@ -230,7 +230,6 @@ TEST_F(ShareMemoryManagerTest, align_zero) const size_t alignment = 16; auto ptr = shmem_align(alignment, 0UL); EXPECT_EQ(nullptr, ptr); - EXPECT_EQ(reinterpret_cast(ptr) & alignment, 0u); test_finalize(stream, device_id); }, local_mem_size, process_count); @@ -249,7 +248,7 @@ TEST_F(ShareMemoryManagerTest, align_one_piece_success) const size_t size = 128UL; auto ptr = shmem_align(alignment, size); EXPECT_NE(nullptr, ptr); - EXPECT_EQ(reinterpret_cast(ptr) & alignment, 0u); + EXPECT_EQ(reinterpret_cast(ptr) & (alignment - 1), 0u); test_finalize(stream, device_id); }, local_mem_size, process_count); @@ -267,6 +266,7 @@ TEST_F(ShareMemoryManagerTest, align_full_space_success) const size_t alignment = 16; auto ptr = shmem_align(alignment, heap_memory_size); EXPECT_NE(nullptr, ptr); + EXPECT_EQ(reinterpret_cast(ptr) & (alignment - 1), 0u); test_finalize(stream, device_id); }, local_mem_size, process_count); -- Gitee From 73df50a0854b3132c740cb942bca78c3c3e3fbec Mon Sep 17 00:00:00 2001 From: caixilong Date: Mon, 1 Sep 2025 10:11:35 +0800 Subject: [PATCH 11/74] fix ci --- .gitmodules | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.gitmodules b/.gitmodules index 1064fb85..1dfda128 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,12 @@ [submodule "3rdparty/memfabric_hybrid"] path = 3rdparty/memfabric_hybrid url = https://gitee.com/ascend/memfabric_hybrid.git + branch = br_release_shmem_1.0 +[submodule "3rdparty/catlass"] + path = 3rdparty/catlass + url = https://gitee.com/ascend/catlass.git +[submodule "3rdparty/googletest"] + path = 3rdparty/googletest + url = https://gitee.com/mirrors/googletest.git + branch = v1.14.x + shallow = true # depth=1 -- Gitee From 3840e6ebf51cd9ce8f30256824310659fc36a4cc Mon Sep 17 00:00:00 2001 From: caixilong Date: Tue, 2 Sep 2025 09:47:23 +0800 Subject: [PATCH 12/74] change googletest clone link --- scripts/build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/build.sh b/scripts/build.sh index d1fa804b..9ce9d8ee 100644 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -118,7 +118,7 @@ function fn_build_googletest() return 0 fi cd $THIRD_PARTY_DIR - [[ ! -d "googletest" ]] && git clone --branch v1.14.0 --depth 1 https://github.com/google/googletest.git + [[ ! -d "googletest" ]] && git clone --branch v1.14.0 --depth 1 https://gitee.com/mirrors/googletest.git cd googletest rm -rf build && mkdir build && cd build -- Gitee From 909154bc4d3890016b5ec7a1c538dffa7a4fe88d Mon Sep 17 00:00:00 2001 From: caixilong Date: Fri, 5 Sep 2025 11:28:05 +0800 Subject: [PATCH 13/74] delete googletest and catlass --- .gitmodules | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/.gitmodules b/.gitmodules index 1dfda128..3876ebf2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,12 +1,3 @@ [submodule "3rdparty/memfabric_hybrid"] path = 3rdparty/memfabric_hybrid - url = https://gitee.com/ascend/memfabric_hybrid.git - branch = br_release_shmem_1.0 -[submodule "3rdparty/catlass"] - path = 3rdparty/catlass - url = https://gitee.com/ascend/catlass.git -[submodule "3rdparty/googletest"] - path = 3rdparty/googletest - url = https://gitee.com/mirrors/googletest.git - branch = v1.14.x - shallow = true # depth=1 + url = https://gitee.com/ascend/memfabric_hybrid.git \ No newline at end of file -- Gitee From b2a77a34e19332eedc259dc2ea511f5ffd8469d9 Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Mon, 1 Sep 2025 15:27:14 +0800 Subject: [PATCH 14/74] RDMA support draft. --- examples/CMakeLists.txt | 1 + examples/rdma_perftest/CMakeLists.txt | 9 + examples/rdma_perftest/README.md | 28 ++ examples/rdma_perftest/main.cpp | 246 +++++++++++++++++ .../rdma_perftest/rdma_perftest_kernel.cpp | 261 ++++++++++++++++++ include/device/shmem_device_rma.h | 170 ++++++++---- include/host/shmem_host_def.h | 8 + include/internal/host_device/shmemi_types.h | 2 +- src/host/init/shmem_init.cpp | 11 +- .../device/mem/rdma_mem/rdma_mem_kernel.cpp | 138 +++++++++ tests/unittest/host/main_test.cpp | 1 + .../host/mem/rdma_mem/rdma_mem_host_test.cpp | 120 ++++++++ 12 files changed, 927 insertions(+), 68 deletions(-) create mode 100644 examples/rdma_perftest/CMakeLists.txt create mode 100644 examples/rdma_perftest/README.md create mode 100644 examples/rdma_perftest/main.cpp create mode 100644 examples/rdma_perftest/rdma_perftest_kernel.cpp create mode 100644 tests/unittest/device/mem/rdma_mem/rdma_mem_kernel.cpp create mode 100644 tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index d086b962..3ffdb7bd 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -50,6 +50,7 @@ endfunction() foreach(EXAMPLE allgather matmul_allreduce + rdma_perftest ) add_subdirectory(${EXAMPLE}) endforeach() \ No newline at end of file diff --git a/examples/rdma_perftest/CMakeLists.txt b/examples/rdma_perftest/CMakeLists.txt new file mode 100644 index 00000000..6691d5c6 --- /dev/null +++ b/examples/rdma_perftest/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +shmem_add_collective_example(rdma_perftest) \ No newline at end of file diff --git a/examples/rdma_perftest/README.md b/examples/rdma_perftest/README.md new file mode 100644 index 00000000..1b1d37fc --- /dev/null +++ b/examples/rdma_perftest/README.md @@ -0,0 +1,28 @@ +使用方式: +1.在shmem/目录编译: +```bash +bash scripts/build.sh +``` +2.在shmem/目录运行: +```bash +export PROJECT_ROOT= +export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/smem/lib64:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/hybm/lib64:$LD_LIBRARY_PATH +./build/bin/rdma_perftest 2 0 tcp://127.0.0.1:8765 2 0 0 highlevel_put_pingpong_latency 64 # rank 0 +./build/bin/rdma_perftest 2 1 tcp://127.0.0.1:8765 2 0 0 highlevel_put_pingpong_latency 64 # rank 1 +``` + +3.命令行参数说明 + ./rdma_perftest + +- n_ranks: 全局Rank数量,只支持2个Rank。 +- rank_id: 当前进程的Rank号。 +- ipport: SHMEM初始化需要的IP及端口号,格式为tcp://:<端口号>。 +- g_npus: 当前卡上启动的NPU数量。 +- f_rank: 当前卡上使用的第一个Rank号。 +- f_npu: 当前卡上使用的第一个NPU卡号。 +- test_type: 测试类型。 + - highlevel_put_pingpong_latency:测试Put高阶接口的pingpong时延。 + - postsend_cost: 测试postsend接口耗时。 + - highlevel_put_bw: 测试Put高阶接口的带宽。 + - rdma_mte_bw: 测试并行下发MTE和RDMA时的带宽。 +- msg_len: 测试传输的数据量大小,单位为字节(Byte)。 \ No newline at end of file diff --git a/examples/rdma_perftest/main.cpp b/examples/rdma_perftest/main.cpp new file mode 100644 index 00000000..36796019 --- /dev/null +++ b/examples/rdma_perftest/main.cpp @@ -0,0 +1,246 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "shmem_api.h" + +int g_npus = 8; +const char *ipport; +int f_rank = 0; +int f_npu = 0; +const char *test_type; + +extern void rdma_highlevel_put_pingpong_latency_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length); +extern void rdma_postsend_cost_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length); +extern void rdma_highlevel_put_bw_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length); +extern void rdma_mte_put_bw_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length); + +int test_shmem_rdma_highlevel_put_pingpong_latency(int rank_id, int n_ranks, uint64_t local_mem_size, int message_length) +{ + uint32_t iteration = 1; + int32_t device_id = rank_id % g_npus + f_npu; + int status = 0; + aclrtStream stream = nullptr; + + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); + status = aclrtCreateStream(&stream); + + shmem_init_attr_t *attributes; + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; + shmem_set_conf_store_tls(false, nullptr, 0); + status = shmem_init_attr(attributes); + + uint64_t fftsConfig = shmemx_get_ffts_config(); + uint8_t* gva = (uint8_t*)shmem_malloc(1024 * 1024 * 6); + + int64_t *xHost; + size_t totalSize = message_length * n_ranks; + + aclrtMallocHost((void **)(&xHost), totalSize); + for (uint32_t i = 0; i < message_length / sizeof(int64_t); i++) { + xHost[i] = rank_id + 10; + } + aclrtMemcpy(gva + rank_id * message_length, message_length, xHost, message_length, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(gva + n_ranks * message_length + 32 * (rank_id + 1), 32, xHost, 32, ACL_MEMCPY_HOST_TO_DEVICE); + + for (uint32_t i = 0; i < iteration; i++) { + rdma_highlevel_put_pingpong_latency_do(1, stream, fftsConfig, gva, message_length); + } + aclrtSynchronizeStream(stream); + if (rank_id == 0) { + aclrtMemcpy(xHost, sizeof(int64_t), gva + message_length * n_ranks, sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST); + std::cout << "RDMA highlevel put pingpong latency test. Message length = " << message_length << " Byte; latency = " << xHost[0] / 50.0 << " us." << std::endl; + } + + status = shmem_finalize(); + status = aclrtDestroyStream(stream); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + return 0; +} + +int test_shmem_rdma_postsend_cost(int rank_id, int n_ranks, uint64_t local_mem_size, int message_length) +{ + uint32_t iteration = 1; + int32_t device_id = rank_id % g_npus + f_npu; + int status = 0; + aclrtStream stream = nullptr; + + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); + status = aclrtCreateStream(&stream); + + shmem_init_attr_t *attributes; + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; + shmem_set_conf_store_tls(false, nullptr, 0); + status = shmem_init_attr(attributes); + + uint64_t fftsConfig = shmemx_get_ffts_config(); + uint8_t* gva = (uint8_t*)shmem_malloc(1024 * 1024 * 6); + + int64_t *xHost; + size_t totalSize = message_length * n_ranks; + + aclrtMallocHost((void **)(&xHost), totalSize); + for (uint32_t i = 0; i < message_length / sizeof(int64_t); i++) { + xHost[i] = rank_id + 10; + } + aclrtMemcpy(gva + rank_id * message_length, message_length, xHost, message_length, ACL_MEMCPY_HOST_TO_DEVICE); + + for (uint32_t i = 0; i < iteration; i++) { + rdma_postsend_cost_do(1, stream, fftsConfig, gva, message_length); + } + aclrtSynchronizeStream(stream); + if (rank_id == 0) { + aclrtMemcpy(xHost, sizeof(uint32_t), gva + message_length * n_ranks, sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST); + std::cout << "RDMA postsend cost test. Message length = " << message_length << " Byte; postsend cost = " << xHost[0] / (50.0 * 500) << " us." << std::endl; + } + + status = shmem_finalize(); + status = aclrtDestroyStream(stream); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + return 0; +} + +int test_shmem_rdma_highlevel_put_bw(int rank_id, int n_ranks, uint64_t local_mem_size, int message_length) +{ + int32_t device_id = rank_id % g_npus + f_npu; + int status = 0; + aclrtStream stream = nullptr; + + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); + status = aclrtCreateStream(&stream); + + shmem_init_attr_t *attributes; + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; + shmem_set_conf_store_tls(false, nullptr, 0); + status = shmem_init_attr(attributes); + + uint64_t fftsConfig = shmemx_get_ffts_config(); + uint8_t* gva = (uint8_t*)shmem_malloc(1024 * 1024 * 6); + + int64_t *xHost; + size_t totalSize = message_length * n_ranks; + + aclrtMallocHost((void **)(&xHost), totalSize); + for (uint32_t i = 0; i < message_length / sizeof(int64_t); i++) { + xHost[i] = rank_id + 10; + } + aclrtMemcpy(gva + rank_id * message_length, message_length, xHost, message_length, ACL_MEMCPY_HOST_TO_DEVICE); + + rdma_highlevel_put_bw_do(1, stream, fftsConfig, gva, message_length); + aclrtSynchronizeStream(stream); + if (rank_id == 0) { + aclrtMemcpy(xHost, sizeof(int64_t), gva + message_length * n_ranks, sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST); + std::cout << "RDMA high level put bandwidth test. Message length = " << message_length << " Byte; time = " << xHost[0] / (50.0) << " us." << std::endl; + } + + status = shmem_finalize(); + status = aclrtDestroyStream(stream); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + return 0; +} + +int test_shmem_rdma_mte_put_bw(int rank_id, int n_ranks, uint64_t local_mem_size, int message_length) +{ + int32_t device_id = rank_id % g_npus + f_npu; + int status = 0; + aclrtStream stream = nullptr; + + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); + status = aclrtCreateStream(&stream); + + shmem_init_attr_t *attributes; + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; + shmem_set_conf_store_tls(false, nullptr, 0); + status = shmem_init_attr(attributes); + shmem_mte_set_ub_params(0, 128 * 1024, 0); + + uint64_t fftsConfig = shmemx_get_ffts_config(); + uint8_t* gva = (uint8_t*)shmem_malloc(1024 * 1024 * 32); + + int64_t *xHost; + size_t totalSize = message_length * n_ranks; + + aclrtMallocHost((void **)(&xHost), totalSize); + for (uint32_t i = 0; i < message_length / sizeof(int64_t); i++) { + xHost[i] = rank_id + 10; + } + aclrtMemcpy(gva + rank_id * message_length, message_length, xHost, message_length, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(gva + (rank_id + n_ranks) * message_length, message_length, xHost, message_length, ACL_MEMCPY_HOST_TO_DEVICE); + + rdma_mte_put_bw_do(1, stream, fftsConfig, gva, message_length); + aclrtSynchronizeStream(stream); + if (rank_id == 0) { + aclrtMemcpy(xHost, 64, gva + message_length * n_ranks * 2, 64, ACL_MEMCPY_DEVICE_TO_HOST); + std::cout << "RDMA rdma mte test. Message length = " << message_length << " Byte; RDMA time = " << xHost[0] / (50.0) << " us." << std::endl; + std::cout << "RDMA rdma mte test. Message length = " << message_length << " Byte; MTE time = " << xHost[6] / (50.0) << " us." << std::endl; + } + + status = shmem_finalize(); + status = aclrtDestroyStream(stream); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + return 0; +} + +int main(int argc, char *argv[]) +{ + if (argc != 9) { + std::cout << "[ERROR] Paramater number mismatch." << std::endl; + std::cout << "[USAGE] ./rdma_perftest . See README for more details." << std::endl; + } + int status = 0; + int n_ranks = atoi(argv[1]); + if (n_ranks != 2) { + std::cout << "[ERROR] Error number of ranks! Only support 2 ranks!" << std::endl; + } + int rank_id = atoi(argv[2]); + if (rank_id >= 2) { + std::cout << "[ERROR] Error rank ID! Only support 2 ranks!" << std::endl; + } + ipport = argv[3]; + g_npus = atoi(argv[4]); + f_rank = atoi(argv[5]); + f_npu = atoi(argv[6]); + test_type = argv[7]; + int msg_len = atoi(argv[8]); + uint64_t local_mem_size = 1024UL * 1024UL * 64; + if (std::string(test_type) == "highlevel_put_pingpong_latency") { + test_shmem_rdma_highlevel_put_pingpong_latency(rank_id, n_ranks, local_mem_size, msg_len); + } else if (std::string(test_type) == "postsend_cost") { + test_shmem_rdma_postsend_cost(rank_id, n_ranks, local_mem_size, msg_len); + } else if (std::string(test_type) == "highlevel_put_bw") { + test_shmem_rdma_highlevel_put_bw(rank_id, n_ranks, local_mem_size, msg_len); + }else if (std::string(test_type) == "rdma_mte_bw") { + test_shmem_rdma_mte_put_bw(rank_id, n_ranks, local_mem_size, msg_len); + } + + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; + + return 0; +} \ No newline at end of file diff --git a/examples/rdma_perftest/rdma_perftest_kernel.cpp b/examples/rdma_perftest/rdma_perftest_kernel.cpp new file mode 100644 index 00000000..862a96a8 --- /dev/null +++ b/examples/rdma_perftest/rdma_perftest_kernel.cpp @@ -0,0 +1,261 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "kernel_operator.h" +#include "acl/acl.h" +#include "shmem_api.h" + +constexpr uint32_t MAGIC_VAL = 10; +constexpr uint32_t WARMUP_MESSAGE_LENGTH = 32; + +/** + * @brief RDMA Quiet function. This synchronous function ensures all previous RDMA WQEs are completed (data has arrived at the destination NIC). + * + * @param remoteRankId [in] destination rank ID + * @param ubLocal64 [in] temporary UB local tensor of uint64_t used as workspace + * @param ubLocal32 [in] temporary UB local tensor of uint32_t used as workspace + */ + +SHMEM_DEVICE void smem_shm_roce_quiet(uint32_t remoteRankId, AscendC::LocalTensor ubLocal64, AscendC::LocalTensor ubLocal32) { + __gm__ HybmDeviceMeta* metaPtr = (__gm__ HybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + + SMEM_SHM_DEVICE_GLOBAL_META_SIZE); + __gm__ AIVRDMAInfo* RDMAInfo = (__gm__ AIVRDMAInfo*)(metaPtr->qpInfoAddress); + uint32_t qpNum = RDMAInfo->qpNum; + for (uint32_t qpIdx = 0; qpIdx < qpNum; qpIdx++) { + __gm__ WQCtx* qpCtxEntry = (__gm__ WQCtx*)(RDMAInfo->sqPtr + (remoteRankId * qpNum + qpIdx) * sizeof(WQCtx)); + auto curHardwareHeadAddr = qpCtxEntry->headAddr; + cacheWriteThrough((__gm__ uint8_t*)curHardwareHeadAddr, 8); + uint32_t curHead = *(__gm__ uint32_t*)(curHardwareHeadAddr); + smem_shm_roce_poll_cq(remoteRankId, qpIdx, curHead - 1, ubLocal64, ubLocal32); + } +} + +extern "C" __global__ __aicore__ void rdma_highlevel_put_pingpong_latency(uint64_t fftsConfig, GM_ADDR gva, int message_length) { + shmemx_set_ffts_config(fftsConfig); + if (AscendC::GetSubBlockIdx() != 0) { + return; + } + AscendC::TPipe pipe; + AscendC::TBuf buf; + pipe.InitBuffer(buf, UB_ALIGN_SIZE); + AscendC::LocalTensor ubLocalRead = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); + + int64_t rank = smem_shm_get_global_rank(); + int64_t rank_size = smem_shm_get_global_rank_size(); + uint32_t peer; + + // Warm up + GM_ADDR warm_addr = gva + rank_size * message_length + WARMUP_MESSAGE_LENGTH * (rank + 1); + if (rank == 0) { + peer = 1; + shmem_put_uint8_mem_nbi(warm_addr, warm_addr, WARMUP_MESSAGE_LENGTH, peer); + while (*(__gm__ uint32_t*)(gva + rank_size * message_length + WARMUP_MESSAGE_LENGTH * (peer + 1)) != peer + MAGIC_VAL) { + cacheWriteThrough(gva + rank_size * message_length + WARMUP_MESSAGE_LENGTH * (peer + 1), sizeof(uint32_t)); + AscendC::GetSystemCycle(); + } + } else { + peer = 0; + while (*(__gm__ uint32_t*)(gva + rank_size * message_length + WARMUP_MESSAGE_LENGTH * (peer + 1)) != peer + MAGIC_VAL) { + cacheWriteThrough(gva + rank_size * message_length + WARMUP_MESSAGE_LENGTH * (peer + 1), sizeof(uint32_t)); + AscendC::GetSystemCycle(); + } + AscendC::PipeBarrier(); + shmem_put_uint8_mem_nbi(warm_addr, warm_addr, WARMUP_MESSAGE_LENGTH, peer); + } + AscendC::PipeBarrier(); + + // Actual test + GM_ADDR src_addr = gva + rank * message_length; + if (rank == 0) { + peer = 1; + int64_t start = AscendC::GetSystemCycle(); + shmem_put_uint8_mem_nbi(src_addr, src_addr, message_length, peer); + while (*(__gm__ uint32_t*)(gva + message_length * 2 - 8) != peer + MAGIC_VAL) { + cacheWriteThrough(gva + message_length * 2 - 8, 8); + AscendC::GetSystemCycle(); + } + AscendC::PipeBarrier(); + int64_t end = AscendC::GetSystemCycle(); + *(__gm__ int64_t*)(gva + message_length * 2) = end - start; + } else { + peer = 0; + while (*(__gm__ uint32_t*)(gva + message_length * 1 - 8) != peer + MAGIC_VAL) { + cacheWriteThrough(gva + message_length * 1 - 8, 8); + AscendC::GetSystemCycle(); + } + AscendC::PipeBarrier(); + shmem_put_uint8_mem_nbi(src_addr, src_addr, message_length, peer); + } +} + +void rdma_highlevel_put_pingpong_latency_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length) { + rdma_highlevel_put_pingpong_latency<<<1, nullptr, stream>>>(fftsConfig, gva, message_length); +} + +extern "C" __global__ __aicore__ void rdma_postsend_cost(uint64_t fftsConfig, GM_ADDR gva, int message_length) { + shmemx_set_ffts_config(fftsConfig); + if (AscendC::GetSubBlockIdx() != 0) { + return; + } + AscendC::TPipe pipe; + AscendC::TBuf buf; + pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); + AscendC::LocalTensor ubLocal32 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); + AscendC::LocalTensor ubLocal64 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); + + int64_t rank = smem_shm_get_global_rank(); + int64_t rank_size = smem_shm_get_global_rank_size(); + uint32_t peer; + + // Actual test + GM_ADDR src_addr = gva + rank * message_length; + + if (rank == 0) { + peer = 1; + GM_ADDR dest_addr = (GM_ADDR)(shmem_ptr(src_addr, peer)); + int64_t start = AscendC::GetSystemCycle(); + for (uint32_t i = 0; i < 500; i++) { + smem_shm_roce_write(src_addr, dest_addr, peer, 0, message_length, ubLocal64, ubLocal32); + } + AscendC::PipeBarrier(); + int64_t end = AscendC::GetSystemCycle(); + *(__gm__ int64_t*)(gva + message_length * 2) = end - start; + } +} + +void rdma_postsend_cost_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length) { + rdma_postsend_cost<<<1, nullptr, stream>>>(fftsConfig, gva, message_length); +} + +extern "C" __global__ __aicore__ void rdma_highlevel_put_bw(uint64_t fftsConfig, GM_ADDR gva, int message_length) { + shmemx_set_ffts_config(fftsConfig); + if (AscendC::GetSubBlockIdx() != 0) { + return; + } + AscendC::TPipe pipe; + AscendC::TBuf buf; + pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); + AscendC::LocalTensor ubLocal32 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); + AscendC::LocalTensor ubLocal64 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); + + int64_t rank = smem_shm_get_global_rank(); + int64_t rank_size = smem_shm_get_global_rank_size(); + uint32_t peer; + + // Actual test + GM_ADDR src_addr = gva + rank * message_length; + if (rank == 0) { + peer = 1; + int64_t start = AscendC::GetSystemCycle(); + for (int i = 0; i < 10000; i++) { + shmem_put_uint8_mem_nbi(src_addr, src_addr, message_length, peer); + } + smem_shm_roce_quiet(peer, ubLocal64, ubLocal32); + shmem_put_uint8_mem_nbi(gva + rank_size * message_length + 8, src_addr, sizeof(uint32_t), peer); + while (*(__gm__ uint32_t*)(gva + message_length * rank_size + 16) != peer + MAGIC_VAL) { + cacheWriteThrough(gva + message_length * rank_size + 16, 8); + AscendC::GetSystemCycle(); + } + AscendC::PipeBarrier(); + int64_t end = AscendC::GetSystemCycle(); + *(__gm__ int64_t*)(gva + message_length * rank_size) = end - start; + } else { + peer = 0; + while (*(__gm__ uint32_t*)(gva + rank_size * message_length + 8) != peer + MAGIC_VAL) { + cacheWriteThrough(gva + rank_size * message_length + 8, 8); + AscendC::GetSystemCycle(); + } + AscendC::PipeBarrier(); + shmem_put_uint8_mem_nbi(gva + message_length * rank_size + 16, src_addr, sizeof(uint32_t), peer); + } +} + +void rdma_highlevel_put_bw_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length) { + rdma_highlevel_put_bw<<<1, nullptr, stream>>>(fftsConfig, gva, message_length); +} + +extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_ADDR gva, int message_length) { + shmemx_set_ffts_config(fftsConfig); + AscendC::LocalTensor ubLocal32; + ubLocal32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ubLocal32.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR); + ubLocal32.address_.dataLen = UB_ALIGN_SIZE; + AscendC::LocalTensor ubLocal64; + ubLocal64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ubLocal64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR + UB_ALIGN_SIZE); + ubLocal64.address_.dataLen = UB_ALIGN_SIZE; + + int64_t rank = smem_shm_get_global_rank(); + int64_t rank_size = smem_shm_get_global_rank_size(); + uint32_t peer; + + // Core 0, RDMA + if (AscendC::GetBlockIdx() == 0) { + GM_ADDR src_addr = gva + rank * message_length; + if (rank == 0) { + peer = 1; + int64_t start = AscendC::GetSystemCycle(); + for (int i = 0; i < 10000; i++) { + smem_shm_roce_write(src_addr, (GM_ADDR)shmem_ptr(src_addr, peer), peer, 0, message_length, ubLocal64, ubLocal32); + } + smem_shm_roce_quiet(peer, ubLocal64, ubLocal32); + smem_shm_roce_write(src_addr, (GM_ADDR)shmem_ptr(gva + rank_size * message_length * 2 + 8, peer), peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); + while (*(__gm__ int64_t*)(gva + message_length * rank_size * 2 + 16) != peer + MAGIC_VAL) { + cacheWriteThrough(gva + message_length * rank_size * 2 + 16, 8); + AscendC::GetSystemCycle(); + } + AscendC::PipeBarrier(); + int64_t end = AscendC::GetSystemCycle(); + *(__gm__ int64_t*)(gva + message_length * rank_size * 2) = end - start; + } else { + peer = 0; + while (*(__gm__ int64_t*)(gva + rank_size * message_length * 2 + 8) != peer + MAGIC_VAL) { + cacheWriteThrough(gva + rank_size * message_length * 2 + 8, 8); + AscendC::GetSystemCycle(); + } + AscendC::PipeBarrier(); + smem_shm_roce_write(src_addr, (GM_ADDR)shmem_ptr(gva + rank_size * message_length * 2 + 16, peer), peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); + } + } else { // core 1, MTE + GM_ADDR src_addr = gva + (rank + rank_size) * message_length; + __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); + /* CopyUB Config Set */ + uint64_t copy_ub = device_state->mte_config.shmem_ub; + uint32_t copy_ub_size = device_state->mte_config.ub_size; + AscendC::TEventID copy_event_id = (AscendC::TEventID)device_state->mte_config.event_id; + if (rank == 0) { + peer = 1; + int64_t start = AscendC::GetSystemCycle(); + for (int i = 0; i < 10000; i++) { + shmem_mte_put_mem_nbi(src_addr, src_addr, reinterpret_cast<__ubuf__ uint8_t*>(copy_ub), copy_ub_size, message_length, peer, copy_event_id); + } + AscendC::PipeBarrier(); + shmem_mte_put_mem_nbi(gva + rank_size * message_length * 2 + 24, src_addr, reinterpret_cast<__ubuf__ uint8_t*>(copy_ub), copy_ub_size, sizeof(uint32_t), peer, copy_event_id); + while (*(__gm__ uint32_t*)(gva + message_length * rank_size * 2 + 32) != peer + MAGIC_VAL) { + cacheWriteThrough(gva + message_length * rank_size * 2 + 32, 8); + AscendC::GetSystemCycle(); + } + AscendC::PipeBarrier(); + int64_t end = AscendC::GetSystemCycle(); + *(__gm__ int64_t*)(gva + message_length * rank_size * 2 + 48) = end - start; + } else { + peer = 0; + while (*(__gm__ uint32_t*)(gva + rank_size * message_length * 2 + 24) != peer + MAGIC_VAL) { + cacheWriteThrough(gva + rank_size * message_length * 2 + 24, 8); + AscendC::GetSystemCycle(); + } + AscendC::PipeBarrier(); + shmem_mte_put_mem_nbi(gva + rank_size * message_length * 2 + 32, src_addr, reinterpret_cast<__ubuf__ uint8_t*>(copy_ub), copy_ub_size, sizeof(uint32_t), peer, copy_event_id); + } + } +} + +void rdma_mte_put_bw_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length) { + rdma_mte_put_bw<<<2, nullptr, stream>>>(fftsConfig, gva, message_length); +} \ No newline at end of file diff --git a/include/device/shmem_device_rma.h b/include/device/shmem_device_rma.h index c73b7339..c8805c5c 100644 --- a/include/device/shmem_device_rma.h +++ b/include/device/shmem_device_rma.h @@ -16,6 +16,10 @@ #include "shmem_device_team.h" #include "internal/device/sync/shmemi_device_p2p.h" #include "shmem_device_sync.h" +#include "host/shmem_host_def.h" + +constexpr uint64_t SHMEM_INTERNAL_UB_BUF_START_ADDR = 188 * 1024; +constexpr uint32_t UB_ALIGN_SIZE = 32; /** * @brief Standard RMA Types and Names @@ -241,17 +245,33 @@ SHMEM_DEVICE void shmem_getmem_nbi(__gm__ void *dst, __gm__ void *src, uint32_t */ \ SHMEM_DEVICE void shmem_get_##NAME##_mem_nbi(__gm__ TYPE *dst, __gm__ TYPE *src, uint32_t elem_size, int32_t pe) \ { \ - /* ROCE */ \ - /* RDMA */ \ - /* MTE */ \ /* Global State Get */ \ __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); \ - /* CopyUB Config Set */ \ - uint64_t copy_ub = device_state->mte_config.shmem_ub; \ - uint32_t copy_ub_size = device_state->mte_config.ub_size; \ - AscendC::TEventID copy_event_id = (AscendC::TEventID)device_state->mte_config.event_id; \ - shmem_mte_get_mem_nbi(dst, src, reinterpret_cast<__ubuf__ TYPE *>(copy_ub), copy_ub_size, elem_size, pe, \ - copy_event_id); \ + if (device_state->topo_list[pe] & SHMEM_TRANSPORT_MTE) { \ + /* MTE */ \ + /* CopyUB Config Set */ \ + uint64_t copy_ub = device_state->mte_config.shmem_ub; \ + uint32_t copy_ub_size = device_state->mte_config.ub_size; \ + AscendC::TEventID copy_event_id = (AscendC::TEventID)device_state->mte_config.event_id; \ + shmem_mte_get_mem_nbi(dst, src, reinterpret_cast<__ubuf__ TYPE *>(copy_ub), copy_ub_size, elem_size, pe, \ + copy_event_id); \ + } else if (device_state->topo_list[pe] & SHMEM_TRANSPORT_ROCE) { \ + /* RoCE */ \ + auto ptr = shmem_ptr(src, pe); \ + if (ptr == nullptr) return; \ + /* Create LocalTensor */ \ + AscendC::LocalTensor ub_tensor_32; \ + ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); \ + ub_tensor_32.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR); \ + ub_tensor_32.address_.dataLen = UB_ALIGN_SIZE; \ + AscendC::LocalTensor ub_tensor_64; \ + ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); \ + ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR \ + + UB_ALIGN_SIZE); \ + ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; \ + smem_shm_roce_read((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)dst, pe, 0, elem_size * sizeof(TYPE), \ + ub_tensor_64, ub_tensor_32); \ + } \ } SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_NBI); @@ -297,20 +317,36 @@ SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_DETAILED_NBI); SHMEM_DEVICE void shmem_get_##NAME##_mem_nbi(AscendC::GlobalTensor dst, AscendC::GlobalTensor src, \ uint32_t elem_size, int pe) \ { \ - /* ROCE */ \ - /* RDMA */ \ - /* MTE */ \ /* Global State Get */ \ __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); \ - /* CopyUB Config Set */ \ - uint64_t copy_ub = device_state->mte_config.shmem_ub; \ - /* Create LocalTensor */ \ - AscendC::LocalTensor ub_tensor; \ - ub_tensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); \ - ub_tensor.address_.bufferAddr = reinterpret_cast(copy_ub); \ - ub_tensor.address_.dataLen = device_state->mte_config.ub_size; \ - AscendC::TEventID copy_event_id = (AscendC::TEventID)device_state->mte_config.event_id; \ - shmem_mte_get_mem_nbi(dst, src, ub_tensor, elem_size, pe, copy_event_id); \ + if (device_state->topo_list[pe] & SHMEM_TRANSPORT_MTE) { \ + /* MTE */ \ + /* CopyUB Config Set */ \ + uint64_t copy_ub = device_state->mte_config.shmem_ub; \ + /* Create LocalTensor */ \ + AscendC::LocalTensor ub_tensor; \ + ub_tensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); \ + ub_tensor.address_.bufferAddr = reinterpret_cast(copy_ub); \ + ub_tensor.address_.dataLen = device_state->mte_config.ub_size; \ + AscendC::TEventID copy_event_id = (AscendC::TEventID)device_state->mte_config.event_id; \ + shmem_mte_get_mem_nbi(dst, src, ub_tensor, elem_size, pe, copy_event_id); \ + } else if (device_state->topo_list[pe] & SHMEM_TRANSPORT_ROCE) { \ + /* RoCE */ \ + auto ptr = shmem_ptr((__gm__ void *)src.GetPhyAddr(), pe); \ + if (ptr == nullptr) return; \ + /* Create LocalTensor */ \ + AscendC::LocalTensor ub_tensor_32; \ + ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); \ + ub_tensor_32.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR); \ + ub_tensor_32.address_.dataLen = UB_ALIGN_SIZE; \ + AscendC::LocalTensor ub_tensor_64; \ + ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); \ + ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR \ + + UB_ALIGN_SIZE); \ + ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; \ + smem_shm_roce_read((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)(dst.GetPhyAddr()), pe, 0, \ + elem_size * sizeof(TYPE), ub_tensor_64, ub_tensor_32); \ + } \ } SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_TENSOR_NBI); @@ -357,17 +393,33 @@ SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_TENSOR_DETAILED_NBI); */ \ SHMEM_DEVICE void shmem_put_##NAME##_mem_nbi(__gm__ TYPE *dst, __gm__ TYPE *src, uint32_t elem_size, int32_t pe) \ { \ - /* ROCE */ \ - /* RDMA */ \ - /* MTE */ \ /* Global State Get */ \ __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); \ - /* CopyUB Config Set */ \ - uint64_t copy_ub = device_state->mte_config.shmem_ub; \ - uint32_t copy_ub_size = device_state->mte_config.ub_size; \ - AscendC::TEventID copy_event_id = (AscendC::TEventID)device_state->mte_config.event_id; \ - shmem_mte_put_mem_nbi(dst, src, reinterpret_cast<__ubuf__ TYPE *>(copy_ub), copy_ub_size, elem_size, pe, \ - copy_event_id); \ + if (device_state->topo_list[pe] & SHMEM_TRANSPORT_MTE) { \ + /* MTE */ \ + /* CopyUB Config Set */ \ + uint64_t copy_ub = device_state->mte_config.shmem_ub; \ + uint32_t copy_ub_size = device_state->mte_config.ub_size; \ + AscendC::TEventID copy_event_id = (AscendC::TEventID)device_state->mte_config.event_id; \ + shmem_mte_put_mem_nbi(dst, src, reinterpret_cast<__ubuf__ TYPE *>(copy_ub), copy_ub_size, elem_size, pe, \ + copy_event_id); \ + } else if (device_state->topo_list[pe] & SHMEM_TRANSPORT_ROCE) { \ + /* RoCE */ \ + auto ptr = shmem_ptr(dst, pe); \ + if (ptr == nullptr) return; \ + /* Create LocalTensor */ \ + AscendC::LocalTensor ub_tensor_32; \ + ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); \ + ub_tensor_32.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR); \ + ub_tensor_32.address_.dataLen = UB_ALIGN_SIZE; \ + AscendC::LocalTensor ub_tensor_64; \ + ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); \ + ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR \ + + UB_ALIGN_SIZE); \ + ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; \ + smem_shm_roce_write((__gm__ uint8_t*)src, (__gm__ uint8_t*)ptr, pe, 0, elem_size * sizeof(TYPE), \ + ub_tensor_64, ub_tensor_32); \ + } \ } SHMEM_TYPE_FUNC(SHMEM_PUT_TYPENAME_MEM_NBI); @@ -412,20 +464,36 @@ SHMEM_TYPE_FUNC(SHMEM_PUT_TYPENAME_MEM_DETAILED_NBI); SHMEM_DEVICE void shmem_put_##NAME##_mem_nbi(AscendC::GlobalTensor dst, AscendC::GlobalTensor src, \ uint32_t elem_size, int pe) \ { \ - /* ROCE */ \ - /* RDMA */ \ - /* MTE */ \ /* Global State Get */ \ __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); \ - /* CopyUB Config Set */ \ - uint64_t copy_ub = device_state->mte_config.shmem_ub; \ - /* Create LocalTensor */ \ - AscendC::LocalTensor ub_tensor; \ - ub_tensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); \ - ub_tensor.address_.bufferAddr = reinterpret_cast(copy_ub); \ - ub_tensor.address_.dataLen = device_state->mte_config.ub_size; \ - AscendC::TEventID copy_event_id = (AscendC::TEventID)device_state->mte_config.event_id; \ - shmem_mte_put_mem_nbi(dst, src, ub_tensor, elem_size, pe, copy_event_id); \ + if (device_state->topo_list[pe] & SHMEM_TRANSPORT_MTE) { \ + /* MTE */ \ + /* CopyUB Config Set */ \ + uint64_t copy_ub = device_state->mte_config.shmem_ub; \ + /* Create LocalTensor */ \ + AscendC::LocalTensor ub_tensor; \ + ub_tensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); \ + ub_tensor.address_.bufferAddr = reinterpret_cast(copy_ub); \ + ub_tensor.address_.dataLen = device_state->mte_config.ub_size; \ + AscendC::TEventID copy_event_id = (AscendC::TEventID)device_state->mte_config.event_id; \ + shmem_mte_put_mem_nbi(dst, src, ub_tensor, elem_size, pe, copy_event_id); \ + } else if (device_state->topo_list[pe] & SHMEM_TRANSPORT_ROCE) { \ + /* RoCE */ \ + auto ptr = shmem_ptr((__gm__ void *)dst.GetPhyAddr(), pe); \ + if (ptr == nullptr) return; \ + /* Create LocalTensor */ \ + AscendC::LocalTensor ub_tensor_32; \ + ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); \ + ub_tensor_32.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR); \ + ub_tensor_32.address_.dataLen = UB_ALIGN_SIZE; \ + AscendC::LocalTensor ub_tensor_64; \ + ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); \ + ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR \ + + UB_ALIGN_SIZE); \ + ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; \ + smem_shm_roce_write((__gm__ uint8_t*)(src.GetPhyAddr()), (__gm__ uint8_t*)ptr, pe, 0, \ + elem_size * sizeof(TYPE), ub_tensor_64, ub_tensor_32); \ + } \ } SHMEM_TYPE_FUNC(SHMEM_PUT_TYPENAME_MEM_TENSOR_NBI); @@ -473,8 +541,6 @@ SHMEM_TYPE_FUNC(SHMEM_PUT_TYPENAME_MEM_TENSOR_DETAILED_NBI); */ \ SHMEM_DEVICE void shmem_get_##NAME##_mem_nbi(__ubuf__ TYPE *dst, __gm__ TYPE *src, uint32_t elem_size, int pe) \ { \ - /* ROCE */ \ - /* RDMA */ \ /* MTE */ \ /* Global State Get */ \ __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); \ @@ -497,8 +563,6 @@ SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_UB_NBI); SHMEM_DEVICE void shmem_get_##NAME##_mem_nbi(AscendC::LocalTensor dst, AscendC::GlobalTensor src, \ uint32_t elem_size, int pe) \ { \ - /* ROCE */ \ - /* RDMA */ \ /* MTE */ \ /* Global State Get */ \ __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); \ @@ -521,8 +585,6 @@ SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_UB_TENSOR_NBI); SHMEM_DEVICE void shmem_get_##NAME##_mem_nbi(__ubuf__ TYPE *dst, __gm__ TYPE *src, \ const non_contiguous_copy_param ©_params, int pe) \ { \ - /* ROCE */ \ - /* RDMA */ \ /* MTE */ \ /* Global State Get */ \ __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); \ @@ -545,8 +607,6 @@ SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_UB_DETAILED_NBI); SHMEM_DEVICE void shmem_get_##NAME##_mem_nbi(AscendC::LocalTensor dst, AscendC::GlobalTensor src, \ const non_contiguous_copy_param ©_params, int pe) \ { \ - /* ROCE */ \ - /* RDMA */ \ /* MTE */ \ /* Global State Get */ \ __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); \ @@ -566,8 +626,6 @@ SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_UB_TENSOR_DETAILED_NBI); */ SHMEM_DEVICE void shmem_putmem_nbi(__gm__ void *dst, __gm__ void *src, uint32_t elem_size, int32_t pe) { - /* ROCE */ - /* RDMA */ /* MTE */ /* Global State Get */ __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); @@ -590,8 +648,6 @@ SHMEM_DEVICE void shmem_putmem_nbi(__gm__ void *dst, __gm__ void *src, uint32_t */ \ SHMEM_DEVICE void shmem_put_##NAME##_mem_nbi(__gm__ TYPE *dst, __ubuf__ TYPE *src, uint32_t elem_size, int32_t pe) \ { \ - /* ROCE */ \ - /* RDMA */ \ /* MTE */ \ /* Global State Get */ \ __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); \ @@ -613,8 +669,6 @@ SHMEM_TYPE_FUNC(SHMEM_PUT_TYPENAME_MEM_UB_NBI); SHMEM_DEVICE void shmem_put_##NAME##_mem_nbi(AscendC::GlobalTensor dst, AscendC::LocalTensor src, \ uint32_t elem_size, int32_t pe) \ { \ - /* ROCE */ \ - /* RDMA */ \ /* MTE */ \ /* Global State Get */ \ __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); \ @@ -637,8 +691,6 @@ SHMEM_TYPE_FUNC(SHMEM_PUT_TYPENAME_MEM_UB_TENSOR_NBI); SHMEM_DEVICE void shmem_put_##NAME##_mem_nbi(__gm__ TYPE *dst, __ubuf__ TYPE *src, \ const non_contiguous_copy_param ©_params, int32_t pe) \ { \ - /* ROCE */ \ - /* RDMA */ \ /* MTE */ \ /* Global State Get */ \ __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); \ @@ -661,8 +713,6 @@ SHMEM_TYPE_FUNC(SHMEM_PUT_TYPENAME_MEM_UB_DETAILED_NBI); SHMEM_DEVICE void shmem_put_##NAME##_mem_nbi(AscendC::GlobalTensor dst, AscendC::LocalTensor src, \ const non_contiguous_copy_param ©_params, int32_t pe) \ { \ - /* ROCE */ \ - /* RDMA */ \ /* MTE */ \ /* Global State Get */ \ __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); \ diff --git a/include/host/shmem_host_def.h b/include/host/shmem_host_def.h index 93808665..837ddd39 100644 --- a/include/host/shmem_host_def.h +++ b/include/host/shmem_host_def.h @@ -93,6 +93,14 @@ enum shmem_init_status_t { SHMEM_STATUS_INVALID = INT_MAX, ///< Invalid status code. }; +/** + * @brief Different transports supported by SHMEM library. +*/ +enum shmem_transport_t : uint8_t { + SHMEM_TRANSPORT_MTE = 1 << 0, ///< MTE Transport. + SHMEM_TRANSPORT_ROCE = 1 << 1, ///< RDMA Transport (RoCE). +}; + /**@} */ // end of group_enums /** diff --git a/include/internal/host_device/shmemi_types.h b/include/internal/host_device/shmemi_types.h index e07bb880..b17f47f9 100644 --- a/include/internal/host_device/shmemi_types.h +++ b/include/internal/host_device/shmemi_types.h @@ -73,7 +73,7 @@ typedef struct { void *heap_base; void *p2p_heap_base[SHMEM_MAX_RANKS]; void *sdma_heap_base[SHMEM_MAX_RANKS]; - void *roce_heap_base[SHMEM_MAX_RANKS]; + uint8_t topo_list[SHMEM_MAX_RANKS]; size_t heap_size; shmemi_team_t *team_pools[SHMEM_MAX_TEAMS]; diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index fe15a604..d0d73e5f 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -38,7 +38,7 @@ constexpr int DEFAULT_BLOCK_NUM = 1; NULL, /* heap_base */ \ {NULL}, /* p2p_heap_base */ \ {NULL}, /* sdma_heap_base */ \ - {NULL}, /* roce_heap_base */ \ + {}, /* topo_list */ \ SIZE_MAX, /* heap_size */ \ {NULL}, /* team_pools */ \ 0, /* sync_pool */ \ @@ -124,10 +124,9 @@ int32_t shmemi_heap_init(shmem_init_attr_t *attributes) uint32_t reach_info = 0; for (int32_t i = 0; i < g_state.npes; i++) { status = smem_shm_topology_can_reach(g_smem_handle, i, &reach_info); + g_state.p2p_heap_base[i] = (void *)((uintptr_t)gva + g_state.heap_size * i); if (reach_info & SMEMS_DATA_OP_MTE) { - g_state.p2p_heap_base[i] = (void *)((uintptr_t)gva + g_state.heap_size * i); - } else { - g_state.p2p_heap_base[i] = NULL; + g_state.topo_list[i] |= SHMEM_TRANSPORT_MTE; } if (reach_info & SMEMS_DATA_OP_SDMA) { g_state.sdma_heap_base[i] = (void *)((uintptr_t)gva + g_state.heap_size * i); @@ -135,9 +134,7 @@ int32_t shmemi_heap_init(shmem_init_attr_t *attributes) g_state.sdma_heap_base[i] = NULL; } if (reach_info & SMEMS_DATA_OP_RDMA) { - g_state.roce_heap_base[i] = (void *)((uintptr_t)gva + g_state.heap_size * i); - } else { - g_state.roce_heap_base[i] = NULL; + g_state.topo_list[i] |= SHMEM_TRANSPORT_ROCE; } } if (shm::g_ipport != nullptr) { diff --git a/tests/unittest/device/mem/rdma_mem/rdma_mem_kernel.cpp b/tests/unittest/device/mem/rdma_mem/rdma_mem_kernel.cpp new file mode 100644 index 00000000..a4fa00c7 --- /dev/null +++ b/tests/unittest/device/mem/rdma_mem/rdma_mem_kernel.cpp @@ -0,0 +1,138 @@ +#include "kernel_operator.h" + +#include "shmem_api.h" +constexpr uint64_t MESSAGE_SIZE = 64; + +extern "C" __global__ __aicore__ void RDMAPollCQTest(GM_ADDR gva, uint64_t config) +{ + shmemx_set_ffts_config(config); + AscendC::TPipe pipe; + AscendC::TBuf buf; + pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); + AscendC::LocalTensor ubLocal32 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); + AscendC::LocalTensor ubLocal64 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); + + int64_t rank = smem_shm_get_global_rank(); + int64_t rank_size = smem_shm_get_global_rank_size(); + GM_ADDR src_addr; + + for (int64_t peer = 0; peer < rank_size; peer++) { + if (peer == rank) { + continue; + } + src_addr = gva + rank * MESSAGE_SIZE; + smem_shm_roce_pollcq_test(src_addr, (GM_ADDR)(shmem_ptr(src_addr, peer)), peer, 0, MESSAGE_SIZE, ubLocal64, ubLocal32, gva + 2048); + } +} + +void test_rdma_poll_cq_do(uint32_t block_dim, void* stream, uint8_t* gva, uint64_t config) +{ + RDMAPollCQTest<<>>(gva, config); +} + +extern "C" __global__ __aicore__ void RDMAGetTestLowLevel(GM_ADDR gva, uint64_t config) +{ + shmemx_set_ffts_config(config); + AscendC::TPipe pipe; + AscendC::TBuf buf; + pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); + AscendC::LocalTensor ubLocal32 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); + AscendC::LocalTensor ubLocal64 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); + + int64_t rank = smem_shm_get_global_rank(); + int64_t rank_size = smem_shm_get_global_rank_size(); + GM_ADDR dest_addr; + + for (int64_t peer = 0; peer < rank_size; peer++) { + if (peer == rank) { + continue; + } + dest_addr = gva + peer * MESSAGE_SIZE; + smem_shm_roce_read((GM_ADDR)(shmem_ptr(dest_addr, peer)), dest_addr, peer, 0, MESSAGE_SIZE, ubLocal64, ubLocal32); + } +} + +void test_rdma_get_low_level(uint32_t block_dim, void* stream, uint8_t* gva, uint64_t config) +{ + RDMAGetTestLowLevel<<>>(gva, config); +} + +extern "C" __global__ __aicore__ void RDMAPutTestLowLevel(GM_ADDR gva, uint64_t config) +{ + shmemx_set_ffts_config(config); + AscendC::TPipe pipe; + AscendC::TBuf buf; + pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); + AscendC::LocalTensor ubLocal32 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); + AscendC::LocalTensor ubLocal64 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); + + int64_t rank = smem_shm_get_global_rank(); + int64_t rank_size = smem_shm_get_global_rank_size(); + GM_ADDR src_addr; + + for (int64_t peer = 0; peer < rank_size; peer++) { + if (peer == rank) { + continue; + } + src_addr = gva + rank * MESSAGE_SIZE; + smem_shm_roce_write(src_addr, (GM_ADDR)(shmem_ptr(src_addr, peer)), peer, 0, MESSAGE_SIZE, ubLocal64, ubLocal32); + } +} + +void test_rdma_put_low_level(uint32_t block_dim, void* stream, uint8_t* gva, uint64_t config) +{ + RDMAPutTestLowLevel<<>>(gva, config); +} + +extern "C" __global__ __aicore__ void RDMAGetTestHighLevel(GM_ADDR gva, uint64_t config) +{ + shmemx_set_ffts_config(config); + int64_t rank = smem_shm_get_global_rank(); + int64_t rank_size = smem_shm_get_global_rank_size(); + GM_ADDR dest_addr; + + for (int64_t peer = 0; peer < rank_size; peer++) { + if (peer == rank) { + continue; + } + dest_addr = gva + peer * MESSAGE_SIZE; + shmem_get_uint8_mem_nbi(dest_addr, dest_addr, MESSAGE_SIZE, peer); + } +} + +void test_rdma_get_high_level(uint32_t block_dim, void* stream, uint8_t* gva, uint64_t config) +{ + RDMAGetTestHighLevel<<>>(gva, config); +} + +extern "C" __global__ __aicore__ void RDMAPutTestHighLevel(GM_ADDR gva, uint64_t config) +{ + shmemx_set_ffts_config(config); + int64_t rank = smem_shm_get_global_rank(); + int64_t rank_size = smem_shm_get_global_rank_size(); + GM_ADDR src_addr; + + for (int64_t peer = 0; peer < rank_size; peer++) { + if (peer == rank) { + continue; + } + src_addr = gva + rank * MESSAGE_SIZE; + shmem_put_uint8_mem_nbi(src_addr, src_addr, MESSAGE_SIZE, peer); + } +} + +void test_rdma_put_high_level(uint32_t block_dim, void* stream, uint8_t* gva, uint64_t config) +{ + RDMAPutTestHighLevel<<>>(gva, config); +} + +extern "C" __global__ __aicore__ void shmem_rdma_get_qpinfo_test(GM_ADDR gva, uint32_t rankId, uint64_t config) +{ + shmemx_set_ffts_config(config); + smem_shm_roce_qpinfo_test(gva, rankId, 0); +} + +void shmem_rdma_get_qpinfo_test_do(void* stream, uint8_t* gva, uint32_t rankId, uint64_t config) +{ + shmem_rdma_get_qpinfo_test<<<1, nullptr, stream>>>(gva, rankId, config); +} \ No newline at end of file diff --git a/tests/unittest/host/main_test.cpp b/tests/unittest/host/main_test.cpp index a440b281..aed20fe2 100644 --- a/tests/unittest/host/main_test.cpp +++ b/tests/unittest/host/main_test.cpp @@ -38,6 +38,7 @@ void test_init(int rank_id, int n_ranks, uint64_t local_mem_size, aclrtStream *s shmem_init_attr_t* attributes; shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); + attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; status = shmem_init_attr(attributes); EXPECT_EQ(status, 0); *st = stream; diff --git a/tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp b/tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp new file mode 100644 index 00000000..1d06a5df --- /dev/null +++ b/tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp @@ -0,0 +1,120 @@ +#include +#include +#include +#include + +#include "acl/acl.h" +#include "shmemi_host_common.h" + +extern int test_gnpu_num; +extern int test_first_npu; +extern void test_mutil_task(std::function func, uint64_t local_mem_size, int processCount); +extern void test_init(int rank_id, int n_ranks, uint64_t local_mem_size, aclrtStream *st); +extern void test_finalize(aclrtStream stream, int device_id); + +extern void test_rdma_put_low_level(uint32_t block_dim, void* stream, uint8_t* gva, uint64_t config); +extern void test_rdma_get_low_level(uint32_t block_dim, void* stream, uint8_t* gva, uint64_t config); +extern void test_rdma_put_high_level(uint32_t block_dim, void* stream, uint8_t* gva, uint64_t config); +extern void test_rdma_get_high_level(uint32_t block_dim, void* stream, uint8_t* gva, uint64_t config); +extern void shmem_rdma_get_qpinfo_test_do(void* stream, uint8_t* gva, uint32_t rankId, uint64_t config); +extern void test_rdma_poll_cq_do(uint32_t block_dim, void* stream, uint8_t* gva, uint64_t config); + +static void test_rdma_poll_cq(aclrtStream stream, uint8_t *gva, uint32_t rank_id, uint64_t heap_size) +{ + size_t messageSize = 128; + uint64_t *xHost; + size_t totalSize = 120; + + ASSERT_EQ(aclrtMallocHost((void **)(&xHost), totalSize), 0); + for (uint32_t i = 0; i < messageSize / sizeof(uint32_t); i++) { + xHost[i] = rank_id + 10; + } + ASSERT_EQ(aclrtMemcpy(gva + (rank_id + 1) * messageSize, messageSize, xHost, messageSize, ACL_MEMCPY_HOST_TO_DEVICE), 0); + + uint32_t block_dim = 1; + test_rdma_poll_cq_do(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); + ASSERT_EQ(aclrtSynchronizeStream(stream), 0); + sleep(2); + + std::string p_name = "[Process " + std::to_string(rank_id) + "] "; + std::cout << p_name; + ASSERT_EQ(aclrtMemcpy(xHost, totalSize, gva + 2048, totalSize, ACL_MEMCPY_DEVICE_TO_HOST), 0); + for (uint32_t i = 0; i < totalSize / sizeof(uint64_t); i++) { + printf("PollCQ index = %d, value = %lu\n", i, xHost[i]); + } +} + +static void test_rdma_put_get(aclrtStream stream, uint8_t *gva, uint32_t rank_id, uint32_t rank_size) +{ + size_t messageSize = 64; + uint32_t *xHost; + size_t totalSize = messageSize * rank_size; + + ASSERT_EQ(aclrtMallocHost((void **)(&xHost), totalSize), 0); + for (uint32_t i = 0; i < messageSize / sizeof(uint32_t); i++) { + xHost[i] = rank_id + 10; + } + ASSERT_EQ(aclrtMemcpy(gva + rank_id * messageSize, messageSize, xHost, messageSize, ACL_MEMCPY_HOST_TO_DEVICE), 0); + + uint32_t block_dim = 1; + // test_rdma_put_low_level(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); + // test_rdma_get_low_level(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); + test_rdma_put_high_level(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); + // test_rdma_get_high_level(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); + ASSERT_EQ(aclrtSynchronizeStream(stream), 0); + sleep(2); + + std::string p_name = "[Process " + std::to_string(rank_id) + "] "; + std::cout << p_name; + ASSERT_EQ(aclrtMemcpy(xHost, totalSize, gva, totalSize, ACL_MEMCPY_DEVICE_TO_HOST), 0); + for (uint32_t i = 0; i < rank_size; i++) { + ASSERT_EQ(xHost[i * messageSize / sizeof(uint32_t)], i + 10); + } +} + +static void test_rdma_get_info(aclrtStream stream, uint8_t *gva, uint32_t rankId, uint32_t rankSize) { + uint64_t *xHost; + size_t totalSize = 120; + ASSERT_EQ(aclrtMallocHost((void **)(&xHost), totalSize), 0); + memset(xHost, 0xEE, totalSize); + ASSERT_EQ(aclrtMemcpy(gva, totalSize, xHost, totalSize, ACL_MEMCPY_HOST_TO_DEVICE), 0); + for (uint32_t curRank = 0; curRank < rankSize; curRank++) { + if (curRank == rankId) { + continue; + } + shmem_rdma_get_qpinfo_test_do(stream, gva, curRank, shmemx_get_ffts_config()); + ASSERT_EQ(aclrtSynchronizeStream(stream), 0); + sleep(1); + + ASSERT_EQ(aclrtMemcpy(xHost, totalSize, gva, totalSize, ACL_MEMCPY_DEVICE_TO_HOST), 0); + for (uint32_t i = 0; i < totalSize / sizeof(uint64_t); i++) { + printf("GetQPInfo srcRank = %d, destRank = %d, index = %d, value = %lu\n", rankId, curRank, i, xHost[i]); + } + } + + ASSERT_EQ(aclrtFreeHost(xHost), 0); +} + +void test_shmem_rdma_mem(int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + ASSERT_NE(stream, nullptr); + + void* ptr = shmem_malloc(1024); + // test_rdma_poll_cq(stream, (uint8_t *)ptr, rank_id, n_ranks); + test_rdma_put_get(stream, (uint8_t *)ptr, rank_id, n_ranks); + // test_rdma_get_info(stream, (uint8_t *)ptr, rank_id, n_ranks); + std::cout << "[TEST] begin to exit...... rank_id: " << rank_id << std::endl; + test_finalize(stream, device_id); + if (::testing::Test::HasFailure()){ + exit(1); + } +} + +TEST(TestMemApi, TestShmemRDMAMem) +{ + const int processCount = test_gnpu_num; + uint64_t local_mem_size = 1024UL * 1024UL * 64; + test_mutil_task(test_shmem_rdma_mem, local_mem_size, processCount); +} \ No newline at end of file -- Gitee From 163135213ab77f967ddef6820b0865a9a2dd5c68 Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Tue, 2 Sep 2025 20:35:11 +0800 Subject: [PATCH 15/74] Optimize bandwidth test with warmup. --- examples/CMakeLists.txt | 1 + examples/rdma_perftest/README.md | 2 +- examples/rdma_perftest/main.cpp | 86 +++++++++++-------- .../rdma_perftest/rdma_perftest_kernel.cpp | 40 ++------- 4 files changed, 63 insertions(+), 66 deletions(-) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 3ffdb7bd..94920765 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -43,6 +43,7 @@ function(shmem_add_collective_example NAME) ${PROJECT_SOURCE_DIR}/3rdparty/catlass/examples/common ${PROJECT_SOURCE_DIR}/examples/${NAME} ${PROJECT_SOURCE_DIR}/examples/utils + ${PROJECT_SOURCE_DIR}/src/host ) target_link_libraries(${NAME} PRIVATE shmem ${NAME}_kernel ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_smem.so) endfunction() diff --git a/examples/rdma_perftest/README.md b/examples/rdma_perftest/README.md index 1b1d37fc..1c73d65e 100644 --- a/examples/rdma_perftest/README.md +++ b/examples/rdma_perftest/README.md @@ -16,7 +16,7 @@ export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/3rdparty/memfab - n_ranks: 全局Rank数量,只支持2个Rank。 - rank_id: 当前进程的Rank号。 -- ipport: SHMEM初始化需要的IP及端口号,格式为tcp://:<端口号>。 +- ipport: SHMEM初始化需要的IP及端口号,格式为tcp://:<端口号>。如果执行跨机测试,需要讲IP设为rank0所在Host的IP。 - g_npus: 当前卡上启动的NPU数量。 - f_rank: 当前卡上使用的第一个Rank号。 - f_npu: 当前卡上使用的第一个NPU卡号。 diff --git a/examples/rdma_perftest/main.cpp b/examples/rdma_perftest/main.cpp index 36796019..62c39bea 100644 --- a/examples/rdma_perftest/main.cpp +++ b/examples/rdma_perftest/main.cpp @@ -18,6 +18,7 @@ #include "acl/acl.h" #include "shmem_api.h" +#include "shmemi_host_common.h" int g_npus = 8; const char *ipport; @@ -28,7 +29,7 @@ const char *test_type; extern void rdma_highlevel_put_pingpong_latency_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length); extern void rdma_postsend_cost_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length); extern void rdma_highlevel_put_bw_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length); -extern void rdma_mte_put_bw_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length); +extern void rdma_mte_put_bw_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length, int64_t iter); int test_shmem_rdma_highlevel_put_pingpong_latency(int rank_id, int n_ranks, uint64_t local_mem_size, int message_length) { @@ -69,10 +70,11 @@ int test_shmem_rdma_highlevel_put_pingpong_latency(int rank_id, int n_ranks, uin std::cout << "RDMA highlevel put pingpong latency test. Message length = " << message_length << " Byte; latency = " << xHost[0] / 50.0 << " us." << std::endl; } - status = shmem_finalize(); - status = aclrtDestroyStream(stream); - status = aclrtResetDevice(device_id); - status = aclFinalize(); + aclrtFreeHost(xHost); + shmem_finalize(); + aclrtDestroyStream(stream); + aclrtResetDevice(device_id); + aclFinalize(); return 0; } @@ -110,14 +112,15 @@ int test_shmem_rdma_postsend_cost(int rank_id, int n_ranks, uint64_t local_mem_s } aclrtSynchronizeStream(stream); if (rank_id == 0) { - aclrtMemcpy(xHost, sizeof(uint32_t), gva + message_length * n_ranks, sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST); + aclrtMemcpy(xHost, sizeof(int64_t), gva + message_length * n_ranks, sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST); std::cout << "RDMA postsend cost test. Message length = " << message_length << " Byte; postsend cost = " << xHost[0] / (50.0 * 500) << " us." << std::endl; } - status = shmem_finalize(); - status = aclrtDestroyStream(stream); - status = aclrtResetDevice(device_id); - status = aclFinalize(); + aclrtFreeHost(xHost); + shmem_finalize(); + aclrtDestroyStream(stream); + aclrtResetDevice(device_id); + aclFinalize(); return 0; } @@ -156,10 +159,11 @@ int test_shmem_rdma_highlevel_put_bw(int rank_id, int n_ranks, uint64_t local_me std::cout << "RDMA high level put bandwidth test. Message length = " << message_length << " Byte; time = " << xHost[0] / (50.0) << " us." << std::endl; } - status = shmem_finalize(); - status = aclrtDestroyStream(stream); - status = aclrtResetDevice(device_id); - status = aclFinalize(); + aclrtFreeHost(xHost); + shmem_finalize(); + aclrtDestroyStream(stream); + aclrtResetDevice(device_id); + aclFinalize(); return 0; } @@ -182,29 +186,43 @@ int test_shmem_rdma_mte_put_bw(int rank_id, int n_ranks, uint64_t local_mem_size uint64_t fftsConfig = shmemx_get_ffts_config(); uint8_t* gva = (uint8_t*)shmem_malloc(1024 * 1024 * 32); - - int64_t *xHost; - size_t totalSize = message_length * n_ranks; - - aclrtMallocHost((void **)(&xHost), totalSize); - for (uint32_t i = 0; i < message_length / sizeof(int64_t); i++) { - xHost[i] = rank_id + 10; + int64_t *inHost; + int64_t *outHost; + size_t totalSize = message_length * n_ranks * 3; + aclrtMallocHost((void **)(&inHost), totalSize); + aclrtMallocHost((void **)(&outHost), totalSize); + memset(inHost, 0, totalSize); + double rdmaTotalTime = 0.0; + double mteTotalTime = 0.0; + + for (int iter = 0; iter < 20; iter++) { + for (uint32_t i = 0; i < message_length / sizeof(int64_t); i++) { + inHost[i + rank_id * message_length / sizeof(int64_t)] = rank_id + 10 + iter; + } + for (uint32_t i = 0; i < message_length / sizeof(int64_t); i++) { + inHost[i + (rank_id + n_ranks) * message_length / sizeof(int64_t)] = rank_id + 10 + iter; + } + aclrtMemcpy(gva, totalSize, inHost, totalSize, ACL_MEMCPY_HOST_TO_DEVICE); + shm::shmemi_control_barrier_all(); + rdma_mte_put_bw_do(1, stream, fftsConfig, gva, message_length, iter); + aclrtSynchronizeStream(stream); + if (rank_id == 0 && iter >= 10) { + aclrtMemcpy(outHost, 64, gva + message_length * n_ranks * 2, 64, ACL_MEMCPY_DEVICE_TO_HOST); + rdmaTotalTime += outHost[0] / 50.0; + mteTotalTime += outHost[6] / 50.0; + } } - aclrtMemcpy(gva + rank_id * message_length, message_length, xHost, message_length, ACL_MEMCPY_HOST_TO_DEVICE); - aclrtMemcpy(gva + (rank_id + n_ranks) * message_length, message_length, xHost, message_length, ACL_MEMCPY_HOST_TO_DEVICE); - - rdma_mte_put_bw_do(1, stream, fftsConfig, gva, message_length); - aclrtSynchronizeStream(stream); if (rank_id == 0) { - aclrtMemcpy(xHost, 64, gva + message_length * n_ranks * 2, 64, ACL_MEMCPY_DEVICE_TO_HOST); - std::cout << "RDMA rdma mte test. Message length = " << message_length << " Byte; RDMA time = " << xHost[0] / (50.0) << " us." << std::endl; - std::cout << "RDMA rdma mte test. Message length = " << message_length << " Byte; MTE time = " << xHost[6] / (50.0) << " us." << std::endl; + std::cout << "RDMA rdma mte test. Message length = " << message_length << " Byte; average RDMA time = " << rdmaTotalTime / 10.0 << " us." << std::endl; + std::cout << "RDMA rdma mte test. Message length = " << message_length << " Byte; average MTE time = " << mteTotalTime / 10.0 << " us." << std::endl; } - status = shmem_finalize(); - status = aclrtDestroyStream(stream); - status = aclrtResetDevice(device_id); - status = aclFinalize(); + aclrtFreeHost(inHost); + aclrtFreeHost(outHost); + shmem_finalize(); + aclrtDestroyStream(stream); + aclrtResetDevice(device_id); + aclFinalize(); return 0; } @@ -236,7 +254,7 @@ int main(int argc, char *argv[]) test_shmem_rdma_postsend_cost(rank_id, n_ranks, local_mem_size, msg_len); } else if (std::string(test_type) == "highlevel_put_bw") { test_shmem_rdma_highlevel_put_bw(rank_id, n_ranks, local_mem_size, msg_len); - }else if (std::string(test_type) == "rdma_mte_bw") { + } else if (std::string(test_type) == "rdma_mte_bw") { test_shmem_rdma_mte_put_bw(rank_id, n_ranks, local_mem_size, msg_len); } diff --git a/examples/rdma_perftest/rdma_perftest_kernel.cpp b/examples/rdma_perftest/rdma_perftest_kernel.cpp index 862a96a8..acca5f98 100644 --- a/examples/rdma_perftest/rdma_perftest_kernel.cpp +++ b/examples/rdma_perftest/rdma_perftest_kernel.cpp @@ -14,28 +14,6 @@ constexpr uint32_t MAGIC_VAL = 10; constexpr uint32_t WARMUP_MESSAGE_LENGTH = 32; -/** - * @brief RDMA Quiet function. This synchronous function ensures all previous RDMA WQEs are completed (data has arrived at the destination NIC). - * - * @param remoteRankId [in] destination rank ID - * @param ubLocal64 [in] temporary UB local tensor of uint64_t used as workspace - * @param ubLocal32 [in] temporary UB local tensor of uint32_t used as workspace - */ - -SHMEM_DEVICE void smem_shm_roce_quiet(uint32_t remoteRankId, AscendC::LocalTensor ubLocal64, AscendC::LocalTensor ubLocal32) { - __gm__ HybmDeviceMeta* metaPtr = (__gm__ HybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + - SMEM_SHM_DEVICE_GLOBAL_META_SIZE); - __gm__ AIVRDMAInfo* RDMAInfo = (__gm__ AIVRDMAInfo*)(metaPtr->qpInfoAddress); - uint32_t qpNum = RDMAInfo->qpNum; - for (uint32_t qpIdx = 0; qpIdx < qpNum; qpIdx++) { - __gm__ WQCtx* qpCtxEntry = (__gm__ WQCtx*)(RDMAInfo->sqPtr + (remoteRankId * qpNum + qpIdx) * sizeof(WQCtx)); - auto curHardwareHeadAddr = qpCtxEntry->headAddr; - cacheWriteThrough((__gm__ uint8_t*)curHardwareHeadAddr, 8); - uint32_t curHead = *(__gm__ uint32_t*)(curHardwareHeadAddr); - smem_shm_roce_poll_cq(remoteRankId, qpIdx, curHead - 1, ubLocal64, ubLocal32); - } -} - extern "C" __global__ __aicore__ void rdma_highlevel_put_pingpong_latency(uint64_t fftsConfig, GM_ADDR gva, int message_length) { shmemx_set_ffts_config(fftsConfig); if (AscendC::GetSubBlockIdx() != 0) { @@ -156,7 +134,7 @@ extern "C" __global__ __aicore__ void rdma_highlevel_put_bw(uint64_t fftsConfig, for (int i = 0; i < 10000; i++) { shmem_put_uint8_mem_nbi(src_addr, src_addr, message_length, peer); } - smem_shm_roce_quiet(peer, ubLocal64, ubLocal32); + smem_shm_roce_quiet(peer, 0, ubLocal64, ubLocal32); shmem_put_uint8_mem_nbi(gva + rank_size * message_length + 8, src_addr, sizeof(uint32_t), peer); while (*(__gm__ uint32_t*)(gva + message_length * rank_size + 16) != peer + MAGIC_VAL) { cacheWriteThrough(gva + message_length * rank_size + 16, 8); @@ -180,7 +158,7 @@ void rdma_highlevel_put_bw_do(uint32_t block_dim, void* stream, uint64_t fftsCon rdma_highlevel_put_bw<<<1, nullptr, stream>>>(fftsConfig, gva, message_length); } -extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_ADDR gva, int message_length) { +extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_ADDR gva, int message_length, int64_t iter) { shmemx_set_ffts_config(fftsConfig); AscendC::LocalTensor ubLocal32; ubLocal32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); @@ -204,9 +182,9 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_AD for (int i = 0; i < 10000; i++) { smem_shm_roce_write(src_addr, (GM_ADDR)shmem_ptr(src_addr, peer), peer, 0, message_length, ubLocal64, ubLocal32); } - smem_shm_roce_quiet(peer, ubLocal64, ubLocal32); + smem_shm_roce_quiet(peer, 0, ubLocal64, ubLocal32); smem_shm_roce_write(src_addr, (GM_ADDR)shmem_ptr(gva + rank_size * message_length * 2 + 8, peer), peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); - while (*(__gm__ int64_t*)(gva + message_length * rank_size * 2 + 16) != peer + MAGIC_VAL) { + while (*(__gm__ int64_t*)(gva + message_length * rank_size * 2 + 16) != peer + MAGIC_VAL + iter) { cacheWriteThrough(gva + message_length * rank_size * 2 + 16, 8); AscendC::GetSystemCycle(); } @@ -215,7 +193,7 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_AD *(__gm__ int64_t*)(gva + message_length * rank_size * 2) = end - start; } else { peer = 0; - while (*(__gm__ int64_t*)(gva + rank_size * message_length * 2 + 8) != peer + MAGIC_VAL) { + while (*(__gm__ int64_t*)(gva + rank_size * message_length * 2 + 8) != peer + MAGIC_VAL + iter) { cacheWriteThrough(gva + rank_size * message_length * 2 + 8, 8); AscendC::GetSystemCycle(); } @@ -237,7 +215,7 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_AD } AscendC::PipeBarrier(); shmem_mte_put_mem_nbi(gva + rank_size * message_length * 2 + 24, src_addr, reinterpret_cast<__ubuf__ uint8_t*>(copy_ub), copy_ub_size, sizeof(uint32_t), peer, copy_event_id); - while (*(__gm__ uint32_t*)(gva + message_length * rank_size * 2 + 32) != peer + MAGIC_VAL) { + while (*(__gm__ uint32_t*)(gva + message_length * rank_size * 2 + 32) != peer + MAGIC_VAL + iter) { cacheWriteThrough(gva + message_length * rank_size * 2 + 32, 8); AscendC::GetSystemCycle(); } @@ -246,7 +224,7 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_AD *(__gm__ int64_t*)(gva + message_length * rank_size * 2 + 48) = end - start; } else { peer = 0; - while (*(__gm__ uint32_t*)(gva + rank_size * message_length * 2 + 24) != peer + MAGIC_VAL) { + while (*(__gm__ uint32_t*)(gva + rank_size * message_length * 2 + 24) != peer + MAGIC_VAL + iter) { cacheWriteThrough(gva + rank_size * message_length * 2 + 24, 8); AscendC::GetSystemCycle(); } @@ -256,6 +234,6 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_AD } } -void rdma_mte_put_bw_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length) { - rdma_mte_put_bw<<<2, nullptr, stream>>>(fftsConfig, gva, message_length); +void rdma_mte_put_bw_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length, int64_t iter) { + rdma_mte_put_bw<<<2, nullptr, stream>>>(fftsConfig, gva, message_length, iter); } \ No newline at end of file -- Gitee From 7d5ea6112c71272285e670b2311b584bf05f0bfc Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Tue, 2 Sep 2025 20:42:25 +0800 Subject: [PATCH 16/74] Add low level RDMA API. --- .../low_level/shmem_device_low_level_rma.h | 100 +++++++++++++++++ include/device/shmem_device_rma.h | 3 - .../device/mem/rdma_mem/rdma_mem_kernel.cpp | 48 +-------- .../host/mem/rdma_mem/rdma_mem_host_test.cpp | 101 +++++++----------- 4 files changed, 144 insertions(+), 108 deletions(-) diff --git a/include/device/low_level/shmem_device_low_level_rma.h b/include/device/low_level/shmem_device_low_level_rma.h index 79148f65..c43cb668 100644 --- a/include/device/low_level/shmem_device_low_level_rma.h +++ b/include/device/low_level/shmem_device_low_level_rma.h @@ -14,6 +14,9 @@ #include "internal/device/shmemi_device_common.h" #include "device/shmem_device_team.h" +constexpr uint64_t SHMEM_INTERNAL_UB_BUF_START_ADDR = 188 * 1024; +constexpr uint32_t UB_ALIGN_SIZE = 32; + /** * @brief Translate an local symmetric address to remote symmetric address on the specified PE. * @@ -78,6 +81,31 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(__gm__ T *dst, __gm__ T *src, __ubuf__ T } } +/** + * @brief Asynchronous interface. Copy contiguous data on symmetric memory from the specified PE to address on the local device. + * + * @param dst [in] Pointer on local device of the destination data. + * @param src [in] Pointer on Symmetric memory of the source data. + * @param buf [in] Pointer on local UB, available space larger than 64 Bytes. + * @param elem_size [in] Number of elements in the destination and source arrays. + * @param pe [in] PE number of the remote PE. + */ +template +SHMEM_DEVICE void shmem_roce_get_mem_nbi(__gm__ T* dst, __gm__ T* src, __ubuf__ T* buf, uint32_t elem_size, int pe) +{ + auto ptr = shmem_ptr(src, pe); + AscendC::LocalTensor ub_tensor_32; + ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ub_tensor_32.address_.bufferAddr = reinterpret_cast(buf); + ub_tensor_32.address_.dataLen = UB_ALIGN_SIZE; + AscendC::LocalTensor ub_tensor_64; + ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ub_tensor_64.address_.bufferAddr = reinterpret_cast(buf) + UB_ALIGN_SIZE; + ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; + smem_shm_roce_read((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)dst, pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); +} + + /** * @brief Asynchronous interface. Provide a high-performance way to copy non-contiguous data * on symmetric memory from the specified PE to address on the local device. @@ -167,6 +195,30 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(AscendC::GlobalTensor dst, AscendC::G } } +/** + * @brief Asynchronous interface. Copy contiguous data on symmetric memory from the specified PE to address on the local PE. + * + * @param dst [in] GlobalTensor on local device of the destination data. + * @param src [in] GlobalTensor on Symmetric memory of the source data. + * @param buf [in] LocalTensor on local UB, available space larger than 64 Bytes. + * @param elem_size [in] Number of elements in the destination and source arrays. + * @param pe [in] PE number of the remote PE. + */ +template +SHMEM_DEVICE void shmem_roce_get_mem_nbi(AscendC::GlobalTensor dst, AscendC::GlobalTensor src, AscendC::LocalTensor buf, uint32_t elem_size, int pe) +{ + auto ptr = shmem_ptr((__gm__ void *)src.GetPhyAddr(), pe); + AscendC::LocalTensor ub_tensor_32; + ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ub_tensor_32.address_.bufferAddr = reinterpret_cast(buf.GetPhyAddr()); + ub_tensor_32.address_.dataLen = UB_ALIGN_SIZE; + AscendC::LocalTensor ub_tensor_64; + ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ub_tensor_64.address_.bufferAddr = reinterpret_cast(buf.GetPhyAddr()) + UB_ALIGN_SIZE; + ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; + smem_shm_roce_read((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)dst.GetPhyAddr(), pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); +} + /** * @brief Asynchronous interface. Provide a high-performance way to copy non-contiguous data * on symmetric memory from the specified PE to address on the local device. @@ -247,6 +299,30 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(__gm__ T *dst, __gm__ T *src, __ubuf__ T } } +/** + * @brief Asynchronous interface. Copy contiguous data on local PE to symmetric address on the specified PE. + * + * @param dst [in] Pointer on Symmetric memory of the destination data. + * @param src [in] Pointer on local device of the source data. + * @param buf [in] Pointer on local UB, available space larger than 64 Bytes. + * @param elem_size [in] Number of elements in the destination and source arrays. + * @param pe [in] PE number of the remote PE. + */ +template +SHMEM_DEVICE void shmem_roce_put_mem_nbi(__gm__ T* dst, __gm__ T* src, __ubuf__ T* buf, uint32_t elem_size, int pe) +{ + auto ptr = shmem_ptr(dst, pe); + AscendC::LocalTensor ub_tensor_32; + ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ub_tensor_32.address_.bufferAddr = reinterpret_cast(buf); + ub_tensor_32.address_.dataLen = UB_ALIGN_SIZE; + AscendC::LocalTensor ub_tensor_64; + ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ub_tensor_64.address_.bufferAddr = reinterpret_cast(buf) + UB_ALIGN_SIZE; + ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; + smem_shm_roce_write((__gm__ uint8_t*)src, (__gm__ uint8_t*)ptr, pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); +} + /** * @brief Asynchronous interface. Provide a high-performance way to copy non-contiguous data * on local PE to symmetric address on the specified PE. @@ -336,6 +412,30 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(AscendC::GlobalTensor dst, AscendC::G } } +/** + * @brief Asynchronous interface. Copy contiguous data on local PE to symmetric address on the specified PE. + * + * @param dst [in] GlobalTensor on Symmetric memory of the destination data. + * @param src [in] GlobalTensor on local device of the source data. + * @param buf [in] Pointer on local UB, available space larger than 64 Bytes. + * @param elem_size [in] Number of elements in the destination and source arrays. + * @param pe [in] PE number of the remote PE. + */ +template +SHMEM_DEVICE void shmem_roce_put_mem_nbi(AscendC::GlobalTensor dst, AscendC::GlobalTensor src, AscendC::LocalTensor buf, uint32_t elem_size, int pe, AscendC::TEventID EVENT_ID) +{ + auto ptr = shmem_ptr((__gm__ void *)dst.GetPhyAddr(), pe); + AscendC::LocalTensor ub_tensor_32; + ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ub_tensor_32.address_.bufferAddr = reinterpret_cast(buf.GetPhyAddr()); + ub_tensor_32.address_.dataLen = UB_ALIGN_SIZE; + AscendC::LocalTensor ub_tensor_64; + ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ub_tensor_64.address_.bufferAddr = reinterpret_cast(buf.GetPhyAddr()) + UB_ALIGN_SIZE; + ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; + smem_shm_roce_write((__gm__ uint8_t*)(src.GetPhyAddr()), (__gm__ uint8_t*)ptr, pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); +} + /** * @brief Asynchronous interface. Provide a high-performance way to copy non-contiguous data * on local PE to symmetric address on the specified PE. diff --git a/include/device/shmem_device_rma.h b/include/device/shmem_device_rma.h index c8805c5c..2c847f79 100644 --- a/include/device/shmem_device_rma.h +++ b/include/device/shmem_device_rma.h @@ -18,9 +18,6 @@ #include "shmem_device_sync.h" #include "host/shmem_host_def.h" -constexpr uint64_t SHMEM_INTERNAL_UB_BUF_START_ADDR = 188 * 1024; -constexpr uint32_t UB_ALIGN_SIZE = 32; - /** * @brief Standard RMA Types and Names * diff --git a/tests/unittest/device/mem/rdma_mem/rdma_mem_kernel.cpp b/tests/unittest/device/mem/rdma_mem/rdma_mem_kernel.cpp index a4fa00c7..64ab741c 100644 --- a/tests/unittest/device/mem/rdma_mem/rdma_mem_kernel.cpp +++ b/tests/unittest/device/mem/rdma_mem/rdma_mem_kernel.cpp @@ -3,41 +3,13 @@ #include "shmem_api.h" constexpr uint64_t MESSAGE_SIZE = 64; -extern "C" __global__ __aicore__ void RDMAPollCQTest(GM_ADDR gva, uint64_t config) -{ - shmemx_set_ffts_config(config); - AscendC::TPipe pipe; - AscendC::TBuf buf; - pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); - AscendC::LocalTensor ubLocal32 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); - AscendC::LocalTensor ubLocal64 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); - - int64_t rank = smem_shm_get_global_rank(); - int64_t rank_size = smem_shm_get_global_rank_size(); - GM_ADDR src_addr; - - for (int64_t peer = 0; peer < rank_size; peer++) { - if (peer == rank) { - continue; - } - src_addr = gva + rank * MESSAGE_SIZE; - smem_shm_roce_pollcq_test(src_addr, (GM_ADDR)(shmem_ptr(src_addr, peer)), peer, 0, MESSAGE_SIZE, ubLocal64, ubLocal32, gva + 2048); - } -} - -void test_rdma_poll_cq_do(uint32_t block_dim, void* stream, uint8_t* gva, uint64_t config) -{ - RDMAPollCQTest<<>>(gva, config); -} - extern "C" __global__ __aicore__ void RDMAGetTestLowLevel(GM_ADDR gva, uint64_t config) { shmemx_set_ffts_config(config); AscendC::TPipe pipe; AscendC::TBuf buf; pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); - AscendC::LocalTensor ubLocal32 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); - AscendC::LocalTensor ubLocal64 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); + AscendC::LocalTensor ubLocal = buf.GetWithOffset(UB_ALIGN_SIZE * 2, 0); int64_t rank = smem_shm_get_global_rank(); int64_t rank_size = smem_shm_get_global_rank_size(); @@ -48,7 +20,7 @@ extern "C" __global__ __aicore__ void RDMAGetTestLowLevel(GM_ADDR gva, uint64_t continue; } dest_addr = gva + peer * MESSAGE_SIZE; - smem_shm_roce_read((GM_ADDR)(shmem_ptr(dest_addr, peer)), dest_addr, peer, 0, MESSAGE_SIZE, ubLocal64, ubLocal32); + shmem_roce_get_mem_nbi(dest_addr, dest_addr, (__ubuf__ uint8_t*)ubLocal.GetPhyAddr(), MESSAGE_SIZE, peer); } } @@ -63,8 +35,7 @@ extern "C" __global__ __aicore__ void RDMAPutTestLowLevel(GM_ADDR gva, uint64_t AscendC::TPipe pipe; AscendC::TBuf buf; pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); - AscendC::LocalTensor ubLocal32 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); - AscendC::LocalTensor ubLocal64 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); + AscendC::LocalTensor ubLocal = buf.GetWithOffset(UB_ALIGN_SIZE * 2, 0); int64_t rank = smem_shm_get_global_rank(); int64_t rank_size = smem_shm_get_global_rank_size(); @@ -75,7 +46,7 @@ extern "C" __global__ __aicore__ void RDMAPutTestLowLevel(GM_ADDR gva, uint64_t continue; } src_addr = gva + rank * MESSAGE_SIZE; - smem_shm_roce_write(src_addr, (GM_ADDR)(shmem_ptr(src_addr, peer)), peer, 0, MESSAGE_SIZE, ubLocal64, ubLocal32); + shmem_roce_put_mem_nbi(src_addr, src_addr, (__ubuf__ uint8_t*)ubLocal.GetPhyAddr(), MESSAGE_SIZE, peer); } } @@ -124,15 +95,4 @@ extern "C" __global__ __aicore__ void RDMAPutTestHighLevel(GM_ADDR gva, uint64_t void test_rdma_put_high_level(uint32_t block_dim, void* stream, uint8_t* gva, uint64_t config) { RDMAPutTestHighLevel<<>>(gva, config); -} - -extern "C" __global__ __aicore__ void shmem_rdma_get_qpinfo_test(GM_ADDR gva, uint32_t rankId, uint64_t config) -{ - shmemx_set_ffts_config(config); - smem_shm_roce_qpinfo_test(gva, rankId, 0); -} - -void shmem_rdma_get_qpinfo_test_do(void* stream, uint8_t* gva, uint32_t rankId, uint64_t config) -{ - shmem_rdma_get_qpinfo_test<<<1, nullptr, stream>>>(gva, rankId, config); } \ No newline at end of file diff --git a/tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp b/tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp index 1d06a5df..3cec2bf9 100644 --- a/tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp +++ b/tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp @@ -16,83 +16,64 @@ extern void test_rdma_put_low_level(uint32_t block_dim, void* stream, uint8_t* g extern void test_rdma_get_low_level(uint32_t block_dim, void* stream, uint8_t* gva, uint64_t config); extern void test_rdma_put_high_level(uint32_t block_dim, void* stream, uint8_t* gva, uint64_t config); extern void test_rdma_get_high_level(uint32_t block_dim, void* stream, uint8_t* gva, uint64_t config); -extern void shmem_rdma_get_qpinfo_test_do(void* stream, uint8_t* gva, uint32_t rankId, uint64_t config); -extern void test_rdma_poll_cq_do(uint32_t block_dim, void* stream, uint8_t* gva, uint64_t config); -static void test_rdma_poll_cq(aclrtStream stream, uint8_t *gva, uint32_t rank_id, uint64_t heap_size) +static void test_rdma_put_get(aclrtStream stream, uint8_t *gva, uint32_t rank_id, uint32_t rank_size) { - size_t messageSize = 128; - uint64_t *xHost; - size_t totalSize = 120; + size_t messageSize = 64; + uint32_t *inHost; + uint32_t *outHost; + size_t totalSize = messageSize * rank_size; + uint32_t block_dim = 1; - ASSERT_EQ(aclrtMallocHost((void **)(&xHost), totalSize), 0); + ASSERT_EQ(aclrtMallocHost((void **)(&inHost), totalSize), 0); + ASSERT_EQ(aclrtMallocHost((void **)(&outHost), totalSize), 0); + memset(inHost, 0, totalSize); for (uint32_t i = 0; i < messageSize / sizeof(uint32_t); i++) { - xHost[i] = rank_id + 10; + inHost[i + rank_id * messageSize / sizeof(uint32_t)] = rank_id + 10; } - ASSERT_EQ(aclrtMemcpy(gva + (rank_id + 1) * messageSize, messageSize, xHost, messageSize, ACL_MEMCPY_HOST_TO_DEVICE), 0); - uint32_t block_dim = 1; - test_rdma_poll_cq_do(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); + ASSERT_EQ(aclrtMemcpy(gva, totalSize, inHost, totalSize, ACL_MEMCPY_HOST_TO_DEVICE), 0); + shm::shmemi_control_barrier_all(); + test_rdma_put_low_level(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); ASSERT_EQ(aclrtSynchronizeStream(stream), 0); - sleep(2); - - std::string p_name = "[Process " + std::to_string(rank_id) + "] "; - std::cout << p_name; - ASSERT_EQ(aclrtMemcpy(xHost, totalSize, gva + 2048, totalSize, ACL_MEMCPY_DEVICE_TO_HOST), 0); - for (uint32_t i = 0; i < totalSize / sizeof(uint64_t); i++) { - printf("PollCQ index = %d, value = %lu\n", i, xHost[i]); + shm::shmemi_control_barrier_all(); + ASSERT_EQ(aclrtMemcpy(outHost, totalSize, gva, totalSize, ACL_MEMCPY_DEVICE_TO_HOST), 0); + for (uint32_t i = 0; i < rank_size; i++) { + ASSERT_EQ(outHost[i * messageSize / sizeof(uint32_t)], i + 10); } -} -static void test_rdma_put_get(aclrtStream stream, uint8_t *gva, uint32_t rank_id, uint32_t rank_size) -{ - size_t messageSize = 64; - uint32_t *xHost; - size_t totalSize = messageSize * rank_size; - - ASSERT_EQ(aclrtMallocHost((void **)(&xHost), totalSize), 0); - for (uint32_t i = 0; i < messageSize / sizeof(uint32_t); i++) { - xHost[i] = rank_id + 10; + ASSERT_EQ(aclrtMemcpy(gva, totalSize, inHost, totalSize, ACL_MEMCPY_HOST_TO_DEVICE), 0); + shm::shmemi_control_barrier_all(); + test_rdma_get_low_level(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); + ASSERT_EQ(aclrtSynchronizeStream(stream), 0); + shm::shmemi_control_barrier_all(); + ASSERT_EQ(aclrtMemcpy(outHost, totalSize, gva, totalSize, ACL_MEMCPY_DEVICE_TO_HOST), 0); + for (uint32_t i = 0; i < rank_size; i++) { + ASSERT_EQ(outHost[i * messageSize / sizeof(uint32_t)], i + 10); } - ASSERT_EQ(aclrtMemcpy(gva + rank_id * messageSize, messageSize, xHost, messageSize, ACL_MEMCPY_HOST_TO_DEVICE), 0); - uint32_t block_dim = 1; - // test_rdma_put_low_level(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); - // test_rdma_get_low_level(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); + ASSERT_EQ(aclrtMemcpy(gva, totalSize, inHost, totalSize, ACL_MEMCPY_HOST_TO_DEVICE), 0); + shm::shmemi_control_barrier_all(); test_rdma_put_high_level(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); - // test_rdma_get_high_level(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); ASSERT_EQ(aclrtSynchronizeStream(stream), 0); - sleep(2); - - std::string p_name = "[Process " + std::to_string(rank_id) + "] "; - std::cout << p_name; - ASSERT_EQ(aclrtMemcpy(xHost, totalSize, gva, totalSize, ACL_MEMCPY_DEVICE_TO_HOST), 0); + shm::shmemi_control_barrier_all(); + ASSERT_EQ(aclrtMemcpy(outHost, totalSize, gva, totalSize, ACL_MEMCPY_DEVICE_TO_HOST), 0); for (uint32_t i = 0; i < rank_size; i++) { - ASSERT_EQ(xHost[i * messageSize / sizeof(uint32_t)], i + 10); + ASSERT_EQ(outHost[i * messageSize / sizeof(uint32_t)], i + 10); } -} - -static void test_rdma_get_info(aclrtStream stream, uint8_t *gva, uint32_t rankId, uint32_t rankSize) { - uint64_t *xHost; - size_t totalSize = 120; - ASSERT_EQ(aclrtMallocHost((void **)(&xHost), totalSize), 0); - memset(xHost, 0xEE, totalSize); - ASSERT_EQ(aclrtMemcpy(gva, totalSize, xHost, totalSize, ACL_MEMCPY_HOST_TO_DEVICE), 0); - for (uint32_t curRank = 0; curRank < rankSize; curRank++) { - if (curRank == rankId) { - continue; - } - shmem_rdma_get_qpinfo_test_do(stream, gva, curRank, shmemx_get_ffts_config()); - ASSERT_EQ(aclrtSynchronizeStream(stream), 0); - sleep(1); - ASSERT_EQ(aclrtMemcpy(xHost, totalSize, gva, totalSize, ACL_MEMCPY_DEVICE_TO_HOST), 0); - for (uint32_t i = 0; i < totalSize / sizeof(uint64_t); i++) { - printf("GetQPInfo srcRank = %d, destRank = %d, index = %d, value = %lu\n", rankId, curRank, i, xHost[i]); - } + ASSERT_EQ(aclrtMemcpy(gva, totalSize, inHost, totalSize, ACL_MEMCPY_HOST_TO_DEVICE), 0); + shm::shmemi_control_barrier_all(); + test_rdma_get_high_level(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); + ASSERT_EQ(aclrtSynchronizeStream(stream), 0); + shm::shmemi_control_barrier_all(); + ASSERT_EQ(aclrtMemcpy(outHost, totalSize, gva, totalSize, ACL_MEMCPY_DEVICE_TO_HOST), 0); + for (uint32_t i = 0; i < rank_size; i++) { + ASSERT_EQ(outHost[i * messageSize / sizeof(uint32_t)], i + 10); } - ASSERT_EQ(aclrtFreeHost(xHost), 0); + ASSERT_EQ(aclrtFreeHost(inHost), 0); + ASSERT_EQ(aclrtFreeHost(outHost), 0); } void test_shmem_rdma_mem(int rank_id, int n_ranks, uint64_t local_mem_size) { @@ -102,9 +83,7 @@ void test_shmem_rdma_mem(int rank_id, int n_ranks, uint64_t local_mem_size) { ASSERT_NE(stream, nullptr); void* ptr = shmem_malloc(1024); - // test_rdma_poll_cq(stream, (uint8_t *)ptr, rank_id, n_ranks); test_rdma_put_get(stream, (uint8_t *)ptr, rank_id, n_ranks); - // test_rdma_get_info(stream, (uint8_t *)ptr, rank_id, n_ranks); std::cout << "[TEST] begin to exit...... rank_id: " << rank_id << std::endl; test_finalize(stream, device_id); if (::testing::Test::HasFailure()){ -- Gitee From 3cb64a43c36ed4e6a3138e8635b5592c29308c9b Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Tue, 2 Sep 2025 20:56:58 +0800 Subject: [PATCH 17/74] Add RDMA device code. --- .../rdma_perftest/rdma_perftest_kernel.cpp | 12 +- .../low_level/shmem_device_low_level_rma.h | 9 +- .../low_level/shmem_device_low_level_roce.h | 506 ++++++++++++++++++ include/device/shmem_device_rma.h | 9 +- 4 files changed, 522 insertions(+), 14 deletions(-) create mode 100644 include/device/low_level/shmem_device_low_level_roce.h diff --git a/examples/rdma_perftest/rdma_perftest_kernel.cpp b/examples/rdma_perftest/rdma_perftest_kernel.cpp index acca5f98..91e91056 100644 --- a/examples/rdma_perftest/rdma_perftest_kernel.cpp +++ b/examples/rdma_perftest/rdma_perftest_kernel.cpp @@ -99,7 +99,7 @@ extern "C" __global__ __aicore__ void rdma_postsend_cost(uint64_t fftsConfig, GM GM_ADDR dest_addr = (GM_ADDR)(shmem_ptr(src_addr, peer)); int64_t start = AscendC::GetSystemCycle(); for (uint32_t i = 0; i < 500; i++) { - smem_shm_roce_write(src_addr, dest_addr, peer, 0, message_length, ubLocal64, ubLocal32); + shmemi_roce_write(src_addr, dest_addr, peer, 0, message_length, ubLocal64, ubLocal32); } AscendC::PipeBarrier(); int64_t end = AscendC::GetSystemCycle(); @@ -134,7 +134,7 @@ extern "C" __global__ __aicore__ void rdma_highlevel_put_bw(uint64_t fftsConfig, for (int i = 0; i < 10000; i++) { shmem_put_uint8_mem_nbi(src_addr, src_addr, message_length, peer); } - smem_shm_roce_quiet(peer, 0, ubLocal64, ubLocal32); + shmemi_roce_quiet(peer, 0, ubLocal64, ubLocal32); shmem_put_uint8_mem_nbi(gva + rank_size * message_length + 8, src_addr, sizeof(uint32_t), peer); while (*(__gm__ uint32_t*)(gva + message_length * rank_size + 16) != peer + MAGIC_VAL) { cacheWriteThrough(gva + message_length * rank_size + 16, 8); @@ -180,10 +180,10 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_AD peer = 1; int64_t start = AscendC::GetSystemCycle(); for (int i = 0; i < 10000; i++) { - smem_shm_roce_write(src_addr, (GM_ADDR)shmem_ptr(src_addr, peer), peer, 0, message_length, ubLocal64, ubLocal32); + shmemi_roce_write(src_addr, (GM_ADDR)shmem_ptr(src_addr, peer), peer, 0, message_length, ubLocal64, ubLocal32); } - smem_shm_roce_quiet(peer, 0, ubLocal64, ubLocal32); - smem_shm_roce_write(src_addr, (GM_ADDR)shmem_ptr(gva + rank_size * message_length * 2 + 8, peer), peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); + shmemi_roce_quiet(peer, 0, ubLocal64, ubLocal32); + shmemi_roce_write(src_addr, (GM_ADDR)shmem_ptr(gva + rank_size * message_length * 2 + 8, peer), peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); while (*(__gm__ int64_t*)(gva + message_length * rank_size * 2 + 16) != peer + MAGIC_VAL + iter) { cacheWriteThrough(gva + message_length * rank_size * 2 + 16, 8); AscendC::GetSystemCycle(); @@ -198,7 +198,7 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_AD AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); - smem_shm_roce_write(src_addr, (GM_ADDR)shmem_ptr(gva + rank_size * message_length * 2 + 16, peer), peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); + shmemi_roce_write(src_addr, (GM_ADDR)shmem_ptr(gva + rank_size * message_length * 2 + 16, peer), peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); } } else { // core 1, MTE GM_ADDR src_addr = gva + (rank + rank_size) * message_length; diff --git a/include/device/low_level/shmem_device_low_level_rma.h b/include/device/low_level/shmem_device_low_level_rma.h index c43cb668..e843f609 100644 --- a/include/device/low_level/shmem_device_low_level_rma.h +++ b/include/device/low_level/shmem_device_low_level_rma.h @@ -13,6 +13,7 @@ #include "kernel_operator.h" #include "internal/device/shmemi_device_common.h" #include "device/shmem_device_team.h" +#include "shmem_device_low_level_roce.h" constexpr uint64_t SHMEM_INTERNAL_UB_BUF_START_ADDR = 188 * 1024; constexpr uint32_t UB_ALIGN_SIZE = 32; @@ -102,7 +103,7 @@ SHMEM_DEVICE void shmem_roce_get_mem_nbi(__gm__ T* dst, __gm__ T* src, __ubuf__ ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); ub_tensor_64.address_.bufferAddr = reinterpret_cast(buf) + UB_ALIGN_SIZE; ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; - smem_shm_roce_read((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)dst, pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); + shmemi_roce_read((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)dst, pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); } @@ -216,7 +217,7 @@ SHMEM_DEVICE void shmem_roce_get_mem_nbi(AscendC::GlobalTensor dst, AscendC:: ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); ub_tensor_64.address_.bufferAddr = reinterpret_cast(buf.GetPhyAddr()) + UB_ALIGN_SIZE; ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; - smem_shm_roce_read((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)dst.GetPhyAddr(), pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); + shmemi_roce_read((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)dst.GetPhyAddr(), pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); } /** @@ -320,7 +321,7 @@ SHMEM_DEVICE void shmem_roce_put_mem_nbi(__gm__ T* dst, __gm__ T* src, __ubuf__ ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); ub_tensor_64.address_.bufferAddr = reinterpret_cast(buf) + UB_ALIGN_SIZE; ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; - smem_shm_roce_write((__gm__ uint8_t*)src, (__gm__ uint8_t*)ptr, pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); + shmemi_roce_write((__gm__ uint8_t*)src, (__gm__ uint8_t*)ptr, pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); } /** @@ -433,7 +434,7 @@ SHMEM_DEVICE void shmem_roce_put_mem_nbi(AscendC::GlobalTensor dst, AscendC:: ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); ub_tensor_64.address_.bufferAddr = reinterpret_cast(buf.GetPhyAddr()) + UB_ALIGN_SIZE; ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; - smem_shm_roce_write((__gm__ uint8_t*)(src.GetPhyAddr()), (__gm__ uint8_t*)ptr, pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); + shmemi_roce_write((__gm__ uint8_t*)(src.GetPhyAddr()), (__gm__ uint8_t*)ptr, pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); } /** diff --git a/include/device/low_level/shmem_device_low_level_roce.h b/include/device/low_level/shmem_device_low_level_roce.h new file mode 100644 index 00000000..54ad7f19 --- /dev/null +++ b/include/device/low_level/shmem_device_low_level_roce.h @@ -0,0 +1,506 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef SHMEM_DEVICE_LOW_LEVEL_ROCE_H +#define SHMEM_DEVICE_LOW_LEVEL_ROCE_H + +#include "kernel_operator.h" +#include "internal/device/shmemi_device_common.h" + +constexpr uint64_t SHMEM_DATA_CACHE_LINE_SIZE = 64; +constexpr uint32_t SHMEM_NUM_CQE_PER_POLL_CQ = 100; + +SHMEM_DEVICE void cacheInvalid(__gm__ uint8_t* sourceAddr, uint64_t length) { + __gm__ uint8_t* start = (__gm__ uint8_t*)((uint64_t)sourceAddr / SHMEM_DATA_CACHE_LINE_SIZE * SHMEM_DATA_CACHE_LINE_SIZE); + __gm__ uint8_t* end = + (__gm__ uint8_t*)( + ((uint64_t)sourceAddr + length) / SHMEM_DATA_CACHE_LINE_SIZE * SHMEM_DATA_CACHE_LINE_SIZE + ); + AscendC::GlobalTensor global; + global.SetGlobalBuffer(start); + for (uint64_t i = 0; i <= end - start; i+= SHMEM_DATA_CACHE_LINE_SIZE) { + AscendC::DataCacheCleanAndInvalid(global[i]); + } +} + +enum class SHMEMAIVOPCODE : uint32_t { + OP_SEND = 0, + OP_SEND_WITH_INV, + OP_SEND_WITH_IMM, + OP_RDMA_WRITE, + OP_RDMA_WRITE_WITH_IMM, + OP_RDMA_READ +}; + +struct SHMEMAIVRDMAInfo { + uint32_t qpNum; // number of QP per connection + uint64_t sqPtr; // pointer to send queue address array of size [PE_NUM][qpNum] + uint64_t rqPtr; // pointer to receive queue address array of size [PE_NUM][qpNum] + uint64_t scqPtr; // pointer to send completion queue address array of size [PE_NUM][qpNum] + uint64_t rcqPtr; // pointer to receive completion queue address array of size [PE_NUM][qpNum] + uint64_t memPtr; // pointer to memory region array of size [MAX_PE_NUM] +}; + +struct SHMEMmemInfo { + uint64_t size; // size of the memory region + uint64_t addr; // start address of the memory region + uint32_t lkey; // local key of the memory region + uint32_t rkey; // remote key of the memory region +}; + +enum class SHMEMDBMode : int32_t { INVALID_DB = -1, HW_DB = 0, SW_DB }; + +struct SHMEMWQCtx { + uint32_t wqn; // work queue number + uint64_t bufAddr; // start address of ring buffer + uint32_t wqeSize; // size of each WQE + uint32_t depth; // depth of ring buffer + uint64_t headAddr; // work queue head (Producer Index) address + uint64_t tailAddr; // work queue tail (Consumer Index) address + SHMEMDBMode dbMode; + uint64_t dbAddr; // doorbell address + uint32_t sl; // service level +}; + +struct SHMEMCQCtx { + uint32_t cqn; // completion queue number + uint64_t bufAddr; // start address of ring buffer + uint32_t cqeSize; // size of each CQE + uint32_t depth; // depth of ring buffer + uint64_t headAddr; // work queue head (Producer Index) address + uint64_t tailAddr; // work queue tail (Consumer Index) address + SHMEMDBMode dbMode; + uint64_t dbAddr; // doorbell address +}; + +struct SHMEMwqeCtx { + uint32_t byte4; + uint32_t msgLen; + uint32_t immtdata; + uint32_t byte16; + uint32_t byte20; + uint32_t rkey; + uint64_t va; +}; + +struct SHMEMsegCtx { + uint32_t len; + uint32_t lkey; + uint64_t addr; +}; + +struct SHMEMcqeCtx { + uint32_t byte4; + uint32_t immtdata; + uint32_t byte12; + uint32_t byte16; + uint32_t byteCnt; + uint32_t smac; + uint32_t byte28; + uint32_t byte32; +}; + +struct SHMEMHybmDeviceMeta { + uint32_t entityId; + uint32_t rankId; + uint32_t rankSize; + uint32_t extraContextSize; + uint64_t symmetricSize; + uint64_t qpInfoAddress; + uint64_t reserved[12]; // total 128B, equal HYBM_DEVICE_PRE_META_SIZE +}; + +/** + * @brief RDMA Poll Completion Queue (CQ) function. Return status: 0 means success, non-zero means error. + * + * @param remoteRankId [in] destination rank ID + * @param qpIdx [in] QP index in multi-QP scenario (default 0 for single QP) + * @param idx [in] expect completion queue consumer index after polling + * @param ubLocal64 [in] temporary UB local tensor of uint64_t used as workspace + * @param ubLocal32 [in] temporary UB local tensor of uint32_t used as workspace + */ + +SHMEM_DEVICE uint32_t shmemi_roce_poll_cq(uint32_t remoteRankId, uint32_t qpIdx, uint32_t idx, + AscendC::LocalTensor ubLocal64, + AscendC::LocalTensor ubLocal32) +{ + __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + + SMEM_SHM_DEVICE_GLOBAL_META_SIZE); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); + uint32_t qpNum = RDMAInfo->qpNum; + __gm__ SHMEMCQCtx* cqCtxEntry = (__gm__ SHMEMCQCtx*)(RDMAInfo->scqPtr + (remoteRankId * qpNum + qpIdx) * sizeof(SHMEMCQCtx)); + auto cqBaseAddr = cqCtxEntry->bufAddr; + auto cqeSize = cqCtxEntry->cqeSize; + auto depth = cqCtxEntry->depth; + auto curHardwareTailAddr = cqCtxEntry->tailAddr; + cacheInvalid((__gm__ uint8_t*)curHardwareTailAddr, 8); + uint32_t curTail = *(__gm__ uint32_t*)(curHardwareTailAddr); + + AscendC::DataCopyExtParams copyParamsTail{1, 1 * sizeof(uint32_t), 0, 0, 0}; + while (curTail != idx) { + __gm__ SHMEMcqeCtx* cqeAddr = (__gm__ SHMEMcqeCtx*)(cqBaseAddr + cqeSize * (curTail & (depth - 1))); + uint32_t cqeByte4 = *(__gm__ uint32_t*)cqeAddr; + while (((cqeByte4 & (1 << 7)) != 0) == ((curTail & depth) != 0)) { + int64_t tmp = AscendC::GetSystemCycle(); + cacheInvalid((__gm__ uint8_t*)cqeAddr, 32); + cqeByte4 = *(__gm__ uint32_t*)cqeAddr; + } + curTail++; + uint32_t wqn = cqeAddr->byte16 & 0xFFFFFF; + + // Check CQE status + uint32_t status = (cqeAddr->byte4 >> 8) & 0xFF; + if (status) { + return status; + } + } + + // Update CQ tail + ubLocal32.SetValue(0, (uint32_t)curTail); + AscendC::GlobalTensor TailGlobalTensor; + TailGlobalTensor.SetGlobalBuffer((__gm__ uint32_t*)curHardwareTailAddr); + AscendC::PipeBarrier(); + AscendC::DataCopyPad(TailGlobalTensor, ubLocal32, copyParamsTail); + AscendC::PipeBarrier(); + cacheInvalid((__gm__ uint8_t*)curHardwareTailAddr, 8); + + // Ring CQ Doorbell + auto cqDBAddr = cqCtxEntry->dbAddr; + if (cqCtxEntry->dbMode == SHMEMDBMode::SW_DB) { + ubLocal32.SetValue(0, (uint32_t)(curTail & 0xFFFFFF)); + AscendC::GlobalTensor CQDBGlobalTensor; + CQDBGlobalTensor.SetGlobalBuffer((__gm__ uint32_t*)cqDBAddr); + AscendC::PipeBarrier(); + AscendC::DataCopyPad(CQDBGlobalTensor, ubLocal32, copyParamsTail); + AscendC::PipeBarrier(); + cacheInvalid((__gm__ uint8_t*)cqDBAddr, 8); + } else if (cqCtxEntry->dbMode == SHMEMDBMode::HW_DB) { + uint64_t doorBellInfo = 0; + doorBellInfo |= cqCtxEntry->cqn; // [0:23] DB_TAG = qp_num + doorBellInfo |= 3 << 24; // [24:27] DB_CMD = HNS_ROCE_V2_CQ_DB_PTR(3) + doorBellInfo |= (uint64_t)(curTail & 0xFFFFFF) << 32; // [32:55] DB_CQ_CI = cq.tail + doorBellInfo |= (uint64_t)1 << 56; // [56:56] DB_CQ_CMD_SN = 1 + ubLocal64.SetValue(0, doorBellInfo); + AscendC::GlobalTensor DBGlobalTensor; + DBGlobalTensor.SetGlobalBuffer((__gm__ uint64_t*)cqDBAddr); + AscendC::DataCopyExtParams copyParams{1, 1 * sizeof(uint64_t), 0, 0, 0}; + AscendC::PipeBarrier(); + AscendC::DataCopyPad(DBGlobalTensor, ubLocal64, copyParams); + AscendC::PipeBarrier(); + } + + // Update WQ tail + __gm__ SHMEMWQCtx* wqCtxEntry = (__gm__ SHMEMWQCtx*)(RDMAInfo->sqPtr + (remoteRankId * qpNum + qpIdx) * sizeof(SHMEMWQCtx)); + auto curWQTailAddr = wqCtxEntry->tailAddr; + cacheInvalid((__gm__ uint8_t*)curWQTailAddr, 8); + uint32_t curWQTail = *(__gm__ uint32_t*)(curWQTailAddr); + ubLocal32.SetValue(0, curTail); + AscendC::GlobalTensor WQTailGlobalTensor; + WQTailGlobalTensor.SetGlobalBuffer((__gm__ uint32_t*)curWQTailAddr); + AscendC::PipeBarrier(); + AscendC::DataCopyPad(WQTailGlobalTensor, ubLocal32, copyParamsTail); + AscendC::PipeBarrier(); + cacheInvalid((__gm__ uint8_t*)curWQTailAddr, 8); + return 0; +} + +/** + * @brief AIV direct RDMA helper function for post send, prepare WQE and ring doorbell. + * + * @param remoteAddr [in] address in remote HBM + * @param localAddr [in] address in lcoal HBM + * @param destRankId [in] destination rank ID + * @param qpIdx [in] QP index in multi-QP scenario (default 0 for single QP) + * @param opcode [in] rdma opcode in SHMEMAIVOPCODE enum class + * @param messageLen [in] message length in Bytes + * @param ubLocal64 [in] temporary UB local tensor of uint64_t used as workspace + * @param ubLocal32 [in] temporary UB local tensor of uint32_t used as workspace + */ + +SHMEM_DEVICE void shmemi_rdma_post_send(__gm__ uint8_t* remoteAddr, __gm__ uint8_t* localAddr, + uint32_t destRankId, uint32_t qpIdx, + SHMEMAIVOPCODE opcode, uint64_t messageLen, + AscendC::LocalTensor ubLocal64, + AscendC::LocalTensor ubLocal32) +{ + __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + + SMEM_SHM_DEVICE_GLOBAL_META_SIZE); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); + uint32_t qpNum = RDMAInfo->qpNum; + __gm__ SHMEMWQCtx* qpCtxEntry = (__gm__ SHMEMWQCtx*)(RDMAInfo->sqPtr + (destRankId * qpNum + qpIdx) * sizeof(SHMEMWQCtx)); + auto SHMEMmemInfoTable = RDMAInfo->memPtr; + auto sqBaseAddr = qpCtxEntry->bufAddr; + auto wqeSize = qpCtxEntry->wqeSize; + auto curHardwareHeadAddr = qpCtxEntry->headAddr; + cacheInvalid((__gm__ uint8_t*)curHardwareHeadAddr, 8); + uint32_t curHead = *(__gm__ uint32_t*)(curHardwareHeadAddr); + auto curHardwareTailAddr = qpCtxEntry->tailAddr; + auto depth = qpCtxEntry->depth; + auto shift = 13; + AscendC::PipeBarrier(); + + // Poll CQ if send queue is full + cacheInvalid((__gm__ uint8_t*)curHardwareTailAddr, 8); + if ((curHead + 10) % depth == (*(__gm__ uint32_t*)(curHardwareTailAddr)) % depth) { + shmemi_roce_poll_cq(destRankId, qpIdx, *(__gm__ uint32_t*)(curHardwareTailAddr) + + SHMEM_NUM_CQE_PER_POLL_CQ, ubLocal64, ubLocal32); + } + + // Write WQE to HBM + __gm__ uint8_t* wqeAddr = (__gm__ uint8_t*)(sqBaseAddr + wqeSize * (curHead % depth)); + uint64_t ownBit = (curHead >> shift) & 0x1; + uint32_t byte4 = (uint32_t)opcode & 0x1F; // [0:4] opcode + byte4 |= ((~ownBit) << 7) & (1 << 7); // [7] owner_bit + byte4 |= 1 << 8; // [8] IBV_SEND_SINGNALED + *(__gm__ uint32_t*)(wqeAddr) = byte4; // control set by local parameter, see above lines + *(__gm__ uint32_t*)(wqeAddr + 4) = messageLen; // message size in bytes + *(__gm__ uint32_t*)(wqeAddr + 8) = 0; // immtdata is always 0 till we provide poll CQ flow in AIV + *(__gm__ uint32_t*)(wqeAddr + 12) = 1 << 24; // [120:127] num_sge = 1 + *(__gm__ uint32_t*)(wqeAddr + 16) = 0; // [128:151] start_sge_index = 0 + __gm__ SHMEMmemInfo* remoteMemInfo = (__gm__ SHMEMmemInfo*)(SHMEMmemInfoTable + sizeof(SHMEMmemInfo) * destRankId); + *(__gm__ uint32_t*)(wqeAddr + 20) = remoteMemInfo->rkey; // rkey + *(__gm__ uint64_t*)(wqeAddr + 24) = (uint64_t)remoteAddr; // remote VA + + // Write SGE to HBM + __gm__ uint8_t* sgeAddr = wqeAddr + sizeof(SHMEMwqeCtx); + *(__gm__ uint32_t*)(sgeAddr) = messageLen; // message size in bytes + __gm__ SHMEMmemInfo* localMemInfo = (__gm__ SHMEMmemInfo*)(SHMEMmemInfoTable + sizeof(SHMEMmemInfo) * shmemi_get_my_pe()); + *(__gm__ uint32_t*)(sgeAddr + 4) = localMemInfo->lkey; // lkey + *(__gm__ uint64_t*)(sgeAddr + 8) = (uint64_t)localAddr; // local VA + + // WQE & SGE cache flush + cacheInvalid(wqeAddr, sizeof(SHMEMwqeCtx) + sizeof(SHMEMsegCtx)); + AscendC::PipeBarrier(); + curHead++; + + uint64_t doorBellInfo = 0; + doorBellInfo |= qpCtxEntry->wqn; // [0:23] DB_TAG = qp_num + doorBellInfo |= 0 << 24; // [24:27] DB_CMD = HNS_ROCE_V2_SQ_DB(0) + doorBellInfo |= ((uint64_t)curHead % 65536) << 32; // [32:47] DB_PI = sq.head + doorBellInfo |= (uint64_t)(qpCtxEntry->sl) << 48; // [48:50] DB_SL = qp.sl + + __gm__ uint64_t* doorBellAddr = (__gm__ uint64_t*)(qpCtxEntry->dbAddr); + AscendC::PipeBarrier(); + + ubLocal64.SetValue(0, doorBellInfo); + AscendC::GlobalTensor DBGlobalTensor; + DBGlobalTensor.SetGlobalBuffer(doorBellAddr); + AscendC::DataCopyExtParams copyParams{1, 1 * sizeof(uint64_t), 0, 0, 0}; + AscendC::PipeBarrier(); + AscendC::DataCopyPad(DBGlobalTensor, ubLocal64, copyParams); + AscendC::PipeBarrier(); + + ubLocal32.SetValue(0, (uint32_t)curHead); + AscendC::GlobalTensor HeadGlobalTensor; + HeadGlobalTensor.SetGlobalBuffer((__gm__ uint32_t*)curHardwareHeadAddr); + AscendC::DataCopyExtParams copyParamsHead{1, 1 * sizeof(uint32_t), 0, 0, 0}; + AscendC::PipeBarrier(); + AscendC::DataCopyPad(HeadGlobalTensor, ubLocal32, copyParamsHead); + AscendC::PipeBarrier(); +} + +/** + * @brief Asynchronous RDMA Write function. + * + * @param srcDmaAddr [in] source address in local HBM + * @param destDmaAddr [in] destination address in remote HBM + * @param destRankId [in] destination rank ID + * @param qpIdx [in] QP index in multi-QP scenario (default 0 for single QP) + * @param messageLen [in] message length in Bytes + * @param ubLocal64 [in] temporary UB local tensor of uint64_t used as workspace + * @param ubLocal32 [in] temporary UB local tensor of uint32_t used as workspace + */ + +template +SHMEM_DEVICE void shmemi_roce_write(__gm__ T* srcDmaAddr, __gm__ T* destDmaAddr, uint32_t destRankId, + uint32_t qpIdx, uint64_t messageLen, + AscendC::LocalTensor ubLocal64, + AscendC::LocalTensor ubLocal32) +{ + shmemi_rdma_post_send(destDmaAddr, srcDmaAddr, destRankId, qpIdx, SHMEMAIVOPCODE::OP_RDMA_WRITE, + messageLen, ubLocal64, ubLocal32); +} + +/** + * @brief Asynchronous RDMA READ function. + * + * @param srcDmaAddr [in] source address in remote HBM + * @param destDmaAddr [in] destination address in local HBM + * @param srcRankId [in] destination rank ID + * @param qpIdx [in] QP index in multi-QP scenario (default 0 for single QP) + * @param messageLen [in] message length in Bytes + * @param ubLocal64 [in] temporary UB local tensor of uint64_t used as workspace + * @param ubLocal32 [in] temporary UB local tensor of uint32_t used as workspace + */ + +template +SHMEM_DEVICE void shmemi_roce_read(__gm__ T* srcDmaAddr, __gm__ T* destDmaAddr, uint32_t srcRankId, + uint32_t qpIdx, uint64_t messageLen, + AscendC::LocalTensor ubLocal64, + AscendC::LocalTensor ubLocal32) +{ + shmemi_rdma_post_send(srcDmaAddr, destDmaAddr, srcRankId, qpIdx, SHMEMAIVOPCODE::OP_RDMA_READ, + messageLen, ubLocal64, ubLocal32); +} + +/** + * @brief RDMA Quiet function. This synchronous function ensures all previous RDMA WQEs are completed (data has arrived at the destination NIC). + * + * @param remoteRankId [in] destination rank ID + * @param qpIdx [in] QP index in multi-QP scenario (default 0 for single QP) + * @param ubLocal64 [in] temporary UB local tensor of uint64_t used as workspace + * @param ubLocal32 [in] temporary UB local tensor of uint32_t used as workspace + */ + +SHMEM_DEVICE void shmemi_roce_quiet(uint32_t remoteRankId, uint32_t qpIdx, + AscendC::LocalTensor ubLocal64, + AscendC::LocalTensor ubLocal32) +{ + __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + + SMEM_SHM_DEVICE_GLOBAL_META_SIZE); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); + uint32_t qpNum = RDMAInfo->qpNum; + __gm__ SHMEMWQCtx* qpCtxEntry = (__gm__ SHMEMWQCtx*)(RDMAInfo->sqPtr + (remoteRankId * qpNum + qpIdx) * sizeof(SHMEMWQCtx)); + auto curHardwareHeadAddr = qpCtxEntry->headAddr; + cacheInvalid((__gm__ uint8_t*)curHardwareHeadAddr, 8); + uint32_t curHead = *(__gm__ uint32_t*)(curHardwareHeadAddr); + shmemi_roce_poll_cq(remoteRankId, qpIdx, curHead, ubLocal64, ubLocal32); +} + +SHMEM_DEVICE void shmemi_roce_qpinfo_test(__gm__ uint8_t* gva, uint32_t destRankId, uint32_t qpIdx) +{ + __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + + SMEM_SHM_DEVICE_GLOBAL_META_SIZE); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); + *(__gm__ uint64_t*)(gva) = (uint64_t)RDMAInfo; + uint32_t qpNum = RDMAInfo->qpNum; + *(__gm__ uint64_t*)(gva + 8) = (uint64_t)qpNum; + __gm__ SHMEMWQCtx* qpCtxEntry = (__gm__ SHMEMWQCtx*)(RDMAInfo->sqPtr + (destRankId * qpNum + qpIdx) * sizeof(SHMEMWQCtx)); + *(__gm__ uint64_t*)(gva + 16) = (uint64_t)qpCtxEntry; + auto SHMEMmemInfoTable = RDMAInfo->memPtr; + *(__gm__ uint64_t*)(gva + 24) = (uint64_t)SHMEMmemInfoTable; + auto sqBaseAddr = qpCtxEntry->bufAddr; + *(__gm__ uint64_t*)(gva + 32) = (uint64_t)sqBaseAddr; + auto wqeSize = qpCtxEntry->wqeSize; + *(__gm__ uint64_t*)(gva + 40) = (uint64_t)wqeSize; + auto curHardwareHeadAddr = qpCtxEntry->headAddr; + *(__gm__ uint64_t*)(gva + 48) = (uint64_t)curHardwareHeadAddr; + cacheInvalid((__gm__ uint8_t*)curHardwareHeadAddr, 8); + uint32_t curHead = *(__gm__ uint32_t*)(curHardwareHeadAddr); + *(__gm__ uint64_t*)(gva + 56) = (uint64_t)curHead; + auto curHardwareTailAddr = qpCtxEntry->tailAddr; + *(__gm__ uint64_t*)(gva + 64) = (uint64_t)curHardwareTailAddr; + auto depth = qpCtxEntry->depth; + *(__gm__ uint64_t*)(gva + 72) = (uint64_t)depth; + *(__gm__ uint64_t*)(gva + 80) = (uint64_t)(qpCtxEntry->sl); + auto shift = 15; + AscendC::PipeBarrier(); + + // Write WQE to HBM + __gm__ uint8_t* wqeAddr = (__gm__ uint8_t*)(sqBaseAddr + wqeSize * (curHead % depth)); + __gm__ SHMEMmemInfo* remoteMemInfo = (__gm__ SHMEMmemInfo*)(SHMEMmemInfoTable + sizeof(SHMEMmemInfo) * destRankId); + *(__gm__ uint64_t*)(gva + 88) = (uint64_t)(remoteMemInfo->rkey); + + // Write SGE to HBM + __gm__ SHMEMmemInfo* localMemInfo = (__gm__ SHMEMmemInfo*)(SHMEMmemInfoTable + sizeof(SHMEMmemInfo) * shmemi_get_my_pe()); + *(__gm__ uint64_t*)(gva + 96) = (uint64_t)(localMemInfo->lkey);; // lkey + + __gm__ uint64_t* doorBellAddr = (__gm__ uint64_t*)(qpCtxEntry->dbAddr); + *(__gm__ uint64_t*)(gva + 104) = (uint64_t)doorBellAddr; + *(__gm__ uint64_t*)(gva + 112) = (uint64_t)gva; + AscendC::PipeBarrier(); +} + +template +SHMEM_DEVICE void shmemi_roce_pollcq_test(__gm__ T* srcDmaAddr, __gm__ T* destDmaAddr, uint32_t destRankId, + uint32_t qpIdx, uint64_t messageLen, + AscendC::LocalTensor ubLocal64, + AscendC::LocalTensor ubLocal32, __gm__ uint8_t* gva) +{ + shmemi_rdma_post_send(destDmaAddr, srcDmaAddr, destRankId, qpIdx, SHMEMAIVOPCODE::OP_RDMA_WRITE, + messageLen, ubLocal64, ubLocal32); + uint32_t idx = 1; + __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + + SMEM_SHM_DEVICE_GLOBAL_META_SIZE); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); + uint32_t qpNum = RDMAInfo->qpNum; + __gm__ SHMEMCQCtx* cqCtxEntry = (__gm__ SHMEMCQCtx*)(RDMAInfo->scqPtr + (destRankId * qpNum + qpIdx) * sizeof(SHMEMCQCtx)); + *(__gm__ uint64_t*)(gva) = (uint64_t)cqCtxEntry; + auto cqBaseAddr = cqCtxEntry->bufAddr; + auto cqeSize = cqCtxEntry->cqeSize; + auto depth = cqCtxEntry->depth; + *(__gm__ uint64_t*)(gva + 8) = (uint64_t)cqBaseAddr; + *(__gm__ uint64_t*)(gva + 16) = (uint64_t)cqeSize; + *(__gm__ uint64_t*)(gva + 24) = (uint64_t)depth; + auto curHardwareTailAddr = cqCtxEntry->tailAddr; + *(__gm__ uint64_t*)(gva + 32) = (uint64_t)curHardwareTailAddr; + cacheInvalid((__gm__ uint8_t*)curHardwareTailAddr, 8); + uint32_t curTail = *(__gm__ uint32_t*)(curHardwareTailAddr); + *(__gm__ uint64_t*)(gva + 40) = (uint64_t)curTail; + + AscendC::DataCopyExtParams copyParamsTail{1, 1 * sizeof(uint32_t), 0, 0, 0}; + + __gm__ SHMEMcqeCtx* cqeAddr = (__gm__ SHMEMcqeCtx*)(cqBaseAddr + cqeSize * (curTail & (depth - 1))); + uint32_t cqeByte4 = *(__gm__ uint32_t*)cqeAddr; + while (!(cqeByte4 & (1 << 7))) { + int64_t tmp = AscendC::GetSystemCycle(); + cacheInvalid((__gm__ uint8_t*)cqeAddr, 32); + cqeByte4 = *(__gm__ uint32_t*)cqeAddr; + } + *(__gm__ uint64_t*)(gva + 56) = (uint64_t)(cqeAddr->byte4); + *(__gm__ uint64_t*)(gva + 64) = (uint64_t)(cqeAddr->immtdata); + *(__gm__ uint64_t*)(gva + 72) = (uint64_t)(cqeAddr->byte12); + *(__gm__ uint64_t*)(gva + 80) = (uint64_t)(cqeAddr->byte16); + *(__gm__ uint64_t*)(gva + 88) = (uint64_t)(cqeAddr->byteCnt); + *(__gm__ uint64_t*)(gva + 96) = (uint64_t)(cqeAddr->smac); + curTail++; + // Process each CQE, and update WQ tail + uint32_t wqn = cqeAddr->byte16 & 0xFFFFFF; + __gm__ SHMEMWQCtx* wqCtxEntry = (__gm__ SHMEMWQCtx*)(RDMAInfo->sqPtr + (destRankId * qpNum + qpIdx) * sizeof(SHMEMWQCtx)); + *(__gm__ uint64_t*)(gva + 104) = (uint64_t)(wqCtxEntry->wqn == wqn); + auto curWQTailAddr = wqCtxEntry->tailAddr; + cacheInvalid((__gm__ uint8_t*)curWQTailAddr, 8); + uint32_t curWQTail = *(__gm__ uint32_t*)(curWQTailAddr); + ubLocal32.SetValue(0, curWQTail + 1); + AscendC::GlobalTensor WQTailGlobalTensor; + WQTailGlobalTensor.SetGlobalBuffer((__gm__ uint32_t*)curWQTailAddr); + AscendC::PipeBarrier(); + AscendC::DataCopyPad(WQTailGlobalTensor, ubLocal32, copyParamsTail); + AscendC::PipeBarrier(); + cacheInvalid((__gm__ uint8_t*)curWQTailAddr, 8); + + // Check CQE status + uint32_t status = (cqeAddr->byte4 >> 8) & 0xFF; + *(__gm__ uint64_t*)(gva + 112) = status; + if (status) { + return; + } + + // Update tail + ubLocal32.SetValue(0, (uint32_t)curTail); + AscendC::GlobalTensor TailGlobalTensor; + TailGlobalTensor.SetGlobalBuffer((__gm__ uint32_t*)curHardwareTailAddr); + AscendC::PipeBarrier(); + AscendC::DataCopyPad(TailGlobalTensor, ubLocal32, copyParamsTail); + AscendC::PipeBarrier(); + cacheInvalid((__gm__ uint8_t*)curHardwareTailAddr, 8); + + // Ring CQ Doorbell + auto cqDBAddr = cqCtxEntry->dbAddr; + ubLocal32.SetValue(0, (uint32_t)(curTail & 0xFFFFFF)); + AscendC::GlobalTensor CQDBGlobalTensor; + CQDBGlobalTensor.SetGlobalBuffer((__gm__ uint32_t*)cqDBAddr); + AscendC::PipeBarrier(); + AscendC::DataCopyPad(CQDBGlobalTensor, ubLocal32, copyParamsTail); + AscendC::PipeBarrier(); + cacheInvalid((__gm__ uint8_t*)cqDBAddr, 8); +} + +#endif // SHMEM_DEVICE_LOW_LEVEL_ROCE_H \ No newline at end of file diff --git a/include/device/shmem_device_rma.h b/include/device/shmem_device_rma.h index 2c847f79..e57b5ef6 100644 --- a/include/device/shmem_device_rma.h +++ b/include/device/shmem_device_rma.h @@ -13,6 +13,7 @@ #include "kernel_operator.h" #include "internal/device/shmemi_device_common.h" #include "low_level/shmem_device_low_level_rma.h" +#include "low_level/shmem_device_low_level_roce.h" #include "shmem_device_team.h" #include "internal/device/sync/shmemi_device_p2p.h" #include "shmem_device_sync.h" @@ -266,7 +267,7 @@ SHMEM_DEVICE void shmem_getmem_nbi(__gm__ void *dst, __gm__ void *src, uint32_t ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR \ + UB_ALIGN_SIZE); \ ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; \ - smem_shm_roce_read((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)dst, pe, 0, elem_size * sizeof(TYPE), \ + shmemi_roce_read((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)dst, pe, 0, elem_size * sizeof(TYPE), \ ub_tensor_64, ub_tensor_32); \ } \ } @@ -341,7 +342,7 @@ SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_DETAILED_NBI); ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR \ + UB_ALIGN_SIZE); \ ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; \ - smem_shm_roce_read((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)(dst.GetPhyAddr()), pe, 0, \ + shmemi_roce_read((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)(dst.GetPhyAddr()), pe, 0, \ elem_size * sizeof(TYPE), ub_tensor_64, ub_tensor_32); \ } \ } @@ -414,7 +415,7 @@ SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_TENSOR_DETAILED_NBI); ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR \ + UB_ALIGN_SIZE); \ ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; \ - smem_shm_roce_write((__gm__ uint8_t*)src, (__gm__ uint8_t*)ptr, pe, 0, elem_size * sizeof(TYPE), \ + shmemi_roce_write((__gm__ uint8_t*)src, (__gm__ uint8_t*)ptr, pe, 0, elem_size * sizeof(TYPE), \ ub_tensor_64, ub_tensor_32); \ } \ } @@ -488,7 +489,7 @@ SHMEM_TYPE_FUNC(SHMEM_PUT_TYPENAME_MEM_DETAILED_NBI); ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR \ + UB_ALIGN_SIZE); \ ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; \ - smem_shm_roce_write((__gm__ uint8_t*)(src.GetPhyAddr()), (__gm__ uint8_t*)ptr, pe, 0, \ + shmemi_roce_write((__gm__ uint8_t*)(src.GetPhyAddr()), (__gm__ uint8_t*)ptr, pe, 0, \ elem_size * sizeof(TYPE), ub_tensor_64, ub_tensor_32); \ } \ } -- Gitee From ace6d45d96790774e808590fe304a130f6f1780c Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Tue, 2 Sep 2025 21:00:27 +0800 Subject: [PATCH 18/74] Add RDMA allgather demo. --- examples/CMakeLists.txt | 1 + examples/rdma_demo/CMakeLists.txt | 9 +++ examples/rdma_demo/README.md | 22 ++++++ examples/rdma_demo/main.cpp | 101 ++++++++++++++++++++++++ examples/rdma_demo/rdma_demo_kernel.cpp | 43 ++++++++++ 5 files changed, 176 insertions(+) create mode 100644 examples/rdma_demo/CMakeLists.txt create mode 100644 examples/rdma_demo/README.md create mode 100644 examples/rdma_demo/main.cpp create mode 100644 examples/rdma_demo/rdma_demo_kernel.cpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 94920765..3262e525 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -52,6 +52,7 @@ foreach(EXAMPLE allgather matmul_allreduce rdma_perftest + rdma_demo ) add_subdirectory(${EXAMPLE}) endforeach() \ No newline at end of file diff --git a/examples/rdma_demo/CMakeLists.txt b/examples/rdma_demo/CMakeLists.txt new file mode 100644 index 00000000..69346172 --- /dev/null +++ b/examples/rdma_demo/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +shmem_add_collective_example(rdma_demo) \ No newline at end of file diff --git a/examples/rdma_demo/README.md b/examples/rdma_demo/README.md new file mode 100644 index 00000000..410ef0e1 --- /dev/null +++ b/examples/rdma_demo/README.md @@ -0,0 +1,22 @@ +使用方式: +1.在shmem/目录编译: +```bash +bash scripts/build.sh +``` +2.在shmem/目录运行: +```bash +export PROJECT_ROOT= +export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/smem/lib64:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/hybm/lib64:$LD_LIBRARY_PATH +./build/bin/rdma_demo 2 0 tcp://127.0.0.1:8765 2 0 0 # rank 0 +./build/bin/rdma_demo 2 1 tcp://127.0.0.1:8765 2 0 0 # rank 1 +``` + +3.命令行参数说明 + ./rdma_demo + +- n_ranks: 全局Rank数量,只支持2个Rank。 +- rank_id: 当前进程的Rank号。 +- ipport: SHMEM初始化需要的IP及端口号,格式为tcp://:<端口号>。如果执行跨机测试,需要讲IP设为rank0所在Host的IP。 +- g_npus: 当前卡上启动的NPU数量。 +- f_rank: 当前卡上使用的第一个Rank号。 +- f_npu: 当前卡上使用的第一个NPU卡号。 \ No newline at end of file diff --git a/examples/rdma_demo/main.cpp b/examples/rdma_demo/main.cpp new file mode 100644 index 00000000..bf1ff56e --- /dev/null +++ b/examples/rdma_demo/main.cpp @@ -0,0 +1,101 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include +#include +#include +#include + +#include "acl/acl.h" +#include "shmem_api.h" +#include "shmemi_host_common.h" + +int g_npus = 8; +const char *ipport; +int f_rank = 0; +int f_npu = 0; +extern void allgather_demo(uint32_t block_dim, void* stream, uint8_t* gva, int message_length); + +int test_shmem_team_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + // 初始化ACL和SHMEM + int32_t device_id = rank_id % g_npus + f_npu; + int status = 0; + aclrtStream stream = nullptr; + + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); + status = aclrtCreateStream(&stream); + + shmem_init_attr_t *attributes; + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; + shmem_set_conf_store_tls(false, nullptr, 0); + status = shmem_init_attr(attributes); + + uint8_t *ptr = (uint8_t*)shmem_malloc(1024); + + // 初始化数据 + uint32_t trans_size = 16; + std::vector input(trans_size, 0); + for (int i = 0; i < trans_size; i++) { + input[i] = (rank_id + 10); + } + + status = aclrtMemcpy(ptr + shmem_my_pe() * trans_size * sizeof(int32_t), trans_size * sizeof(int32_t), + input.data(), trans_size * sizeof(int32_t), ACL_MEMCPY_HOST_TO_DEVICE); + + // AllGather + allgather_demo(1, stream, (uint8_t *)ptr, trans_size * sizeof(int32_t)); + status = aclrtSynchronizeStream(stream); + shm::shmemi_control_barrier_all(); + + // 结果校验打印 + int32_t *y_host; + size_t input_size = n_ranks * trans_size * sizeof(int32_t); + status = aclrtMallocHost(reinterpret_cast(&y_host), input_size); + status = aclrtMemcpy(y_host, input_size, ptr, input_size, ACL_MEMCPY_DEVICE_TO_HOST); + + for (int i = 0; i < n_ranks; i++) { + if (y_host[trans_size * i] != 10 + i) { + std::cout << y_host[trans_size * i] << " != " << 10 + i << std::endl; + std::exit(EXIT_FAILURE); + } + } + std::cout << "rank: " << rank_id << " ["; + for (int j = 0; j < trans_size * n_ranks; j++) { + std::cout << y_host[j] << ", "; + } + std::cout << "]" << std::endl; + // 去初始化 + status = aclrtFreeHost(y_host); + shmem_free(ptr); + status = shmem_finalize(); + status = aclrtDestroyStream(stream); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + return 0; +} + +int main(int argc, char *argv[]) +{ + int status = 0; + int n_ranks = atoi(argv[1]); + int rank_id = atoi(argv[2]); + ipport = argv[3]; + g_npus = atoi(argv[4]); + f_rank = atoi(argv[5]); + f_npu = atoi(argv[6]); + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + status = test_shmem_team_all_gather(rank_id, n_ranks, local_mem_size); + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; + + return 0; +} \ No newline at end of file diff --git a/examples/rdma_demo/rdma_demo_kernel.cpp b/examples/rdma_demo/rdma_demo_kernel.cpp new file mode 100644 index 00000000..92448e33 --- /dev/null +++ b/examples/rdma_demo/rdma_demo_kernel.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef _RDMA_DEMO_KERNEL_ +#define _RDMA_DEMO_KERNEL_ + +#include "kernel_operator.h" +#include "shmem_api.h" + +// all_gather简易实现 +extern "C" __global__ __aicore__ void device_all_gather_test(GM_ADDR gva, int message_length) +{ + AscendC::TPipe pipe; + AscendC::TBuf buf; + pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); + // 需要用户指定一个长度大于等于64字节的LocalTensor用于RDMA任务下发 + AscendC::LocalTensor ubLocal = buf.GetWithOffset(UB_ALIGN_SIZE * 2, 0); + int64_t my_rank = shmem_my_pe(); + int64_t pe_size = shmem_n_pes(); + AscendC::PipeBarrier(); + // All Gather + for (int i = 0; i < pe_size; i++) { + if (i == my_rank) { + continue; + } + shmem_roce_put_mem_nbi(gva + message_length * my_rank, gva + message_length * my_rank, + (__ubuf__ uint8_t*)ubLocal.GetPhyAddr(), message_length, i); + } +} + +void allgather_demo(uint32_t block_dim, void* stream, uint8_t* gva, int elements) +{ + device_all_gather_test<<>>(gva, elements); +} + +#endif // _RDMA_DEMO_KERNEL_ \ No newline at end of file -- Gitee From e0e6517b6707d136b9350b8ae40da22bf18a0393 Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Wed, 3 Sep 2025 19:20:14 +0800 Subject: [PATCH 19/74] Change RDMA lowlevel interface. --- examples/rdma_perftest/rdma_perftest_kernel.cpp | 8 ++++---- include/device/low_level/shmem_device_low_level_rma.h | 8 ++++---- include/device/low_level/shmem_device_low_level_roce.h | 8 ++++---- include/device/shmem_device_rma.h | 8 ++++---- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/examples/rdma_perftest/rdma_perftest_kernel.cpp b/examples/rdma_perftest/rdma_perftest_kernel.cpp index 91e91056..733ba7dd 100644 --- a/examples/rdma_perftest/rdma_perftest_kernel.cpp +++ b/examples/rdma_perftest/rdma_perftest_kernel.cpp @@ -99,7 +99,7 @@ extern "C" __global__ __aicore__ void rdma_postsend_cost(uint64_t fftsConfig, GM GM_ADDR dest_addr = (GM_ADDR)(shmem_ptr(src_addr, peer)); int64_t start = AscendC::GetSystemCycle(); for (uint32_t i = 0; i < 500; i++) { - shmemi_roce_write(src_addr, dest_addr, peer, 0, message_length, ubLocal64, ubLocal32); + shmemi_roce_write(dest_addr, src_addr, peer, 0, message_length, ubLocal64, ubLocal32); } AscendC::PipeBarrier(); int64_t end = AscendC::GetSystemCycle(); @@ -180,10 +180,10 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_AD peer = 1; int64_t start = AscendC::GetSystemCycle(); for (int i = 0; i < 10000; i++) { - shmemi_roce_write(src_addr, (GM_ADDR)shmem_ptr(src_addr, peer), peer, 0, message_length, ubLocal64, ubLocal32); + shmemi_roce_write((GM_ADDR)shmem_ptr(src_addr, peer), src_addr, peer, 0, message_length, ubLocal64, ubLocal32); } shmemi_roce_quiet(peer, 0, ubLocal64, ubLocal32); - shmemi_roce_write(src_addr, (GM_ADDR)shmem_ptr(gva + rank_size * message_length * 2 + 8, peer), peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); + shmemi_roce_write((GM_ADDR)shmem_ptr(gva + rank_size * message_length * 2 + 8, peer), src_addr, peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); while (*(__gm__ int64_t*)(gva + message_length * rank_size * 2 + 16) != peer + MAGIC_VAL + iter) { cacheWriteThrough(gva + message_length * rank_size * 2 + 16, 8); AscendC::GetSystemCycle(); @@ -198,7 +198,7 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_AD AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); - shmemi_roce_write(src_addr, (GM_ADDR)shmem_ptr(gva + rank_size * message_length * 2 + 16, peer), peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); + shmemi_roce_write((GM_ADDR)shmem_ptr(gva + rank_size * message_length * 2 + 16, peer), src_addr, peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); } } else { // core 1, MTE GM_ADDR src_addr = gva + (rank + rank_size) * message_length; diff --git a/include/device/low_level/shmem_device_low_level_rma.h b/include/device/low_level/shmem_device_low_level_rma.h index e843f609..e4a674f6 100644 --- a/include/device/low_level/shmem_device_low_level_rma.h +++ b/include/device/low_level/shmem_device_low_level_rma.h @@ -103,7 +103,7 @@ SHMEM_DEVICE void shmem_roce_get_mem_nbi(__gm__ T* dst, __gm__ T* src, __ubuf__ ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); ub_tensor_64.address_.bufferAddr = reinterpret_cast(buf) + UB_ALIGN_SIZE; ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; - shmemi_roce_read((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)dst, pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); + shmemi_roce_read((__gm__ uint8_t*)dst, (__gm__ uint8_t*)ptr, pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); } @@ -217,7 +217,7 @@ SHMEM_DEVICE void shmem_roce_get_mem_nbi(AscendC::GlobalTensor dst, AscendC:: ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); ub_tensor_64.address_.bufferAddr = reinterpret_cast(buf.GetPhyAddr()) + UB_ALIGN_SIZE; ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; - shmemi_roce_read((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)dst.GetPhyAddr(), pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); + shmemi_roce_read((__gm__ uint8_t*)dst.GetPhyAddr(), (__gm__ uint8_t*)ptr, pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); } /** @@ -321,7 +321,7 @@ SHMEM_DEVICE void shmem_roce_put_mem_nbi(__gm__ T* dst, __gm__ T* src, __ubuf__ ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); ub_tensor_64.address_.bufferAddr = reinterpret_cast(buf) + UB_ALIGN_SIZE; ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; - shmemi_roce_write((__gm__ uint8_t*)src, (__gm__ uint8_t*)ptr, pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); + shmemi_roce_write((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)src, pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); } /** @@ -434,7 +434,7 @@ SHMEM_DEVICE void shmem_roce_put_mem_nbi(AscendC::GlobalTensor dst, AscendC:: ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); ub_tensor_64.address_.bufferAddr = reinterpret_cast(buf.GetPhyAddr()) + UB_ALIGN_SIZE; ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; - shmemi_roce_write((__gm__ uint8_t*)(src.GetPhyAddr()), (__gm__ uint8_t*)ptr, pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); + shmemi_roce_write((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)(src.GetPhyAddr()), pe, 0, elem_size * sizeof(T), ub_tensor_64, ub_tensor_32); } /** diff --git a/include/device/low_level/shmem_device_low_level_roce.h b/include/device/low_level/shmem_device_low_level_roce.h index 54ad7f19..4c752252 100644 --- a/include/device/low_level/shmem_device_low_level_roce.h +++ b/include/device/low_level/shmem_device_low_level_roce.h @@ -309,8 +309,8 @@ SHMEM_DEVICE void shmemi_rdma_post_send(__gm__ uint8_t* remoteAddr, __gm__ uint8 /** * @brief Asynchronous RDMA Write function. * - * @param srcDmaAddr [in] source address in local HBM * @param destDmaAddr [in] destination address in remote HBM + * @param srcDmaAddr [in] source address in local HBM * @param destRankId [in] destination rank ID * @param qpIdx [in] QP index in multi-QP scenario (default 0 for single QP) * @param messageLen [in] message length in Bytes @@ -319,7 +319,7 @@ SHMEM_DEVICE void shmemi_rdma_post_send(__gm__ uint8_t* remoteAddr, __gm__ uint8 */ template -SHMEM_DEVICE void shmemi_roce_write(__gm__ T* srcDmaAddr, __gm__ T* destDmaAddr, uint32_t destRankId, +SHMEM_DEVICE void shmemi_roce_write(__gm__ T* destDmaAddr, __gm__ T* srcDmaAddr, uint32_t destRankId, uint32_t qpIdx, uint64_t messageLen, AscendC::LocalTensor ubLocal64, AscendC::LocalTensor ubLocal32) @@ -331,8 +331,8 @@ SHMEM_DEVICE void shmemi_roce_write(__gm__ T* srcDmaAddr, __gm__ T* destDmaAddr, /** * @brief Asynchronous RDMA READ function. * - * @param srcDmaAddr [in] source address in remote HBM * @param destDmaAddr [in] destination address in local HBM + * @param srcDmaAddr [in] source address in remote HBM * @param srcRankId [in] destination rank ID * @param qpIdx [in] QP index in multi-QP scenario (default 0 for single QP) * @param messageLen [in] message length in Bytes @@ -341,7 +341,7 @@ SHMEM_DEVICE void shmemi_roce_write(__gm__ T* srcDmaAddr, __gm__ T* destDmaAddr, */ template -SHMEM_DEVICE void shmemi_roce_read(__gm__ T* srcDmaAddr, __gm__ T* destDmaAddr, uint32_t srcRankId, +SHMEM_DEVICE void shmemi_roce_read(__gm__ T* destDmaAddr, __gm__ T* srcDmaAddr, uint32_t srcRankId, uint32_t qpIdx, uint64_t messageLen, AscendC::LocalTensor ubLocal64, AscendC::LocalTensor ubLocal32) diff --git a/include/device/shmem_device_rma.h b/include/device/shmem_device_rma.h index e57b5ef6..cbab2a2d 100644 --- a/include/device/shmem_device_rma.h +++ b/include/device/shmem_device_rma.h @@ -267,7 +267,7 @@ SHMEM_DEVICE void shmem_getmem_nbi(__gm__ void *dst, __gm__ void *src, uint32_t ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR \ + UB_ALIGN_SIZE); \ ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; \ - shmemi_roce_read((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)dst, pe, 0, elem_size * sizeof(TYPE), \ + shmemi_roce_read((__gm__ uint8_t*)dst, (__gm__ uint8_t*)ptr, pe, 0, elem_size * sizeof(TYPE), \ ub_tensor_64, ub_tensor_32); \ } \ } @@ -342,7 +342,7 @@ SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_DETAILED_NBI); ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR \ + UB_ALIGN_SIZE); \ ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; \ - shmemi_roce_read((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)(dst.GetPhyAddr()), pe, 0, \ + shmemi_roce_read((__gm__ uint8_t*)(dst.GetPhyAddr()), (__gm__ uint8_t*)ptr, pe, 0, \ elem_size * sizeof(TYPE), ub_tensor_64, ub_tensor_32); \ } \ } @@ -415,7 +415,7 @@ SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_TENSOR_DETAILED_NBI); ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR \ + UB_ALIGN_SIZE); \ ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; \ - shmemi_roce_write((__gm__ uint8_t*)src, (__gm__ uint8_t*)ptr, pe, 0, elem_size * sizeof(TYPE), \ + shmemi_roce_write((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)src, pe, 0, elem_size * sizeof(TYPE), \ ub_tensor_64, ub_tensor_32); \ } \ } @@ -489,7 +489,7 @@ SHMEM_TYPE_FUNC(SHMEM_PUT_TYPENAME_MEM_DETAILED_NBI); ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR \ + UB_ALIGN_SIZE); \ ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; \ - shmemi_roce_write((__gm__ uint8_t*)(src.GetPhyAddr()), (__gm__ uint8_t*)ptr, pe, 0, \ + shmemi_roce_write((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)(src.GetPhyAddr()), pe, 0, \ elem_size * sizeof(TYPE), ub_tensor_64, ub_tensor_32); \ } \ } -- Gitee From d587af60209ef7df79d9cce64172116681519188 Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Mon, 8 Sep 2025 21:02:23 +0800 Subject: [PATCH 20/74] Clean code. --- .../low_level/shmem_device_low_level_roce.h | 50 ++++++------------- include/internal/device/shmemi_device_arch.h | 17 +++++++ tests/unittest/host/main_test.cpp | 24 +++++++++ .../host/mem/rdma_mem/rdma_mem_host_test.cpp | 4 +- 4 files changed, 59 insertions(+), 36 deletions(-) diff --git a/include/device/low_level/shmem_device_low_level_roce.h b/include/device/low_level/shmem_device_low_level_roce.h index 4c752252..d8997dff 100644 --- a/include/device/low_level/shmem_device_low_level_roce.h +++ b/include/device/low_level/shmem_device_low_level_roce.h @@ -13,23 +13,8 @@ #include "kernel_operator.h" #include "internal/device/shmemi_device_common.h" -constexpr uint64_t SHMEM_DATA_CACHE_LINE_SIZE = 64; constexpr uint32_t SHMEM_NUM_CQE_PER_POLL_CQ = 100; -SHMEM_DEVICE void cacheInvalid(__gm__ uint8_t* sourceAddr, uint64_t length) { - __gm__ uint8_t* start = (__gm__ uint8_t*)((uint64_t)sourceAddr / SHMEM_DATA_CACHE_LINE_SIZE * SHMEM_DATA_CACHE_LINE_SIZE); - __gm__ uint8_t* end = - (__gm__ uint8_t*)( - ((uint64_t)sourceAddr + length) / SHMEM_DATA_CACHE_LINE_SIZE * SHMEM_DATA_CACHE_LINE_SIZE - ); - AscendC::GlobalTensor global; - global.SetGlobalBuffer(start); - for (uint64_t i = 0; i <= end - start; i+= SHMEM_DATA_CACHE_LINE_SIZE) { - AscendC::DataCacheCleanAndInvalid(global[i]); - } -} - enum class SHMEMAIVOPCODE : uint32_t { OP_SEND = 0, OP_SEND_WITH_INV, @@ -140,7 +125,7 @@ SHMEM_DEVICE uint32_t shmemi_roce_poll_cq(uint32_t remoteRankId, uint32_t qpIdx, auto cqeSize = cqCtxEntry->cqeSize; auto depth = cqCtxEntry->depth; auto curHardwareTailAddr = cqCtxEntry->tailAddr; - cacheInvalid((__gm__ uint8_t*)curHardwareTailAddr, 8); + dcci_cachelines((__gm__ uint8_t*)curHardwareTailAddr, 8); uint32_t curTail = *(__gm__ uint32_t*)(curHardwareTailAddr); AscendC::DataCopyExtParams copyParamsTail{1, 1 * sizeof(uint32_t), 0, 0, 0}; @@ -148,12 +133,12 @@ SHMEM_DEVICE uint32_t shmemi_roce_poll_cq(uint32_t remoteRankId, uint32_t qpIdx, __gm__ SHMEMcqeCtx* cqeAddr = (__gm__ SHMEMcqeCtx*)(cqBaseAddr + cqeSize * (curTail & (depth - 1))); uint32_t cqeByte4 = *(__gm__ uint32_t*)cqeAddr; while (((cqeByte4 & (1 << 7)) != 0) == ((curTail & depth) != 0)) { - int64_t tmp = AscendC::GetSystemCycle(); - cacheInvalid((__gm__ uint8_t*)cqeAddr, 32); + int64_t tmp = AscendC::GetSystemCycle(); // reserved for timeout check + dcci_cachelines((__gm__ uint8_t*)cqeAddr, 32); cqeByte4 = *(__gm__ uint32_t*)cqeAddr; } curTail++; - uint32_t wqn = cqeAddr->byte16 & 0xFFFFFF; + uint32_t wqn = cqeAddr->byte16 & 0xFFFFFF; // reserved for multi WQ share the same CQ // Check CQE status uint32_t status = (cqeAddr->byte4 >> 8) & 0xFF; @@ -169,7 +154,6 @@ SHMEM_DEVICE uint32_t shmemi_roce_poll_cq(uint32_t remoteRankId, uint32_t qpIdx, AscendC::PipeBarrier(); AscendC::DataCopyPad(TailGlobalTensor, ubLocal32, copyParamsTail); AscendC::PipeBarrier(); - cacheInvalid((__gm__ uint8_t*)curHardwareTailAddr, 8); // Ring CQ Doorbell auto cqDBAddr = cqCtxEntry->dbAddr; @@ -180,7 +164,6 @@ SHMEM_DEVICE uint32_t shmemi_roce_poll_cq(uint32_t remoteRankId, uint32_t qpIdx, AscendC::PipeBarrier(); AscendC::DataCopyPad(CQDBGlobalTensor, ubLocal32, copyParamsTail); AscendC::PipeBarrier(); - cacheInvalid((__gm__ uint8_t*)cqDBAddr, 8); } else if (cqCtxEntry->dbMode == SHMEMDBMode::HW_DB) { uint64_t doorBellInfo = 0; doorBellInfo |= cqCtxEntry->cqn; // [0:23] DB_TAG = qp_num @@ -199,7 +182,7 @@ SHMEM_DEVICE uint32_t shmemi_roce_poll_cq(uint32_t remoteRankId, uint32_t qpIdx, // Update WQ tail __gm__ SHMEMWQCtx* wqCtxEntry = (__gm__ SHMEMWQCtx*)(RDMAInfo->sqPtr + (remoteRankId * qpNum + qpIdx) * sizeof(SHMEMWQCtx)); auto curWQTailAddr = wqCtxEntry->tailAddr; - cacheInvalid((__gm__ uint8_t*)curWQTailAddr, 8); + dcci_cachelines((__gm__ uint8_t*)curWQTailAddr, 8); uint32_t curWQTail = *(__gm__ uint32_t*)(curWQTailAddr); ubLocal32.SetValue(0, curTail); AscendC::GlobalTensor WQTailGlobalTensor; @@ -207,7 +190,6 @@ SHMEM_DEVICE uint32_t shmemi_roce_poll_cq(uint32_t remoteRankId, uint32_t qpIdx, AscendC::PipeBarrier(); AscendC::DataCopyPad(WQTailGlobalTensor, ubLocal32, copyParamsTail); AscendC::PipeBarrier(); - cacheInvalid((__gm__ uint8_t*)curWQTailAddr, 8); return 0; } @@ -239,7 +221,7 @@ SHMEM_DEVICE void shmemi_rdma_post_send(__gm__ uint8_t* remoteAddr, __gm__ uint8 auto sqBaseAddr = qpCtxEntry->bufAddr; auto wqeSize = qpCtxEntry->wqeSize; auto curHardwareHeadAddr = qpCtxEntry->headAddr; - cacheInvalid((__gm__ uint8_t*)curHardwareHeadAddr, 8); + dcci_cachelines((__gm__ uint8_t*)curHardwareHeadAddr, 8); uint32_t curHead = *(__gm__ uint32_t*)(curHardwareHeadAddr); auto curHardwareTailAddr = qpCtxEntry->tailAddr; auto depth = qpCtxEntry->depth; @@ -247,7 +229,7 @@ SHMEM_DEVICE void shmemi_rdma_post_send(__gm__ uint8_t* remoteAddr, __gm__ uint8 AscendC::PipeBarrier(); // Poll CQ if send queue is full - cacheInvalid((__gm__ uint8_t*)curHardwareTailAddr, 8); + dcci_cachelines((__gm__ uint8_t*)curHardwareTailAddr, 8); if ((curHead + 10) % depth == (*(__gm__ uint32_t*)(curHardwareTailAddr)) % depth) { shmemi_roce_poll_cq(destRankId, qpIdx, *(__gm__ uint32_t*)(curHardwareTailAddr) + SHMEM_NUM_CQE_PER_POLL_CQ, ubLocal64, ubLocal32); @@ -276,7 +258,7 @@ SHMEM_DEVICE void shmemi_rdma_post_send(__gm__ uint8_t* remoteAddr, __gm__ uint8 *(__gm__ uint64_t*)(sgeAddr + 8) = (uint64_t)localAddr; // local VA // WQE & SGE cache flush - cacheInvalid(wqeAddr, sizeof(SHMEMwqeCtx) + sizeof(SHMEMsegCtx)); + dcci_cachelines(wqeAddr, sizeof(SHMEMwqeCtx) + sizeof(SHMEMsegCtx)); AscendC::PipeBarrier(); curHead++; @@ -369,7 +351,7 @@ SHMEM_DEVICE void shmemi_roce_quiet(uint32_t remoteRankId, uint32_t qpIdx, uint32_t qpNum = RDMAInfo->qpNum; __gm__ SHMEMWQCtx* qpCtxEntry = (__gm__ SHMEMWQCtx*)(RDMAInfo->sqPtr + (remoteRankId * qpNum + qpIdx) * sizeof(SHMEMWQCtx)); auto curHardwareHeadAddr = qpCtxEntry->headAddr; - cacheInvalid((__gm__ uint8_t*)curHardwareHeadAddr, 8); + dcci_cachelines((__gm__ uint8_t*)curHardwareHeadAddr, 8); uint32_t curHead = *(__gm__ uint32_t*)(curHardwareHeadAddr); shmemi_roce_poll_cq(remoteRankId, qpIdx, curHead, ubLocal64, ubLocal32); } @@ -392,7 +374,7 @@ SHMEM_DEVICE void shmemi_roce_qpinfo_test(__gm__ uint8_t* gva, uint32_t destRank *(__gm__ uint64_t*)(gva + 40) = (uint64_t)wqeSize; auto curHardwareHeadAddr = qpCtxEntry->headAddr; *(__gm__ uint64_t*)(gva + 48) = (uint64_t)curHardwareHeadAddr; - cacheInvalid((__gm__ uint8_t*)curHardwareHeadAddr, 8); + dcci_cachelines((__gm__ uint8_t*)curHardwareHeadAddr, 8); uint32_t curHead = *(__gm__ uint32_t*)(curHardwareHeadAddr); *(__gm__ uint64_t*)(gva + 56) = (uint64_t)curHead; auto curHardwareTailAddr = qpCtxEntry->tailAddr; @@ -441,7 +423,7 @@ SHMEM_DEVICE void shmemi_roce_pollcq_test(__gm__ T* srcDmaAddr, __gm__ T* destDm *(__gm__ uint64_t*)(gva + 24) = (uint64_t)depth; auto curHardwareTailAddr = cqCtxEntry->tailAddr; *(__gm__ uint64_t*)(gva + 32) = (uint64_t)curHardwareTailAddr; - cacheInvalid((__gm__ uint8_t*)curHardwareTailAddr, 8); + dcci_cachelines((__gm__ uint8_t*)curHardwareTailAddr, 8); uint32_t curTail = *(__gm__ uint32_t*)(curHardwareTailAddr); *(__gm__ uint64_t*)(gva + 40) = (uint64_t)curTail; @@ -451,7 +433,7 @@ SHMEM_DEVICE void shmemi_roce_pollcq_test(__gm__ T* srcDmaAddr, __gm__ T* destDm uint32_t cqeByte4 = *(__gm__ uint32_t*)cqeAddr; while (!(cqeByte4 & (1 << 7))) { int64_t tmp = AscendC::GetSystemCycle(); - cacheInvalid((__gm__ uint8_t*)cqeAddr, 32); + dcci_cachelines((__gm__ uint8_t*)cqeAddr, 32); cqeByte4 = *(__gm__ uint32_t*)cqeAddr; } *(__gm__ uint64_t*)(gva + 56) = (uint64_t)(cqeAddr->byte4); @@ -466,7 +448,7 @@ SHMEM_DEVICE void shmemi_roce_pollcq_test(__gm__ T* srcDmaAddr, __gm__ T* destDm __gm__ SHMEMWQCtx* wqCtxEntry = (__gm__ SHMEMWQCtx*)(RDMAInfo->sqPtr + (destRankId * qpNum + qpIdx) * sizeof(SHMEMWQCtx)); *(__gm__ uint64_t*)(gva + 104) = (uint64_t)(wqCtxEntry->wqn == wqn); auto curWQTailAddr = wqCtxEntry->tailAddr; - cacheInvalid((__gm__ uint8_t*)curWQTailAddr, 8); + dcci_cachelines((__gm__ uint8_t*)curWQTailAddr, 8); uint32_t curWQTail = *(__gm__ uint32_t*)(curWQTailAddr); ubLocal32.SetValue(0, curWQTail + 1); AscendC::GlobalTensor WQTailGlobalTensor; @@ -474,7 +456,7 @@ SHMEM_DEVICE void shmemi_roce_pollcq_test(__gm__ T* srcDmaAddr, __gm__ T* destDm AscendC::PipeBarrier(); AscendC::DataCopyPad(WQTailGlobalTensor, ubLocal32, copyParamsTail); AscendC::PipeBarrier(); - cacheInvalid((__gm__ uint8_t*)curWQTailAddr, 8); + dcci_cachelines((__gm__ uint8_t*)curWQTailAddr, 8); // Check CQE status uint32_t status = (cqeAddr->byte4 >> 8) & 0xFF; @@ -490,7 +472,7 @@ SHMEM_DEVICE void shmemi_roce_pollcq_test(__gm__ T* srcDmaAddr, __gm__ T* destDm AscendC::PipeBarrier(); AscendC::DataCopyPad(TailGlobalTensor, ubLocal32, copyParamsTail); AscendC::PipeBarrier(); - cacheInvalid((__gm__ uint8_t*)curHardwareTailAddr, 8); + dcci_cachelines((__gm__ uint8_t*)curHardwareTailAddr, 8); // Ring CQ Doorbell auto cqDBAddr = cqCtxEntry->dbAddr; @@ -500,7 +482,7 @@ SHMEM_DEVICE void shmemi_roce_pollcq_test(__gm__ T* srcDmaAddr, __gm__ T* destDm AscendC::PipeBarrier(); AscendC::DataCopyPad(CQDBGlobalTensor, ubLocal32, copyParamsTail); AscendC::PipeBarrier(); - cacheInvalid((__gm__ uint8_t*)cqDBAddr, 8); + dcci_cachelines((__gm__ uint8_t*)cqDBAddr, 8); } #endif // SHMEM_DEVICE_LOW_LEVEL_ROCE_H \ No newline at end of file diff --git a/include/internal/device/shmemi_device_arch.h b/include/internal/device/shmemi_device_arch.h index 66e38aba..07de181e 100644 --- a/include/internal/device/shmemi_device_arch.h +++ b/include/internal/device/shmemi_device_arch.h @@ -11,6 +11,7 @@ #define SHMEMI_DEVICE_ARCH_H #include "device/shmem_device_def.h" +constexpr uint64_t SHMEM_DATA_CACHE_LINE_SIZE = 64; SHMEM_DEVICE void dcci_cacheline(__gm__ uint8_t * addr) { using namespace AscendC; @@ -23,6 +24,22 @@ SHMEM_DEVICE void dcci_cacheline(__gm__ uint8_t * addr) { __asm__ __volatile__(""); } +SHMEM_DEVICE void dcci_cachelines(__gm__ uint8_t* addr, uint64_t length) { + __gm__ uint8_t* start = (__gm__ uint8_t*)((uint64_t)addr / SHMEM_DATA_CACHE_LINE_SIZE * SHMEM_DATA_CACHE_LINE_SIZE); + __gm__ uint8_t* end = + (__gm__ uint8_t*)( + ((uint64_t)addr + length) / SHMEM_DATA_CACHE_LINE_SIZE * SHMEM_DATA_CACHE_LINE_SIZE + ); + AscendC::GlobalTensor global; + global.SetGlobalBuffer(start); + for (uint64_t i = 0; i <= end - start; i+= SHMEM_DATA_CACHE_LINE_SIZE) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(global[i]); + __asm__ __volatile__(""); + } +} + SHMEM_DEVICE void dcci_entire_cache() { using namespace AscendC; GlobalTensor global; diff --git a/tests/unittest/host/main_test.cpp b/tests/unittest/host/main_test.cpp index aed20fe2..54e08442 100644 --- a/tests/unittest/host/main_test.cpp +++ b/tests/unittest/host/main_test.cpp @@ -36,6 +36,30 @@ void test_init(int rank_id, int n_ranks, uint64_t local_mem_size, aclrtStream *s EXPECT_EQ(status = shmem_set_conf_store_tls(false, nullptr, 0), 0); + shmem_init_attr_t* attributes; + shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); + status = shmem_init_attr(attributes); + EXPECT_EQ(status, 0); + *st = stream; +} + +void test_rdma_init(int rank_id, int n_ranks, uint64_t local_mem_size, aclrtStream *st) +{ + *st = nullptr; + int status = 0; + if (n_ranks != (n_ranks & (~(n_ranks - 1)))) { + std::cout << "[TEST] input rank_size: "<< n_ranks << " is not the power of 2" << std::endl; + status = -1; + } + EXPECT_EQ(status, 0); + EXPECT_EQ(aclInit(nullptr), 0); + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + EXPECT_EQ(status = aclrtSetDevice(device_id), 0); + aclrtStream stream = nullptr; + EXPECT_EQ(status = aclrtCreateStream(&stream), 0); + + EXPECT_EQ(status = shmem_set_conf_store_tls(false, nullptr, 0), 0); + shmem_init_attr_t* attributes; shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; diff --git a/tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp b/tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp index 3cec2bf9..31c82d7a 100644 --- a/tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp +++ b/tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp @@ -9,7 +9,7 @@ extern int test_gnpu_num; extern int test_first_npu; extern void test_mutil_task(std::function func, uint64_t local_mem_size, int processCount); -extern void test_init(int rank_id, int n_ranks, uint64_t local_mem_size, aclrtStream *st); +extern void test_rdma_init(int rank_id, int n_ranks, uint64_t local_mem_size, aclrtStream *st); extern void test_finalize(aclrtStream stream, int device_id); extern void test_rdma_put_low_level(uint32_t block_dim, void* stream, uint8_t* gva, uint64_t config); @@ -79,7 +79,7 @@ static void test_rdma_put_get(aclrtStream stream, uint8_t *gva, uint32_t rank_id void test_shmem_rdma_mem(int rank_id, int n_ranks, uint64_t local_mem_size) { int32_t device_id = rank_id % test_gnpu_num + test_first_npu; aclrtStream stream; - test_init(rank_id, n_ranks, local_mem_size, &stream); + test_rdma_init(rank_id, n_ranks, local_mem_size, &stream); ASSERT_NE(stream, nullptr); void* ptr = shmem_malloc(1024); -- Gitee From 8e71b874606ac292044b06d1a6851c87bd843a96 Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Tue, 9 Sep 2025 11:33:01 +0800 Subject: [PATCH 21/74] Fix example compile error. --- .../matmul_allreduce/epilogue/block/epilogue_allreduce.hpp | 4 ++-- examples/matmul_allreduce/kernel/matmul_epilogue_comm.hpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/matmul_allreduce/epilogue/block/epilogue_allreduce.hpp b/examples/matmul_allreduce/epilogue/block/epilogue_allreduce.hpp index bb34111d..acb7905d 100644 --- a/examples/matmul_allreduce/epilogue/block/epilogue_allreduce.hpp +++ b/examples/matmul_allreduce/epilogue/block/epilogue_allreduce.hpp @@ -165,7 +165,7 @@ public: auto offsetOut = blockOffset + rankBlockOffset; auto residueProcessShape = actualCommSubBlockShape % params.processShape; - auto processCount = CeilDiv(actualCommSubBlockShape, params.processShape); + MatrixCoord processCount = CeilDiv(actualCommSubBlockShape, params.processShape); uint32_t processLoop = processCount.row() * processCount.column(); // [ReduceScatter] 1. Alloc TmpUB @@ -240,7 +240,7 @@ public: auto offsetOut = outputBlockOffset + rankBlockOffset; auto residueProcessShape = actualCommSubBlockShape % params.processShape; - auto processCount = CeilDiv(actualCommSubBlockShape, params.processShape); + MatrixCoord processCount = CeilDiv(actualCommSubBlockShape, params.processShape); uint32_t processLoop = processCount.row() * processCount.column(); diff --git a/examples/matmul_allreduce/kernel/matmul_epilogue_comm.hpp b/examples/matmul_allreduce/kernel/matmul_epilogue_comm.hpp index bbcc7d4d..fb3b3e71 100644 --- a/examples/matmul_allreduce/kernel/matmul_epilogue_comm.hpp +++ b/examples/matmul_allreduce/kernel/matmul_epilogue_comm.hpp @@ -174,7 +174,7 @@ public: // Split core loop to comm loop tile MatrixCoord coreLoops{params.epilogueParams.gemmSwizzle.GetCoreLoops(), 1}; MatrixCoord commBlockCount{loopNumPerComm, 1}; - auto commLoops = CeilDiv(coreLoops, commBlockCount); + MatrixCoord commLoops = CeilDiv(coreLoops, commBlockCount); auto residueCommBlockCount = coreLoops % commBlockCount; MatrixCoord blockShape{params.blockShape.m(), params.blockShape.n()}; -- Gitee From b020b5e3d64374e52bd974cabf79fd03b10efe58 Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Tue, 9 Sep 2025 14:28:27 +0800 Subject: [PATCH 22/74] Change mf commit. --- 3rdparty/memfabric_hybrid | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/memfabric_hybrid b/3rdparty/memfabric_hybrid index 2af27076..e3101ced 160000 --- a/3rdparty/memfabric_hybrid +++ b/3rdparty/memfabric_hybrid @@ -1 +1 @@ -Subproject commit 2af2707690163d405f1ddf99094508a865f4c7a3 +Subproject commit e3101ced4fdf7dccbf680ae875bb85d8ed2f9eae -- Gitee From 9ff2efa3f54509cca92b0a33e85cfd7c402344da Mon Sep 17 00:00:00 2001 From: yinqiran Date: Wed, 17 Sep 2025 08:04:30 +0000 Subject: [PATCH 23/74] update OWNERS. --- OWNERS | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/OWNERS b/OWNERS index e90f9031..0b91d036 100644 --- a/OWNERS +++ b/OWNERS @@ -4,6 +4,8 @@ approvers: - nino233 - baoxiaom - victorwaang +- oioring +- gujianxiao reviewers: - git_ray @@ -19,4 +21,5 @@ reviewers: - huangxiaolan - Vector - lenokia -- victorwaang \ No newline at end of file +- victorwaang +- gujianxiao \ No newline at end of file -- Gitee From 59c72bcdb70e52426965e767d0c6c6002ca9395f Mon Sep 17 00:00:00 2001 From: xiebin Date: Wed, 17 Sep 2025 09:06:11 +0000 Subject: [PATCH 24/74] !366 refactor framework * fix compile issues * dynamic load bootstrap modules * refactor framework --- include/host/shmem_host_def.h | 8 + include/internal/host_device/shmemi_types.h | 9 - src/CMakeLists.txt | 2 + src/host/bootstrap/shmemi_bootstrap.cpp | 40 +++ src/host/bootstrap/shmemi_bootstrap.h | 10 + src/host/common/shmemi_host_types.h | 69 ++++ src/host/init/shmem_init_default.cpp | 325 ++++++++++++++++++ .../{shmem_init.cpp => shmem_init_mf.cpp} | 0 src/host/shmemi_host_common.h | 3 + src/host/transport/shmemi_transport.cpp | 40 +++ src/host/transport/shmemi_transport.h | 12 + .../bootstrap/shmemi_bootstrap_mpi.cpp | 31 ++ .../bootstrap/shmemi_bootstrap_uid.cpp | 33 ++ src/modules/transport/shmemi_mte.cpp | 32 ++ src/modules/transport/shmemi_rdma.cpp | 31 ++ src/transport/CMakeLists.txt | 7 - src/transport/adaptor/CMakeLists.txt | 7 - src/transport/adaptor/hccs/CMakeLists.txt | 7 - src/transport/adaptor/mte/CMakeLists.txt | 7 - src/transport/include/.gitkeep | 0 20 files changed, 636 insertions(+), 37 deletions(-) create mode 100644 src/host/bootstrap/shmemi_bootstrap.cpp create mode 100644 src/host/bootstrap/shmemi_bootstrap.h create mode 100644 src/host/common/shmemi_host_types.h create mode 100644 src/host/init/shmem_init_default.cpp rename src/host/init/{shmem_init.cpp => shmem_init_mf.cpp} (100%) create mode 100644 src/host/transport/shmemi_transport.cpp create mode 100644 src/host/transport/shmemi_transport.h create mode 100644 src/modules/bootstrap/shmemi_bootstrap_mpi.cpp create mode 100644 src/modules/bootstrap/shmemi_bootstrap_uid.cpp create mode 100644 src/modules/transport/shmemi_mte.cpp create mode 100644 src/modules/transport/shmemi_rdma.cpp delete mode 100644 src/transport/CMakeLists.txt delete mode 100644 src/transport/adaptor/CMakeLists.txt delete mode 100644 src/transport/adaptor/hccs/CMakeLists.txt delete mode 100644 src/transport/adaptor/mte/CMakeLists.txt delete mode 100644 src/transport/include/.gitkeep diff --git a/include/host/shmem_host_def.h b/include/host/shmem_host_def.h index 837ddd39..c5cbbb15 100644 --- a/include/host/shmem_host_def.h +++ b/include/host/shmem_host_def.h @@ -83,6 +83,14 @@ enum shmem_error_code_t : int { SHMEM_NOT_INITED = -5, ///< This is a problem caused by an uninitialization. }; +/** + * @brief init flags +*/ +enum shmemx_bootstrap_t : int { + SHMEMX_INIT_WITH_UNIQUEID = 1, + SHMEMX_INIT_WITH_MPI = 1 << 1, +}; + /** * @brief The state of the SHMEM library initialization. */ diff --git a/include/internal/host_device/shmemi_types.h b/include/internal/host_device/shmemi_types.h index b17f47f9..2df1ac71 100644 --- a/include/internal/host_device/shmemi_types.h +++ b/include/internal/host_device/shmemi_types.h @@ -92,15 +92,6 @@ typedef struct { shmemi_mte_config_t mte_config; } shmemi_device_host_state_t; -// host only state -typedef struct { - // typedef void *aclrtStream; as in https://www.hiascend.com/document/detail/zh/canncommercial/80RC3/apiref/appdevgapi/aclcppdevg_03_1355.html - void *default_stream; - // using TEventID = int8_t; as in https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/800alpha003/apiref/ascendcopapi/atlasascendc_api_07_0181.html - int8_t default_event_id; - uint32_t default_block_num; -} shmemi_host_state_t; - #ifdef __cplusplus } #endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index aa9fe9d0..0a8e2b81 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -21,6 +21,8 @@ target_include_directories(shmem_device file(GLOB_RECURSE SHMEM_HOST_FILES ${CMAKE_CURRENT_SOURCE_DIR}/host/*.cpp) list(FILTER SHMEM_HOST_FILES EXCLUDE REGEX "python_wrapper") +list(FILTER SHMEM_HOST_FILES EXCLUDE REGEX "modules") +list(FILTER SHMEM_HOST_FILES EXCLUDE REGEX "shmem_init_default.cpp") add_library(shmem_host OBJECT ${SHMEM_HOST_FILES}) target_compile_options(shmem_host PRIVATE ${CMAKE_CPP_COMPILE_OPTIONS}) target_include_directories(shmem_host diff --git a/src/host/bootstrap/shmemi_bootstrap.cpp b/src/host/bootstrap/shmemi_bootstrap.cpp new file mode 100644 index 00000000..c8183854 --- /dev/null +++ b/src/host/bootstrap/shmemi_bootstrap.cpp @@ -0,0 +1,40 @@ +#include "shmemi_host_common.h" +#include "dlfcn.h" + +#define BOOTSTRAP_MODULE_MPI "shmem_bootstrap_mpi.so" +#define BOOTSTRAP_MODULE_UID "shmem_bootstrap_uid.so" + +#define BOOTSTRAP_PLUGIN_INIT_FUNC "shmemi_bootstrap_plugin_init" + +shmemi_bootstrap_handle_t g_boot_handle; + +static void *plugin_hdl = nullptr; +static char *plugin_name = nullptr; + +// for UID +int32_t shmemi_bootstrap_pre_init() { + +} + +int32_t shmemi_bootstrap_init(int flags) { + if (flags & SHMEMX_INIT_WITH_MPI) { + plugin_name = BOOTSTRAP_MODULE_MPI; + } else if (flags & SHMEMX_INIT_WITH_UNIQUEID) { + plugin_name = BOOTSTRAP_MODULE_UID; + } else { + // error log + return -1; + } + + plugin_hdl = dlopen(plugin_name, RTLD_NOW); + int32_t (*plugin_init)(void *, shmemi_bootstrap_handle_t *); + *((void **)&plugin_init) = dlsym(plugin_hdl, BOOTSTRAP_PLUGIN_INIT_FUNC); + + plugin_init(nullptr, &g_boot_handle); +} + +int32_t shmemi_bootstrap_finalize() { + g_boot_handle.finalize(&g_boot_handle); + + dlclose(plugin_hdl); +} diff --git a/src/host/bootstrap/shmemi_bootstrap.h b/src/host/bootstrap/shmemi_bootstrap.h new file mode 100644 index 00000000..34988a67 --- /dev/null +++ b/src/host/bootstrap/shmemi_bootstrap.h @@ -0,0 +1,10 @@ +#ifndef SHMEMI_BOOTSTRAP_H +#define SHMEMI_BOOTSTRAP_H + +int32_t shmemi_bootstrap_pre_init(); + +int32_t shmemi_bootstrap_init(); + +int32_t shmemi_bootstrap_finalize(); + +#endif \ No newline at end of file diff --git a/src/host/common/shmemi_host_types.h b/src/host/common/shmemi_host_types.h new file mode 100644 index 00000000..b9f1c528 --- /dev/null +++ b/src/host/common/shmemi_host_types.h @@ -0,0 +1,69 @@ +#ifndef SHMEMI_HOST_TYPES_H +#define SHMEMI_HOST_TYPES_H + +#define SHMEM_MAX_TRANSPORT_NUM 16 + +typedef struct shmemi_bootstrap_handle { + int32_t mype, npes; + void *bootstrap_state; + + int (*finalize)(struct shmemi_bootstrap_handle *boot_handle); + int (*allgather)(void *dst, void *src, size_t size, struct shmemi_bootstrap_handle *boot_handle); + int (*barrier)(struct shmemi_bootstrap_handle *boot_handle); +} shmemi_bootstrap_handle_t; + +typedef struct shmemi_bootstrap_mpi_options { + // TBD +} shmemi_bootstrap_mpi_options_t; + +typedef struct shmemi_bootstrap_uid_options { + // TBD +} shmemi_bootstrap_uid_options_t; + +typedef struct shmemi_transport_pe_info { + int32_t mype; + uint32_t host_id; + uint32_t dev_id; +} shmemi_transport_pe_info_t; + +typedef struct shmemi_transport { + shmemi_bootstrap_handle_t *boot_handle; + + // control plane + int (*can_access_peer)(int *access, shmemi_transport_pe_info_t *peer, struct shmemi_transport *t); + int (*connect_peers)(struct shmemi_transport *t, int *selected_dev_ids, int num_selected_devs); + int (*finalize)(struct shmemi_transport *t); + + // data plane, TBD + void (*rma)(struct shmemi_transport *t, int32_t type, void *dst, void *src, size_t size, int32_t pe); + void (*amo)(struct shmemi_transport *t, int32_t type, void *dst, void *src, size_t size, int32_t pe); + void (*quiet)(struct shmemi_transport *t); + void (*fence)(struct shmemi_transport *t); +} shmemi_transport_t; + +typedef struct { + int32_t pe, npes; + + shmemi_bootstrap_mpi_options_t mpi_options; + shmemi_bootstrap_tcp_options_t tcp_options; + + // other options + bool rdma_enabled; +} shmemi_options_t; + +// host only state +typedef struct { + // typedef void *aclrtStream; as in https://www.hiascend.com/document/detail/zh/canncommercial/80RC3/apiref/appdevgapi/aclcppdevg_03_1355.html + void *default_stream; + // using TEventID = int8_t; as in https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/800alpha003/apiref/ascendcopapi/atlasascendc_api_07_0181.html + int8_t default_event_id; + uint32_t default_block_num; + + shmemi_options_t options; + + shmemi_bootstrap_handle_t *boot_handle; + shmemi_transport_t choosen_transports[SHMEM_MAX_TRANSPORT_NUM]; + int32_t num_choosen_transport; +} shmemi_host_state_t; + +#endif // SHMEMI_HOST_TYPES_H \ No newline at end of file diff --git a/src/host/init/shmem_init_default.cpp b/src/host/init/shmem_init_default.cpp new file mode 100644 index 00000000..efbc0a96 --- /dev/null +++ b/src/host/init/shmem_init_default.cpp @@ -0,0 +1,325 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include +#include +#include +#include +#include "acl/acl.h" +#include "shmemi_host_common.h" + +using namespace std; + +namespace shm { + +#define DEFAULT_MY_PE (-1) +#define DEFAULT_N_PES (-1) + +constexpr int DEFAULT_FLAG = 0; +constexpr int DEFAULT_ID = 0; +constexpr int DEFAULT_TIMEOUT = 120; +constexpr int DEFAULT_TEVENT = 0; +constexpr int DEFAULT_BLOCK_NUM = 1; + +// initializer +#define SHMEM_DEVICE_HOST_STATE_INITIALIZER \ + { \ + (1 << 16) + sizeof(shmemi_device_host_state_t), /* version */ \ + (DEFAULT_MY_PE), /* mype */ \ + (DEFAULT_N_PES), /* npes */ \ + NULL, /* heap_base */ \ + {NULL}, /* p2p_heap_base */ \ + {NULL}, /* sdma_heap_base */ \ + {}, /* topo_list */ \ + SIZE_MAX, /* heap_size */ \ + {NULL}, /* team_pools */ \ + 0, /* sync_pool */ \ + 0, /* sync_counter */ \ + 0, /* core_sync_pool */ \ + 0, /* core_sync_counter */ \ + false, /* shmem_is_shmem_initialized */ \ + false, /* shmem_is_shmem_created */ \ + {0, 16 * 1024, 0}, /* shmem_mte_config */ \ + } + +shmemi_device_host_state_t g_state = SHMEM_DEVICE_HOST_STATE_INITIALIZER; +shmemi_host_state_t g_state_host = {nullptr, DEFAULT_TEVENT, DEFAULT_BLOCK_NUM}; + +shmem_init_attr_t g_attr; +static smem_shm_t g_smem_handle = nullptr; +static bool g_attr_init = false; +static char *g_ipport = nullptr; + +int32_t version_compatible() +{ + int32_t status = SHMEM_SUCCESS; + return status; +} + +int32_t shmemi_options_init() +{ + int32_t status = SHMEM_SUCCESS; + return status; +} + +int32_t shmemi_state_init_attr(shmem_init_attr_t *attributes) +{ + int32_t status = SHMEM_SUCCESS; + g_state.mype = attributes->my_rank; + g_state.npes = attributes->n_ranks; + g_state.heap_size = attributes->local_mem_size + SHMEM_EXTRA_SIZE; + + aclrtStream stream = nullptr; + SHMEM_CHECK_RET(aclrtCreateStream(&stream)); + g_state_host.default_stream = stream; + g_state_host.default_event_id = DEFAULT_TEVENT; + g_state_host.default_block_num = DEFAULT_BLOCK_NUM; + return status; +} + +// TODO: use shmem native heap init +int32_t shmemi_heap_init_v2(shmem_init_attr_t *attributes) { + // 申请Physical Mem, 申请Virtual Addr,MMAP + // 连续映射, VA_size = rank * PA_size + // 非连续映射, VA_size = PA_size + + return 0; +} + +// TODO: use shmem native barrier +int32_t shmemi_control_barrier_all() +{ + return g_boot_handle.barrier(&g_boot_handle); +} + +// TODO: use shmem native global state +int32_t update_device_state() +{ + return 0; +} + +int32_t check_attr(shmem_init_attr_t *attributes) +{ + if ((attributes->my_rank < 0) || (attributes->n_ranks <= 0)) { + SHM_LOG_ERROR("my_rank:" << attributes->my_rank << " and n_ranks: " << attributes->n_ranks + << " cannot be less 0 , n_ranks still cannot be equal 0"); + return SHMEM_INVALID_VALUE; + } else if (attributes->n_ranks > SHMEM_MAX_RANKS) { + SHM_LOG_ERROR("n_ranks: " << attributes->n_ranks << " cannot be more than " << SHMEM_MAX_RANKS); + return SHMEM_INVALID_VALUE; + } else if (attributes->my_rank >= attributes->n_ranks) { + SHM_LOG_ERROR("n_ranks:" << attributes->n_ranks << " cannot be less than my_rank:" << attributes->my_rank); + return SHMEM_INVALID_PARAM; + } else if (attributes->local_mem_size <= 0) { + SHM_LOG_ERROR("local_mem_size:" << attributes->local_mem_size << " cannot be less or equal 0"); + return SHMEM_INVALID_VALUE; + } + return SHMEM_SUCCESS; +} + +} // namespace shm + +int32_t shmem_set_data_op_engine_type(shmem_init_attr_t *attributes, data_op_engine_type_t value) +{ + attributes->option_attr.data_op_engine_type = value; + return SHMEM_SUCCESS; +} + +int32_t shmem_set_timeout(shmem_init_attr_t *attributes, uint32_t value) +{ + attributes->option_attr.shm_init_timeout = value; + attributes->option_attr.shm_create_timeout = value; + attributes->option_attr.control_operation_timeout = value; + return SHMEM_SUCCESS; +} + +int32_t shmem_set_attr(int32_t my_rank, int32_t n_ranks, uint64_t local_mem_size, const char *ip_port, + shmem_init_attr_t **attributes) +{ + SHM_ASSERT_RETURN(local_mem_size <= SHMEM_MAX_LOCAL_SIZE, SHMEM_INVALID_VALUE); + SHM_ASSERT_RETURN(n_ranks <= SHMEM_MAX_RANKS, SHMEM_INVALID_VALUE); + SHM_ASSERT_RETURN(my_rank <= SHMEM_MAX_RANKS, SHMEM_INVALID_VALUE); + *attributes = &shm::g_attr; + if (ip_port == nullptr) { + SHM_LOG_ERROR("my_rank:" << my_rank << " ip_port is NULL!"); + return SHMEM_INVALID_PARAM; + } + // 安全警告:此处strlen依赖ip_port以null结尾,如果ip_port不是合法的C字符串,将导致越界读取 + if (ip_port == nullptr) { + SHM_LOG_ERROR("my_rank:" << my_rank << " ip_port is NULL!"); + return SHMEM_INVALID_PARAM; + } + // 安全警告:此处strlen依赖ip_port以null结尾,如果ip_port不是合法的C字符串,将导致越界读取 + size_t ip_len = strlen(ip_port); + shm::g_ipport = new (std::nothrow) char[ip_len + 1]; + if (shm::g_ipport == nullptr) { + SHM_LOG_ERROR("my_rank:" << my_rank << " failed to allocate IP port string!"); + return SHMEM_INNER_ERROR; + } + std::copy(ip_port, ip_port + ip_len + 1, shm::g_ipport); + if (shm::g_ipport == nullptr) { + SHM_LOG_ERROR("my_rank:" << my_rank << " shm::g_ipport is nullptr!"); + return SHMEM_INVALID_VALUE; + } + int attr_version = (1 << 16) + sizeof(shmem_init_attr_t); + shm::g_attr.my_rank = my_rank; + shm::g_attr.n_ranks = n_ranks; + shm::g_attr.ip_port = shm::g_ipport; + shm::g_attr.local_mem_size = local_mem_size; + shm::g_attr.option_attr = {attr_version, SHMEM_DATA_OP_MTE, shm::DEFAULT_TIMEOUT, + shm::DEFAULT_TIMEOUT, shm::DEFAULT_TIMEOUT}; + shm::g_attr_init = true; + return SHMEM_SUCCESS; +} + +int32_t shmem_init_status() +{ + if (!shm::g_state.is_shmem_created) + return SHMEM_STATUS_NOT_INITIALIZED; + else if (!shm::g_state.is_shmem_initialized) + return SHMEM_STATUS_SHM_CREATED; + else if (shm::g_state.is_shmem_initialized) + return SHMEM_STATUS_IS_INITIALIZED; + else + return SHMEM_STATUS_INVALID; +} + +void shmem_rank_exit(int status) +{ + SHM_LOG_DEBUG("shmem_rank_exit is work ,status: " << status); + exit(status); +} + +void shmem_init_attr(shmem_init_attr_t *attributes) { + // namespace to be deleted + using namespace shm; + + SHM_ASSERT_RETURN(attributes != nullptr, SHMEM_INVALID_PARAM); + SHMEM_CHECK_RET(shmem_set_log_level(shm::ERROR_LEVEL)); + SHMEM_CHECK_RET(check_attr(attributes)); + SHMEM_CHECK_RET(version_compatible()); + SHMEM_CHECK_RET(shmemi_options_init()); + + SHMEM_CHECK_RET(shmemi_state_init_attr(attributes)); + + shmemi_bootstrap_init(); + + shmemi_heap_init(attributes); + SHMEM_CHECK_RET(memory_manager_initialize(g_state.heap_base, g_state.heap_size)); + + shmemi_build_transport_map(); + + shmemi_transport_init(); + + SHMEM_CHECK_RET(shmemi_team_init(g_state.mype, g_state.npes)); + SHMEM_CHECK_RET(update_device_state()); + SHMEM_CHECK_RET(shmemi_sync_init()); + + g_state.is_shmem_initialized = true; + + return g_boot_handle.barrier(&g_boot_handle); +} + +int32_t shmem_finalize_v2() +{ + SHMEM_CHECK_RET(shm::shmemi_team_finalize()); + + if (shm::g_smem_handle != nullptr) { + int32_t status = smem_shm_destroy(shm::g_smem_handle, 0); + if (status != SHMEM_SUCCESS) { + SHM_LOG_ERROR("smem_shm_destroy Failed"); + return SHMEM_SMEM_ERROR; + } + shm::g_smem_handle = nullptr; + } + smem_shm_uninit(0); + smem_uninit(); + return SHMEM_SUCCESS; +} + +int32_t shmem_register_decrypt_handler(const shmem_decrypt_handler handler) +{ + return smem_register_decrypt_handler(handler); +} + +int32_t shmem_set_extern_logger(void (*func)(int level, const char *msg)) +{ + shm::shm_out_logger::Instance().set_extern_log_func(func, true); + return smem_set_extern_logger(func); +} + +int32_t shmem_set_log_level(int level) +{ + // use env first, input level secondly, user may change level from env instead call func + const char *in_level = std::getenv("SHMEM_LOG_LEVEL"); + if (in_level != nullptr) { + auto tmp_level = std::string(in_level); + if (tmp_level == "DEBUG") { + level = shm::DEBUG_LEVEL; + } else if (tmp_level == "INFO") { + level = shm::INFO_LEVEL; + } else if (tmp_level == "WARN") { + level = shm::WARN_LEVEL; + } else if (tmp_level == "ERROR") { + level = shm::ERROR_LEVEL; + } else if (tmp_level == "FATAL") { + level = shm::FATAL_LEVEL; + } + } + shm::shm_out_logger::Instance().set_log_level(static_cast(level)); + return smem_set_log_level(level); +} + +int32_t shmem_set_conf_store_tls(bool enable, const char *tls_info, const uint32_t tls_info_len) +{ + return smem_set_conf_store_tls(enable, tls_info, tls_info_len); +} + +int32_t shmem_finalize() +{ + SHMEM_CHECK_RET(shm::shmemi_team_finalize()); + if (shm::g_smem_handle != nullptr) { + int32_t status = smem_shm_destroy(shm::g_smem_handle, 0); + if (status != SHMEM_SUCCESS) { + SHM_LOG_ERROR("smem_shm_destroy Failed"); + return SHMEM_SMEM_ERROR; + } + shm::g_smem_handle = nullptr; + } + smem_shm_uninit(0); + smem_uninit(); + return SHMEM_SUCCESS; +} + +void shmem_info_get_version(int *major, int *minor) +{ + SHM_ASSERT_RET_VOID(major != nullptr && minor != nullptr); + *major = SHMEM_MAJOR_VERSION; + *minor = SHMEM_MINOR_VERSION; +} + +void shmem_info_get_name(char *name) +{ + SHM_ASSERT_RET_VOID(name != nullptr); + std::ostringstream oss; + oss << "SHMEM v" << SHMEM_VENDOR_MAJOR_VER << "." << SHMEM_VENDOR_MINOR_VER << "." << SHMEM_VENDOR_PATCH_VER; + auto version_str = oss.str(); + size_t i; + for (i = 0; i < SHMEM_MAX_NAME_LEN - 1 && version_str[i] != '\0'; i++) { + name[i] = version_str[i]; + } + name[i] = '\0'; +} + +void shmem_global_exit(int status) +{ + smem_shm_global_exit(shm::g_smem_handle, status); +} diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init_mf.cpp similarity index 100% rename from src/host/init/shmem_init.cpp rename to src/host/init/shmem_init_mf.cpp diff --git a/src/host/shmemi_host_common.h b/src/host/shmemi_host_common.h index b46ce602..907c01b0 100644 --- a/src/host/shmemi_host_common.h +++ b/src/host/shmemi_host_common.h @@ -14,10 +14,13 @@ #include "common/shmemi_logger.h" #include "common/shmemi_functions.h" +#include "common/shmemi_host_types.h" #include "init/shmemi_init.h" #include "team/shmemi_team.h" #include "mem/shmemi_mm.h" #include "sync/shmemi_sync.h" +#include "bootstrap/shmemi_bootstrap.h" +#include "transport/shmemi_transport.h" // smem api #include diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp new file mode 100644 index 00000000..b8f1e64c --- /dev/null +++ b/src/host/transport/shmemi_transport.cpp @@ -0,0 +1,40 @@ +#include "shmemi_host_common.h" + +extern shmemi_host_state_t g_host_state; + +int32_t shmemi_transport_init() { + uint32_t num_choosen_transport = 0; + +// #ifdef SHMEM_CONTINUOUS_ADDRESS_SPACE +// g_host_state.choosen_transports[num_choosen_transport++] = &g_transport_mte_c; +// g_host_state.choosen_transports[num_choosen_transport++] = &g_transport_rdma_c; +// #else +// g_host_state.choosen_transports[num_choosen_transport++] = &g_transport_mte_d; +// g_host_state.choosen_transports[num_choosen_transport++] = &g_transport_rdma_d; +// #endif + + g_host_state.num_choosen_transport = num_choosen_transport; + + for (int i = 0; i < num_choosen_transport; i++) { + auto t = g_host_state.choosen_transports + i; + t->boot_handle = g_host_state.boot_handle; + } +} + +int32_t shmemi_build_transport_map() { + // fill p2p/rdma/sdma heap bases +} + +int32_t shmemi_transport_setup_connections() { + for (int i = 0; i < g_host_state.num_choosen_transport; i++) { + auto t = g_host_state.choosen_transports + i; + t->connect_peers(t, nullptr, 0); + } +} + +int32_t shmemi_transport_finalize() { + for (int i = g_host_state.num_choosen_transport - 1; i >= 0; i--) { + auto t = g_host_state.choosen_transports + i; + t->finalize(t); + } +} \ No newline at end of file diff --git a/src/host/transport/shmemi_transport.h b/src/host/transport/shmemi_transport.h new file mode 100644 index 00000000..c08ecb6e --- /dev/null +++ b/src/host/transport/shmemi_transport.h @@ -0,0 +1,12 @@ +#ifndef SHMEMI_TRANSPORT_H +#define SHMEMI_TRANSPORT_H + +int32_t shmemi_transport_init(); + +int32_t shmemi_build_transport_map(); + +int32_t shmemi_transport_setup_connections(); + +int32_t shmemi_transport_finalize(); + +#endif \ No newline at end of file diff --git a/src/modules/bootstrap/shmemi_bootstrap_mpi.cpp b/src/modules/bootstrap/shmemi_bootstrap_mpi.cpp new file mode 100644 index 00000000..ffb8524a --- /dev/null +++ b/src/modules/bootstrap/shmemi_bootstrap_mpi.cpp @@ -0,0 +1,31 @@ +#ifdef SHMEM_BOOTSTRAP_MPI +#include "shmemi_host_common.h" +#include "mpi.h" + +typedef struct { + MPI_COMM comm; +} shmemi_bootstrap_mpi_state_t; + +static shmemi_bootstrap_mpi_state_t shmemi_bootstrap_mpi_state; + +int shmemi_bootstrap_plugin_init(void *args, shmemi_bootstrap_handle_t *boot_handle) { + // INIT + + handle->allgather = shmemi_bootstrap_mpi_allgather; + handle->barrier = shmemi_bootstrap_mpi_allgather; + handle->finalize = shmemi_bootstrap_mpi_finalize; +} + +int shmemi_bootstrap_mpi_finalize(shmemi_bootstrap_handle_t *boot_handle) { + +} + +int shmemi_bootstrap_mpi_allgather(void *dst, void *src, size_t size, shmemi_bootstrap_handle_t *boot_handle) { + +} + +int shmemi_bootstrap_mpi_barrier(shmemi_bootstrap_handle_t *boot_handle) { + +} + +#endif // SHMEM_BOOTSTRAP_MPI \ No newline at end of file diff --git a/src/modules/bootstrap/shmemi_bootstrap_uid.cpp b/src/modules/bootstrap/shmemi_bootstrap_uid.cpp new file mode 100644 index 00000000..34ec3185 --- /dev/null +++ b/src/modules/bootstrap/shmemi_bootstrap_uid.cpp @@ -0,0 +1,33 @@ +#ifdef SHMEM_BOOTSTRAP_UID +#include "shmemi_bootstrap.h" + +typedef struct { + char ifname[64]; + int af; + int32_t ip, port; +} shmemi_bootstrap_uid_state_t; + +static shmemi_bootstrap_uid_state_t shmemi_bootstrap_uid_state; + +int shmemi_bootstrap_plugin_init(void *args, shmemi_bootstrap_handle_t *boot_handle) { + // INIT + + handle->allgather = shmemi_bootstrap_uid_allgather; + handle->barrier = shmemi_bootstrap_uid_allgather; + handle->finalize = shmemi_bootstrap_uid_finalize; + +} + +int shmemi_bootstrap_uid_finalize(shmemi_bootstrap_handle_t *boot_handle) { + +} + +int shmemi_bootstrap_uid_allgather(void *dst, void *src, size_t size, shmemi_bootstrap_handle_t *boot_handle) { + +} + +int shmemi_bootstrap_uid_barrier(shmemi_bootstrap_handle_t *boot_handle) { + +} + +#endif // SHMEM_BOOTSTRAP_UID \ No newline at end of file diff --git a/src/modules/transport/shmemi_mte.cpp b/src/modules/transport/shmemi_mte.cpp new file mode 100644 index 00000000..492c2eaf --- /dev/null +++ b/src/modules/transport/shmemi_mte.cpp @@ -0,0 +1,32 @@ +#include "shmemi_transport.h" + +typedef struct { + +} shmemi_mted_transport_state_t; + +static shmemi_mted_transport_state_t shmemi_mted_transport_state; + +// control plane +int shmemi_mted_init(shmemi_host_state_t *state, shmemi_transport_t *t) { + +} + +int shmemi_mted_can_access_peer(int *access, shmemi_transport_pe_info_t *peer, shmemi_transport_t *t) { + // host相同——true,否则false +} + +int shmemi_mted_connect_peers(shmemi_transport_t *t, int *selected_dev_ids, int num_selected_devs) { + +} + +int shmemi_mted_finalize(shmemi_transport_t *t) { + +} + + +shmemi_transport_t shmemi_mted_transport_state = { + .init = shmemi_mted_init, + .finalize = shmemi_mted_finalize, + .can_access_peer = shmemi_mted_can_access_peer, + .connect_peers = shmemi_mted_connect_peers, +} \ No newline at end of file diff --git a/src/modules/transport/shmemi_rdma.cpp b/src/modules/transport/shmemi_rdma.cpp new file mode 100644 index 00000000..b35be44a --- /dev/null +++ b/src/modules/transport/shmemi_rdma.cpp @@ -0,0 +1,31 @@ +#include "shmemi_transport.h" + +typedef struct { + +} shmemi_rdmad_transport_state_t; + +static shmemi_rdmad_transport_state_t shmemi_rdmad_transport_state; + +// control plane +int shmemi_rdmad_init(shmemi_host_state_t *state, shmemi_transport_t *t) { + +} + +int shmemi_rdmad_can_access_peer(int *access, shmemi_transport_pe_info_t *peer, shmemi_transport_t *t) { + // true +} + +int shmemi_rdmad_connect_peers(shmemi_transport_t *t, int *selected_dev_ids, int num_selected_devs) { + // 建立QP链接 —— 获取NIC ip,check_peer_access(所有),创建sockets,创建(多)qp并连接 +} + +int shmemi_rdmad_finalize(shmemi_transport_t *t) { + +} + +shmemi_transport_t shmemi_rdmad_transport_state = { + .init = shmemi_rdmad_init, + .finalize = shmemi_rdmad_finalize, + .can_access_peer = shmemi_rdmad_can_access_peer, + .connect_peers = shmemi_rdmad_connect_peers, +} \ No newline at end of file diff --git a/src/transport/CMakeLists.txt b/src/transport/CMakeLists.txt deleted file mode 100644 index f882e7f5..00000000 --- a/src/transport/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. \ No newline at end of file diff --git a/src/transport/adaptor/CMakeLists.txt b/src/transport/adaptor/CMakeLists.txt deleted file mode 100644 index f882e7f5..00000000 --- a/src/transport/adaptor/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. \ No newline at end of file diff --git a/src/transport/adaptor/hccs/CMakeLists.txt b/src/transport/adaptor/hccs/CMakeLists.txt deleted file mode 100644 index f882e7f5..00000000 --- a/src/transport/adaptor/hccs/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. \ No newline at end of file diff --git a/src/transport/adaptor/mte/CMakeLists.txt b/src/transport/adaptor/mte/CMakeLists.txt deleted file mode 100644 index f882e7f5..00000000 --- a/src/transport/adaptor/mte/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. \ No newline at end of file diff --git a/src/transport/include/.gitkeep b/src/transport/include/.gitkeep deleted file mode 100644 index e69de29b..00000000 -- Gitee From 0678f7a8092e060864866dba5408d3cd26aabab5 Mon Sep 17 00:00:00 2001 From: zhangyunqi05 Date: Thu, 18 Sep 2025 19:11:13 +0800 Subject: [PATCH 25/74] mpi bootstrap with new framwork --- examples/hellow_word/CMakeLists.txt | 106 ++++++++++++++++ examples/hellow_word/main.cpp | 30 +++++ examples/hellow_word/run.sh | 21 ++++ include/host/shmem_host_init.h | 3 +- src/CMakeLists.txt | 40 +++++- src/host/bootstrap/shmemi_bootstrap.cpp | 110 ++++++++++++++-- src/host/bootstrap/shmemi_bootstrap.h | 22 +++- src/host/common/shmemi_host_types.h | 28 ++++- src/host/init/shmem_init_default.cpp | 57 ++------- src/host/transport/shmemi_transport.cpp | 2 +- .../bootstrap/shmemi_bootstrap_mpi.cpp | 119 +++++++++++++++--- 11 files changed, 454 insertions(+), 84 deletions(-) create mode 100644 examples/hellow_word/CMakeLists.txt create mode 100644 examples/hellow_word/main.cpp create mode 100644 examples/hellow_word/run.sh diff --git a/examples/hellow_word/CMakeLists.txt b/examples/hellow_word/CMakeLists.txt new file mode 100644 index 00000000..44b020af --- /dev/null +++ b/examples/hellow_word/CMakeLists.txt @@ -0,0 +1,106 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +cmake_minimum_required(VERSION 3.18) +project(SHMEM) + +# 设置C++标准 +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# 生成位置无关代码 +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +# 设置可执行文件输出目录 +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) + +# 设置安装路径 +set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_SOURCE_DIR}/install/shmem) + +# 获取CANN相关环境变量 +if(NOT DEFINED ENV{ASCEND_HOME_PATH}) + message(FATAL_ERROR "Cannot find ASCEND_HOME_PATH, please run set_env.sh.") +else() + set(ASCEND_HOME_PATH $ENV{ASCEND_HOME_PATH}) +endif() + +option(USE_UNIT_TEST "USE_UNIT_TEST" OFF) +option(USE_EXAMPLES "USE_EXAMPLES" OFF) +message(STATUS "USE_UNIT_TEST:${USE_UNIT_TEST}") +message(STATUS "USE_EXAMPLES:${USE_EXAMPLES}") +set(ASCEND_DRIVER_PATH /usr/local/Ascend/driver) + +set(CMAKE_COMPILER g++) +# set(CMAKE_C_COMPILER ${CMAKE_COMPILER}) +set(CMAKE_CXX_COMPILER ${CMAKE_COMPILER}) + +add_compile_options( + -O2 -std=c++17 + -Wno-macro-redefined -Wno-ignored-attributes + # avoid ascendc interference + -DL2_CACHE_HINT + -DTILING_KEY_VAR +) + +set(CMAKE_CPP_COMPILE_OPTIONS + -xc++ + "SHELL:-include stdint.h" + "SHELL:-include stddef.h" +) + +include_directories( + ${ASCEND_HOME_PATH}/compiler/tikcpp + ${ASCEND_HOME_PATH}/compiler/tikcpp/tikcfw + ${ASCEND_HOME_PATH}/compiler/tikcpp/tikcfw/impl + ${ASCEND_HOME_PATH}/compiler/tikcpp/tikcfw/interface + ${ASCEND_HOME_PATH}/include + ${ASCEND_HOME_PATH}/include/experiment/runtime + ${ASCEND_HOME_PATH}/include/experiment/msprof + ${ASCEND_DRIVER_PATH}/kernel/inc +) + +link_directories( + ${ASCEND_HOME_PATH}/lib64 + ${ASCEND_DRIVER_PATH}/lib64/driver +) + +link_libraries(runtime stdc++ ascendcl m tiling_api platform c_sec dl nnopbase ascend_hal pthread) + +find_package(MPI REQUIRED) + +include_directories( + ${MPI_INCLUDE_PATH} + ${ASCEND_HOME_PATH}/lib64 + ${ASCEND_DRIVER_PATH}/lib64/driver +) + +link_directories( + ${CMAKE_CURRENT_SOURCE_DIR}/../../install/shmem/lib +) + +add_executable(helloword main.cpp) + +target_include_directories(helloword PRIVATE + ${ASCEND_DRIVER_PATH}/kernel/inc + ${ASCEND_HOME_PATH}/include/ + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${CMAKE_CURRENT_SOURCE_DIR}/../../install/memfabric_hybrid/include/smem/host/ + ${CMAKE_CURRENT_SOURCE_DIR}/../../install/memfabric_hybrid/include/smem/device/ +) + +target_link_libraries(helloword PRIVATE MPI::MPI_CXX) +target_link_libraries(helloword PRIVATE shmem) + +target_compile_options(helloword PRIVATE ${MPI_CXX_COMPILE_FLAGS}) + +set_target_properties(helloword PROPERTIES + CXX_COMPILER "g++" +) +target_compile_options(helloword PRIVATE + "$<$:-g;-Wall>" +) \ No newline at end of file diff --git a/examples/hellow_word/main.cpp b/examples/hellow_word/main.cpp new file mode 100644 index 00000000..3798c73f --- /dev/null +++ b/examples/hellow_word/main.cpp @@ -0,0 +1,30 @@ +#include +#include +#include +#include +#include "shmem_api.h" +int main(int argc, char* argv[]) +{ + // 初始化MPI环境 + MPI_Init(&argc, &argv); + + // 获取当前进程的编号(rank) + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + int status = SHMEM_SUCCESS; + aclInit(nullptr); + aclrtSetDevice(rank); + shmem_init_attr_t *attributes; + status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); + + + status = shmem_finalize(); + if ( status != SHMEM_SUCCESS) { + std::cout << "[ERROR] demo run failed!" << std::endl; + std::exit(status); + } + aclrtResetDevice(rank); + aclFinalize(); + MPI_Finalize(); + std::cout << "[SUCCESS] demo run success!" << std::endl; +} \ No newline at end of file diff --git a/examples/hellow_word/run.sh b/examples/hellow_word/run.sh new file mode 100644 index 00000000..1bb367ac --- /dev/null +++ b/examples/hellow_word/run.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# +BULID_DIR="build" +if [ ! -d "$BULID_DIR" ]; then + mkdir "$BULID_DIR" +fi +cd $BULID_DIR + +cmake .. +make +cd .. + +mpirun -np 4 ./build/bin/helloword diff --git a/include/host/shmem_host_init.h b/include/host/shmem_host_init.h index c2ea76fd..ff691974 100644 --- a/include/host/shmem_host_init.h +++ b/include/host/shmem_host_init.h @@ -65,10 +65,11 @@ SHMEM_HOST_API int shmem_set_timeout(shmem_init_attr_t *attributes, uint32_t val * if the self-created attr structure is incorrect, the initialization will fail. * It is recommended to build the attributes by shmem_set_attr(). * + * @param bootstrap_flags [in] bootstrap_flags for init. * @param attributes [in] Pointer to the user-defined attributes. * @return Returns 0 on success or an error code on failure */ -SHMEM_HOST_API int shmem_init_attr(shmem_init_attr_t *attributes); +SHMEM_HOST_API int shmem_init_attr(uint32_t bootstrap_flags, shmem_init_attr_t *attributes); /** * @brief Register a decrypt key password handler. diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0a8e2b81..f1c0f54e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -10,6 +10,17 @@ if (BUILD_PYTHON STREQUAL "ON") add_subdirectory(host/python_wrapper) endif () +set(SHMEM_MPI_SUPPORT OFF) + +if(SHMEM_MPI_SUPPORT) + find_package(MPI REQUIRED) +else() + find_package(MPI) + if(MPI_FOUND) + set(SHMEM_MPI_SUPPORT ON) + endif() +endif() + file(GLOB_RECURSE SHMEM_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp) add_library(shmem_device OBJECT ${SHMEM_KERNEL_FILES}) target_compile_options(shmem_device PRIVATE ${CMAKE_CCE_COMPILE_OPTIONS} --cce-aicore-arch=dav-c220) @@ -22,7 +33,8 @@ target_include_directories(shmem_device file(GLOB_RECURSE SHMEM_HOST_FILES ${CMAKE_CURRENT_SOURCE_DIR}/host/*.cpp) list(FILTER SHMEM_HOST_FILES EXCLUDE REGEX "python_wrapper") list(FILTER SHMEM_HOST_FILES EXCLUDE REGEX "modules") -list(FILTER SHMEM_HOST_FILES EXCLUDE REGEX "shmem_init_default.cpp") +# list(FILTER SHMEM_HOST_FILES EXCLUDE REGEX "shmem_init_default.cpp") +list(FILTER SHMEM_HOST_FILES EXCLUDE REGEX "shmem_init_mf.cpp") add_library(shmem_host OBJECT ${SHMEM_HOST_FILES}) target_compile_options(shmem_host PRIVATE ${CMAKE_CPP_COMPILE_OPTIONS}) target_include_directories(shmem_host @@ -42,6 +54,32 @@ target_link_libraries(shmem ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_hybm_core.so ) +# MPI +if(SHMEM_MPI_SUPPORT) + separate_arguments(SHMEM_CXX_LINK_FLAGS NATIVE_COMMAND "${MPI_CXX_LINK_FLAGS}") + target_link_options(shmem INTERFACE ${SHMEM_CXX_LINK_FLAGS}) + target_compile_definitions(shmem INTERFACE ${MPI_CXX_COMPILE_DEFINITIONS}) + target_compile_options(shmem INTERFACE ${MPI_CXX_COMPILE_OPTIONS}) + + add_library( + shmem_bootstrap_mpi SHARED + ) + target_sources(shmem_bootstrap_mpi PRIVATE modules/bootstrap/shmemi_bootstrap_mpi.cpp) + target_link_libraries(shmem_bootstrap_mpi PRIVATE MPI::MPI_CXX) + target_include_directories(shmem_bootstrap_mpi + PRIVATE + ${PROJECT_SOURCE_DIR}/include/ + ${PROJECT_SOURCE_DIR}/src/host + ${PROJECT_SOURCE_DIR}/src/device + ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/host/ + ) + set_target_properties(shmem_bootstrap_mpi PROPERTIES PREFIX "") + install(TARGETS shmem_bootstrap_mpi + LIBRARY DESTINATION lib + ) + +endif() + # 安装配置 install(TARGETS shmem LIBRARY DESTINATION lib diff --git a/src/host/bootstrap/shmemi_bootstrap.cpp b/src/host/bootstrap/shmemi_bootstrap.cpp index c8183854..6d8e3dce 100644 --- a/src/host/bootstrap/shmemi_bootstrap.cpp +++ b/src/host/bootstrap/shmemi_bootstrap.cpp @@ -6,34 +6,124 @@ #define BOOTSTRAP_PLUGIN_INIT_FUNC "shmemi_bootstrap_plugin_init" +#define shmemxi_error_unlikely(x) __builtin_expect(!!(x), 0) + +#define SHMEMI_NULL_ERROR_JMP(var, status, err, ...) \ + do { \ + if (shmemxi_error_unlikely(var == NULL)) { \ + fprintf(stderr, "%s:%d: NULL value ", __FILE__, __LINE__); \ + fprintf(stderr, __VA_ARGS__); \ + fprintf(stderr, "\n"); \ + status = err; \ + } \ + } while (0) + +#define GET_SYMBOL(lib_handle, name, var, status) \ + do { \ + void **var_ptr = (void **)&(var); \ + void *tmp = (void *)dlsym(lib_handle, name); \ + SHMEMI_NULL_ERROR_JMP( \ + tmp, status, SHMEM_INNER_ERROR, "Bootstrap failed to get symbol '%s'\n\t%s\n", name, dlerror()); \ + *var_ptr = tmp; \ + } while (0) + shmemi_bootstrap_handle_t g_boot_handle; static void *plugin_hdl = nullptr; static char *plugin_name = nullptr; +void _bootstrap_loader_fini_helper(void *plugin_hdl, char *plugin_name) +{ + if (plugin_hdl != nullptr) { + dlclose(plugin_hdl); + plugin_hdl = nullptr; + } + + if (plugin_name != nullptr) { + free(plugin_name); + plugin_name = nullptr; + } +} + +int bootstrap_loader_finalize(shmemi_bootstrap_handle_t *handle) +{ + int status = handle->finalize(handle); + + if (status != 0) + SHM_LOG_ERROR("Bootstrap plugin finalize failed for " << plugin_name); + + dlclose(plugin_hdl); + plugin_hdl = nullptr; + free(plugin_name); + plugin_name = nullptr; + + return 0; +} + +static int _bootstrap_loader_init_helper(const char *plugin, shmemi_bootstrap_handle_t *handle) +{ + dlerror(); + if (plugin_name == nullptr) { + plugin_name = strdup(plugin); + if (!plugin_name) { + SHM_LOG_ERROR("Failed to strdup plugin name, err is: " << stderr); + return SHMEM_INVALID_VALUE; + } + } + + if (plugin_hdl == nullptr) { + plugin_hdl = dlopen(plugin, RTLD_NOW); + } + dlerror(); + + if (!plugin_hdl) { + SHM_LOG_ERROR("Bootstrap unable to load " << plugin << ", err is: " << stderr); + _bootstrap_loader_fini_helper(plugin_hdl, plugin_name); + return SHMEM_INVALID_VALUE; + } + + return SHMEM_SUCCESS; +} + +int bootstrap_loader_init(const char *plugin, void *arg, shmemi_bootstrap_handle_t *handle) +{ + int status = _bootstrap_loader_init_helper(plugin, handle); + if (status != 0) { + SHM_LOG_ERROR("Bootstrap library dlopen failed for " << plugin); + _bootstrap_loader_fini_helper(plugin_hdl, plugin_name); + return SHMEM_INNER_ERROR; + } + int (*bootstrap_plugin_initops)(void *arg, shmemi_bootstrap_handle_t *handle); + GET_SYMBOL(plugin_hdl, BOOTSTRAP_PLUGIN_INIT_FUNC, bootstrap_plugin_initops, status); + status = bootstrap_plugin_initops(arg, handle); + if (status != 0) { + SHM_LOG_ERROR("Bootstrap plugin init failed for " << plugin); + _bootstrap_loader_fini_helper(plugin_hdl, plugin_name); + return SHMEM_INNER_ERROR; + } + return SHMEM_SUCCESS; +} // for UID int32_t shmemi_bootstrap_pre_init() { } -int32_t shmemi_bootstrap_init(int flags) { +int32_t shmemi_bootstrap_init(int flags, shmemi_bootstrap_attr_t *attr) { + int32_t status; if (flags & SHMEMX_INIT_WITH_MPI) { plugin_name = BOOTSTRAP_MODULE_MPI; + status = bootstrap_loader_init(plugin_name, (attr != NULL) ? attr->mpi_comm : NULL, &g_boot_handle); } else if (flags & SHMEMX_INIT_WITH_UNIQUEID) { plugin_name = BOOTSTRAP_MODULE_UID; + status = bootstrap_loader_init(plugin_name, (attr->uid_args), &g_boot_handle); } else { - // error log - return -1; + SHM_LOG_ERROR("Unknown Type for bootstrap"); + status = SHMEM_INVALID_PARAM; } - - plugin_hdl = dlopen(plugin_name, RTLD_NOW); - int32_t (*plugin_init)(void *, shmemi_bootstrap_handle_t *); - *((void **)&plugin_init) = dlsym(plugin_hdl, BOOTSTRAP_PLUGIN_INIT_FUNC); - - plugin_init(nullptr, &g_boot_handle); + return status; } -int32_t shmemi_bootstrap_finalize() { +void shmemi_bootstrap_finalize() { g_boot_handle.finalize(&g_boot_handle); dlclose(plugin_hdl); diff --git a/src/host/bootstrap/shmemi_bootstrap.h b/src/host/bootstrap/shmemi_bootstrap.h index 34988a67..a3ad3637 100644 --- a/src/host/bootstrap/shmemi_bootstrap.h +++ b/src/host/bootstrap/shmemi_bootstrap.h @@ -1,10 +1,28 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + #ifndef SHMEMI_BOOTSTRAP_H #define SHMEMI_BOOTSTRAP_H +#ifdef __cplusplus +extern "C" { +#endif int32_t shmemi_bootstrap_pre_init(); -int32_t shmemi_bootstrap_init(); +int32_t shmemi_bootstrap_init(int flags, shmemi_bootstrap_attr_t *attr); + +void shmemi_bootstrap_finalize(); -int32_t shmemi_bootstrap_finalize(); +int shmemi_bootstrap_plugin_init(void *mpi_comm, shmemi_bootstrap_handle_t *handle); +#ifdef __cplusplus +} +#endif #endif \ No newline at end of file diff --git a/src/host/common/shmemi_host_types.h b/src/host/common/shmemi_host_types.h index b9f1c528..c7cbc6d9 100644 --- a/src/host/common/shmemi_host_types.h +++ b/src/host/common/shmemi_host_types.h @@ -3,13 +3,30 @@ #define SHMEM_MAX_TRANSPORT_NUM 16 +typedef struct shmemi_bootstrap_attr { + shmemi_bootstrap_attr() : initialize_mf(0), mpi_comm(NULL), uid_args(NULL) + {} + int initialize_mf; + void *mpi_comm; + void *mete_data; + void *uid_args; +} shmemi_bootstrap_attr_t; + +typedef struct shmemi_bootstrap_init_ops { + void *cookie; + int (*get_unique_id)(void *cookit); +} shmemi_bootstrap_init_ops_t; + typedef struct shmemi_bootstrap_handle { int32_t mype, npes; void *bootstrap_state; - int (*finalize)(struct shmemi_bootstrap_handle *boot_handle); - int (*allgather)(void *dst, void *src, size_t size, struct shmemi_bootstrap_handle *boot_handle); - int (*barrier)(struct shmemi_bootstrap_handle *boot_handle); + int (*finalize)(shmemi_bootstrap_handle *boot_handle); + int (*allgather)(const void *sendbuf, void *recvbuf, int size, shmemi_bootstrap_handle *boot_handle); + int (*barrier)(shmemi_bootstrap_handle *boot_handle); + int (*alltoall)(const void *sendbuf, void *recvbuf, int size, shmemi_bootstrap_handle *boot_handle); + void (*global_exit)(int status); + shmemi_bootstrap_init_ops_t *pre_init_ops; } shmemi_bootstrap_handle_t; typedef struct shmemi_bootstrap_mpi_options { @@ -45,7 +62,7 @@ typedef struct { int32_t pe, npes; shmemi_bootstrap_mpi_options_t mpi_options; - shmemi_bootstrap_tcp_options_t tcp_options; + shmemi_bootstrap_uid_options_t uid_options; // other options bool rdma_enabled; @@ -65,5 +82,6 @@ typedef struct { shmemi_transport_t choosen_transports[SHMEM_MAX_TRANSPORT_NUM]; int32_t num_choosen_transport; } shmemi_host_state_t; - +extern shmemi_bootstrap_handle_t g_boot_handle; +extern shmemi_host_state_t g_host_state; #endif // SHMEMI_HOST_TYPES_H \ No newline at end of file diff --git a/src/host/init/shmem_init_default.cpp b/src/host/init/shmem_init_default.cpp index efbc0a96..a17fcf09 100644 --- a/src/host/init/shmem_init_default.cpp +++ b/src/host/init/shmem_init_default.cpp @@ -52,7 +52,6 @@ constexpr int DEFAULT_BLOCK_NUM = 1; shmemi_device_host_state_t g_state = SHMEM_DEVICE_HOST_STATE_INITIALIZER; shmemi_host_state_t g_state_host = {nullptr, DEFAULT_TEVENT, DEFAULT_BLOCK_NUM}; - shmem_init_attr_t g_attr; static smem_shm_t g_smem_handle = nullptr; static bool g_attr_init = false; @@ -86,7 +85,7 @@ int32_t shmemi_state_init_attr(shmem_init_attr_t *attributes) } // TODO: use shmem native heap init -int32_t shmemi_heap_init_v2(shmem_init_attr_t *attributes) { +int32_t shmemi_heap_init(shmem_init_attr_t *attributes) { // 申请Physical Mem, 申请Virtual Addr,MMAP // 连续映射, VA_size = rank * PA_size // 非连续映射, VA_size = PA_size @@ -198,53 +197,22 @@ void shmem_rank_exit(int status) exit(status); } -void shmem_init_attr(shmem_init_attr_t *attributes) { +int32_t shmem_init_attr(uint32_t bootstrap_flags, shmem_init_attr_t *attributes) { // namespace to be deleted using namespace shm; - - SHM_ASSERT_RETURN(attributes != nullptr, SHMEM_INVALID_PARAM); SHMEM_CHECK_RET(shmem_set_log_level(shm::ERROR_LEVEL)); - SHMEM_CHECK_RET(check_attr(attributes)); + if (!shm::g_attr_init) { + shmemi_bootstrap_attr_t attr = {}; + SHMEM_CHECK_RET(shmemi_bootstrap_init(bootstrap_flags, &attr)); + } SHMEM_CHECK_RET(version_compatible()); SHMEM_CHECK_RET(shmemi_options_init()); - SHMEM_CHECK_RET(shmemi_state_init_attr(attributes)); - - shmemi_bootstrap_init(); - - shmemi_heap_init(attributes); - SHMEM_CHECK_RET(memory_manager_initialize(g_state.heap_base, g_state.heap_size)); - - shmemi_build_transport_map(); - - shmemi_transport_init(); - - SHMEM_CHECK_RET(shmemi_team_init(g_state.mype, g_state.npes)); - SHMEM_CHECK_RET(update_device_state()); - SHMEM_CHECK_RET(shmemi_sync_init()); - g_state.is_shmem_initialized = true; return g_boot_handle.barrier(&g_boot_handle); } -int32_t shmem_finalize_v2() -{ - SHMEM_CHECK_RET(shm::shmemi_team_finalize()); - - if (shm::g_smem_handle != nullptr) { - int32_t status = smem_shm_destroy(shm::g_smem_handle, 0); - if (status != SHMEM_SUCCESS) { - SHM_LOG_ERROR("smem_shm_destroy Failed"); - return SHMEM_SMEM_ERROR; - } - shm::g_smem_handle = nullptr; - } - smem_shm_uninit(0); - smem_uninit(); - return SHMEM_SUCCESS; -} - int32_t shmem_register_decrypt_handler(const shmem_decrypt_handler handler) { return smem_register_decrypt_handler(handler); @@ -285,17 +253,8 @@ int32_t shmem_set_conf_store_tls(bool enable, const char *tls_info, const uint32 int32_t shmem_finalize() { - SHMEM_CHECK_RET(shm::shmemi_team_finalize()); - if (shm::g_smem_handle != nullptr) { - int32_t status = smem_shm_destroy(shm::g_smem_handle, 0); - if (status != SHMEM_SUCCESS) { - SHM_LOG_ERROR("smem_shm_destroy Failed"); - return SHMEM_SMEM_ERROR; - } - shm::g_smem_handle = nullptr; - } - smem_shm_uninit(0); - smem_uninit(); + + shmemi_bootstrap_finalize(); return SHMEM_SUCCESS; } diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp index b8f1e64c..19d69dee 100644 --- a/src/host/transport/shmemi_transport.cpp +++ b/src/host/transport/shmemi_transport.cpp @@ -1,6 +1,6 @@ #include "shmemi_host_common.h" -extern shmemi_host_state_t g_host_state; +shmemi_host_state_t g_host_state; int32_t shmemi_transport_init() { uint32_t num_choosen_transport = 0; diff --git a/src/modules/bootstrap/shmemi_bootstrap_mpi.cpp b/src/modules/bootstrap/shmemi_bootstrap_mpi.cpp index ffb8524a..62f03be0 100644 --- a/src/modules/bootstrap/shmemi_bootstrap_mpi.cpp +++ b/src/modules/bootstrap/shmemi_bootstrap_mpi.cpp @@ -1,31 +1,120 @@ -#ifdef SHMEM_BOOTSTRAP_MPI + +#include +#include +#include +#include +#include #include "shmemi_host_common.h" -#include "mpi.h" -typedef struct { - MPI_COMM comm; -} shmemi_bootstrap_mpi_state_t; +static MPI_Comm shmemi_bootstrap_comm = MPI_COMM_NULL; +static int shmem_initialized_mpi = 0; -static shmemi_bootstrap_mpi_state_t shmemi_bootstrap_mpi_state; +static int shmemi_bootstrap_mpi_barrier(shmemi_bootstrap_handle_t *handle) { + int status = MPI_SUCCESS; -int shmemi_bootstrap_plugin_init(void *args, shmemi_bootstrap_handle_t *boot_handle) { - // INIT + status = MPI_Barrier(shmemi_bootstrap_comm); + SHMEM_CHECK_RET(status); - handle->allgather = shmemi_bootstrap_mpi_allgather; - handle->barrier = shmemi_bootstrap_mpi_allgather; - handle->finalize = shmemi_bootstrap_mpi_finalize; + return status; +} + +static int shmemi_bootstrap_mpi_allgather(const void *sendbuf, void *recvbuf, int length, + shmemi_bootstrap_handle_t *handle) { + int status = MPI_SUCCESS; + + status = MPI_Allgather(sendbuf, length, MPI_BYTE, recvbuf, length, MPI_BYTE, shmemi_bootstrap_comm); + SHMEM_CHECK_RET(status); + + return status; +} + +static int shmemi_bootstrap_mpi_alltoall(const void *sendbuf, void *recvbuf, int length, + shmemi_bootstrap_handle_t *handle) { + int status = MPI_SUCCESS; + + status = MPI_Alltoall(sendbuf, length, MPI_BYTE, recvbuf, length, MPI_BYTE, shmemi_bootstrap_comm); + SHMEM_CHECK_RET(status); + + return status; } -int shmemi_bootstrap_mpi_finalize(shmemi_bootstrap_handle_t *boot_handle) { +static void shmemi_bootstrap_mpi_global_exit(int status) { + int rc = MPI_SUCCESS; + rc = MPI_Abort(shmemi_bootstrap_comm, status); + if (rc != MPI_SUCCESS) { + exit(1); + } } -int shmemi_bootstrap_mpi_allgather(void *dst, void *src, size_t size, shmemi_bootstrap_handle_t *boot_handle) { +static int shmemi_bootstrap_mpi_finalize(shmemi_bootstrap_handle_t *handle) { + int status = MPI_SUCCESS, finalized; + + status = MPI_Finalized(&finalized); + SHMEM_CHECK_RET(status); + + if (finalized) { + if (shmem_initialized_mpi) { + status = SHMEM_INNER_ERROR; + } else { + status = 0; + } + + return status; + } + if (!finalized && shmem_initialized_mpi) { + status = MPI_Comm_free(&shmemi_bootstrap_comm); + SHMEM_CHECK_RET(status); + } + + if (shmem_initialized_mpi) MPI_Finalize(); + + return status; } -int shmemi_bootstrap_mpi_barrier(shmemi_bootstrap_handle_t *boot_handle) { +int shmemi_bootstrap_plugin_init(void *mpi_comm, shmemi_bootstrap_handle_t *handle) { + int status = MPI_SUCCESS, initialized = 0, finalized = 0; + MPI_Comm src_comm; + if (NULL == mpi_comm) + src_comm = MPI_COMM_WORLD; + else + src_comm = *((MPI_Comm *)mpi_comm); + status = MPI_Initialized(&initialized); + SHMEM_CHECK_RET(status); + status = MPI_Finalized(&finalized); + SHMEM_CHECK_RET(status); + if (!initialized && !finalized) { + MPI_Init(NULL, NULL); + shmem_initialized_mpi = 1; + if (src_comm != MPI_COMM_WORLD && src_comm != MPI_COMM_SELF) { + status = SHMEM_INNER_ERROR; + if (shmem_initialized_mpi) { + MPI_Finalize(); + shmem_initialized_mpi = 0; + } + } + } else if (finalized) { + status = SHMEM_INNER_ERROR; + if (shmem_initialized_mpi) { + MPI_Finalize(); + shmem_initialized_mpi = 0; + } + } + status = MPI_Comm_dup(src_comm, &shmemi_bootstrap_comm); + SHMEM_CHECK_RET(status); + status = MPI_Comm_rank(shmemi_bootstrap_comm, &handle->mype); + SHMEM_CHECK_RET(status); + status = MPI_Comm_size(shmemi_bootstrap_comm, &handle->npes); + SHMEM_CHECK_RET(status); + handle->allgather = shmemi_bootstrap_mpi_allgather; + handle->alltoall = shmemi_bootstrap_mpi_alltoall; + handle->barrier = shmemi_bootstrap_mpi_barrier; + handle->global_exit = shmemi_bootstrap_mpi_global_exit; + handle->finalize = shmemi_bootstrap_mpi_finalize; + handle->pre_init_ops = NULL; + handle->bootstrap_state = &shmemi_bootstrap_comm; + return status; } -#endif // SHMEM_BOOTSTRAP_MPI \ No newline at end of file -- Gitee From 4de2616e840e3558c5dd5f038ecef88b3a7d8686 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Thu, 18 Sep 2025 20:46:12 +0800 Subject: [PATCH 26/74] shmem init reconstruct mf backend version --- examples/rdma_demo/main.cpp | 2 +- examples/rdma_perftest/main.cpp | 2 +- src/host/common/shmemi_host_types.h | 2 +- src/host/init/init_impl/shmemi_init_base.h | 24 ++ src/host/init/init_impl/shmemi_init_mf.cpp | 117 ++++++++++ src/host/init/init_impl/shmemi_init_mf.h | 30 +++ .../init/init_impl/shmemi_init_normal.cpp | 26 +++ src/host/init/init_impl/shmemi_init_normal.h | 26 +++ src/host/init/shmem_init_default.cpp | 2 +- src/host/init/shmem_init_mf.cpp | 212 +++++++++--------- src/host/init/shmemi_init.h | 6 +- src/host/mem/shmem_mm.cpp | 6 +- src/host/mem/shmem_rma.cpp | 2 +- src/host/team/shmem_team.cpp | 4 +- src/host/transport/shmemi_transport.cpp | 80 +++---- 15 files changed, 384 insertions(+), 157 deletions(-) create mode 100644 src/host/init/init_impl/shmemi_init_base.h create mode 100644 src/host/init/init_impl/shmemi_init_mf.cpp create mode 100644 src/host/init/init_impl/shmemi_init_mf.h create mode 100644 src/host/init/init_impl/shmemi_init_normal.cpp create mode 100644 src/host/init/init_impl/shmemi_init_normal.h diff --git a/examples/rdma_demo/main.cpp b/examples/rdma_demo/main.cpp index bf1ff56e..6e0d222a 100644 --- a/examples/rdma_demo/main.cpp +++ b/examples/rdma_demo/main.cpp @@ -55,7 +55,7 @@ int test_shmem_team_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size // AllGather allgather_demo(1, stream, (uint8_t *)ptr, trans_size * sizeof(int32_t)); status = aclrtSynchronizeStream(stream); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); // 结果校验打印 int32_t *y_host; diff --git a/examples/rdma_perftest/main.cpp b/examples/rdma_perftest/main.cpp index 62c39bea..d29537af 100644 --- a/examples/rdma_perftest/main.cpp +++ b/examples/rdma_perftest/main.cpp @@ -203,7 +203,7 @@ int test_shmem_rdma_mte_put_bw(int rank_id, int n_ranks, uint64_t local_mem_size inHost[i + (rank_id + n_ranks) * message_length / sizeof(int64_t)] = rank_id + 10 + iter; } aclrtMemcpy(gva, totalSize, inHost, totalSize, ACL_MEMCPY_HOST_TO_DEVICE); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); rdma_mte_put_bw_do(1, stream, fftsConfig, gva, message_length, iter); aclrtSynchronizeStream(stream); if (rank_id == 0 && iter >= 10) { diff --git a/src/host/common/shmemi_host_types.h b/src/host/common/shmemi_host_types.h index b9f1c528..d6d15cd3 100644 --- a/src/host/common/shmemi_host_types.h +++ b/src/host/common/shmemi_host_types.h @@ -45,7 +45,7 @@ typedef struct { int32_t pe, npes; shmemi_bootstrap_mpi_options_t mpi_options; - shmemi_bootstrap_tcp_options_t tcp_options; + shmemi_bootstrap_uid_options_t tcp_options; // other options bool rdma_enabled; diff --git a/src/host/init/init_impl/shmemi_init_base.h b/src/host/init/init_impl/shmemi_init_base.h new file mode 100644 index 00000000..dbea6b74 --- /dev/null +++ b/src/host/init/init_impl/shmemi_init_base.h @@ -0,0 +1,24 @@ +#ifndef SHMEMI_INIT_BASE_H +#define SHMEMI_INIT_BASE_H + +#include + +#include "acl/acl.h" +#include "internal/host_device/shmemi_types.h" + +class init_base { +public: + virtual int init_device_state() = 0; + virtual int update_device_state(void* host_ptr, size_t size) = 0; + virtual int heap_init(shmemi_device_host_state_t &g_state) = 0; + virtual int heap_finalize() = 0; + + virtual int barrier_all() = 0; + + virtual ~init_base() { + std::cout << "init_base destructor called. " << std::endl; + } + +}; + +#endif // SHMEMI_INIT_BASE_H \ No newline at end of file diff --git a/src/host/init/init_impl/shmemi_init_mf.cpp b/src/host/init/init_impl/shmemi_init_mf.cpp new file mode 100644 index 00000000..669b1e6c --- /dev/null +++ b/src/host/init/init_impl/shmemi_init_mf.cpp @@ -0,0 +1,117 @@ +#include "shmemi_init_mf.h" + +constexpr int DEFAULT_FLAG = 0; +constexpr int DEFAULT_ID = 0; +constexpr int DEFAULT_TIMEOUT = 120; +constexpr int DEFAULT_TEVENT = 0; +constexpr int DEFAULT_BLOCK_NUM = 1; + +// smem need +static smem_shm_t g_smem_handle = nullptr; +static char *g_ipport = nullptr; + +init_mf::init_mf(shmem_init_attr_t *attr, char *ipport) +{ + attributes = attr; + g_ipport = ipport; + + aclrtGetDevice(&device_id); + int32_t status = smem_init(DEFAULT_FLAG); + if (status != SHMEM_SUCCESS) { + SHM_LOG_ERROR("smem_init Failed"); + } +} + +int init_mf::init_device_state() +{ + int32_t status = SHMEM_SUCCESS; + smem_shm_config_t config; + status = smem_shm_config_init(&config); + if (status != SHMEM_SUCCESS) { + SHM_LOG_ERROR("smem_shm_config_init Failed"); + return SHMEM_SMEM_ERROR; + } + + status = smem_shm_init(attributes->ip_port, attributes->n_ranks, attributes->my_rank, device_id, &config); + if (status != SHMEM_SUCCESS) { + SHM_LOG_ERROR("smem_shm_init Failed"); + return SHMEM_SMEM_ERROR; + } + + config.shmInitTimeout = attributes->option_attr.shm_init_timeout; + config.shmCreateTimeout = attributes->option_attr.shm_create_timeout; + config.controlOperationTimeout = attributes->option_attr.control_operation_timeout; + + return SHMEM_SUCCESS; +} + +int init_mf::update_device_state(void* host_ptr, size_t size) +{ + if (g_smem_handle == nullptr) { + SHM_LOG_ERROR("smem_shm_create Not Success, update_device_state Failed"); + return SHMEM_SMEM_ERROR; + } + return smem_shm_set_extra_context(g_smem_handle, host_ptr, size); +} + +int init_mf::heap_init(shmemi_device_host_state_t &g_state) +{ + int32_t status = SHMEM_SUCCESS; + void *gva = nullptr; + g_smem_handle = smem_shm_create(DEFAULT_ID, attributes->n_ranks, attributes->my_rank, g_state.heap_size, + static_cast(attributes->option_attr.data_op_engine_type), + DEFAULT_FLAG, &gva); + + if (g_smem_handle == nullptr || gva == nullptr) { + SHM_LOG_ERROR("smem_shm_create Failed"); + return SHMEM_SMEM_ERROR; + } + g_state.heap_base = (void *)((uintptr_t)gva + g_state.heap_size * attributes->my_rank); + uint32_t reach_info = 0; + for (int32_t i = 0; i < g_state.npes; i++) { + status = smem_shm_topology_can_reach(g_smem_handle, i, &reach_info); + g_state.p2p_heap_base[i] = (void *)((uintptr_t)gva + g_state.heap_size * i); + if (reach_info & SMEMS_DATA_OP_MTE) { + g_state.topo_list[i] |= SHMEM_TRANSPORT_MTE; + } + if (reach_info & SMEMS_DATA_OP_SDMA) { + g_state.sdma_heap_base[i] = (void *)((uintptr_t)gva + g_state.heap_size * i); + } else { + g_state.sdma_heap_base[i] = NULL; + } + if (reach_info & SMEMS_DATA_OP_RDMA) { + g_state.topo_list[i] |= SHMEM_TRANSPORT_ROCE; + } + } + if (g_ipport != nullptr) { + delete[] g_ipport; + g_ipport = nullptr; + attributes->ip_port = nullptr; + } else { + SHM_LOG_WARN("my_rank:" << attributes->my_rank << " g_ipport is released in advance!"); + attributes->ip_port = nullptr; + } + g_state.is_shmem_created = true; + return status; +} + +int init_mf::heap_finalize() +{ + if (g_smem_handle != nullptr) { + int32_t status = smem_shm_destroy(g_smem_handle, 0); + if (status != SHMEM_SUCCESS) { + SHM_LOG_ERROR("smem_shm_destroy Failed"); + return SHMEM_SMEM_ERROR; + } + g_smem_handle = nullptr; + } + smem_shm_uninit(0); + smem_uninit(); + return SHMEM_SUCCESS; +} + +int init_mf::barrier_all() +{ + SHM_ASSERT_RETURN(g_smem_handle != nullptr, SHMEM_INVALID_PARAM); + return smem_shm_control_barrier(g_smem_handle); +} \ No newline at end of file diff --git a/src/host/init/init_impl/shmemi_init_mf.h b/src/host/init/init_impl/shmemi_init_mf.h new file mode 100644 index 00000000..d2266e72 --- /dev/null +++ b/src/host/init/init_impl/shmemi_init_mf.h @@ -0,0 +1,30 @@ +#ifndef SHMEMI_INIT_MF_H +#define SHMEMI_INIT_MF_H + +#include + +#include "shmemi_init_base.h" +#include "shmemi_host_common.h" +#include "internal/host_device/shmemi_types.h" + +class init_mf: public init_base { +public: + int init_device_state() override; + int update_device_state(void* host_ptr, size_t size) override; + int heap_init(shmemi_device_host_state_t &g_state) override; + int heap_finalize() override; + + int barrier_all() override; + + init_mf(shmem_init_attr_t *attr, char *ipport); + ~init_mf() { + std::cout << "init_mf destructor called. " << std::endl; + } +private: + int32_t device_id; + + shmem_init_attr_t *attributes; + char *g_ipport = nullptr; +}; + +#endif // SHMEMI_INIT_MF_H \ No newline at end of file diff --git a/src/host/init/init_impl/shmemi_init_normal.cpp b/src/host/init/init_impl/shmemi_init_normal.cpp new file mode 100644 index 00000000..fbbacb0a --- /dev/null +++ b/src/host/init/init_impl/shmemi_init_normal.cpp @@ -0,0 +1,26 @@ +#include "shmemi_init_normal.h" + +int init_normal::init_device_state() +{ + return SHMEM_SUCCESS; +} + +int init_normal::update_device_state(void* host_ptr, size_t size) +{ + return SHMEM_SUCCESS; +} + +int init_normal::heap_init(shmemi_device_host_state_t &g_state) +{ + return SHMEM_SUCCESS; +} + +int init_normal::heap_finalize() +{ + return SHMEM_SUCCESS; +} + +int init_normal::barrier_all() +{ + return SHMEM_SUCCESS; +} \ No newline at end of file diff --git a/src/host/init/init_impl/shmemi_init_normal.h b/src/host/init/init_impl/shmemi_init_normal.h new file mode 100644 index 00000000..46a98dca --- /dev/null +++ b/src/host/init/init_impl/shmemi_init_normal.h @@ -0,0 +1,26 @@ +#ifndef SHMEMI_INIT_NORMAL_H +#define SHMEMI_INIT_NORMAL_H + +#include + +#include "shmemi_init_base.h" +#include "shmemi_host_common.h" +#include "internal/host_device/shmemi_types.h" + + +class init_normal: public init_base { +public: + int init_device_state() override; + int update_device_state(void* host_ptr, size_t size) override; + int heap_init(shmemi_device_host_state_t &g_state) override; + int heap_finalize() override; + + int barrier_all() override; + + ~init_normal() { + std::cout << "init_normal destructor called. " << std::endl; + } + +}; + +#endif // SHMEMI_INIT_NORMAL_H \ No newline at end of file diff --git a/src/host/init/shmem_init_default.cpp b/src/host/init/shmem_init_default.cpp index efbc0a96..4e67cc07 100644 --- a/src/host/init/shmem_init_default.cpp +++ b/src/host/init/shmem_init_default.cpp @@ -210,7 +210,7 @@ void shmem_init_attr(shmem_init_attr_t *attributes) { SHMEM_CHECK_RET(shmemi_state_init_attr(attributes)); - shmemi_bootstrap_init(); + shmemi_bootstrap_init(); shmemi_heap_init(attributes); SHMEM_CHECK_RET(memory_manager_initialize(g_state.heap_base, g_state.heap_size)); diff --git a/src/host/init/shmem_init_mf.cpp b/src/host/init/shmem_init_mf.cpp index d0d73e5f..80b70653 100644 --- a/src/host/init/shmem_init_mf.cpp +++ b/src/host/init/shmem_init_mf.cpp @@ -53,8 +53,8 @@ constexpr int DEFAULT_BLOCK_NUM = 1; shmemi_device_host_state_t g_state = SHMEM_DEVICE_HOST_STATE_INITIALIZER; shmemi_host_state_t g_state_host = {nullptr, DEFAULT_TEVENT, DEFAULT_BLOCK_NUM}; shmem_init_attr_t g_attr; -static smem_shm_t g_smem_handle = nullptr; -static bool g_attr_init = false; +// static smem_shm_t g_smem_handle = nullptr; +// static bool g_attr_init = false; static char *g_ipport = nullptr; int32_t version_compatible() @@ -84,84 +84,78 @@ int32_t shmemi_state_init_attr(shmem_init_attr_t *attributes) return status; } -int32_t shmemi_heap_init(shmem_init_attr_t *attributes) -{ - void *gva = nullptr; - int32_t status = SHMEM_SUCCESS; - int32_t device_id; - SHMEM_CHECK_RET(aclrtGetDevice(&device_id)); - - status = smem_init(DEFAULT_FLAG); - if (status != SHMEM_SUCCESS) { - SHM_LOG_ERROR("smem_init Failed"); - return SHMEM_SMEM_ERROR; - } - smem_shm_config_t config; - status = smem_shm_config_init(&config); - if (status != SHMEM_SUCCESS) { - SHM_LOG_ERROR("smem_shm_config_init Failed"); - return SHMEM_SMEM_ERROR; - } - status = smem_shm_init(attributes->ip_port, attributes->n_ranks, attributes->my_rank, device_id, &config); - if (status != SHMEM_SUCCESS) { - SHM_LOG_ERROR("smem_shm_init Failed"); - return SHMEM_SMEM_ERROR; - } - - config.shmInitTimeout = attributes->option_attr.shm_init_timeout; - config.shmCreateTimeout = attributes->option_attr.shm_create_timeout; - config.controlOperationTimeout = attributes->option_attr.control_operation_timeout; - - g_smem_handle = smem_shm_create(DEFAULT_ID, attributes->n_ranks, attributes->my_rank, g_state.heap_size, - static_cast(attributes->option_attr.data_op_engine_type), - DEFAULT_FLAG, &gva); - - if (g_smem_handle == nullptr || gva == nullptr) { - SHM_LOG_ERROR("smem_shm_create Failed"); - return SHMEM_SMEM_ERROR; - } - g_state.heap_base = (void *)((uintptr_t)gva + g_state.heap_size * attributes->my_rank); - uint32_t reach_info = 0; - for (int32_t i = 0; i < g_state.npes; i++) { - status = smem_shm_topology_can_reach(g_smem_handle, i, &reach_info); - g_state.p2p_heap_base[i] = (void *)((uintptr_t)gva + g_state.heap_size * i); - if (reach_info & SMEMS_DATA_OP_MTE) { - g_state.topo_list[i] |= SHMEM_TRANSPORT_MTE; - } - if (reach_info & SMEMS_DATA_OP_SDMA) { - g_state.sdma_heap_base[i] = (void *)((uintptr_t)gva + g_state.heap_size * i); - } else { - g_state.sdma_heap_base[i] = NULL; - } - if (reach_info & SMEMS_DATA_OP_RDMA) { - g_state.topo_list[i] |= SHMEM_TRANSPORT_ROCE; - } - } - if (shm::g_ipport != nullptr) { - delete[] shm::g_ipport; - shm::g_ipport = nullptr; - attributes->ip_port = nullptr; - } else { - SHM_LOG_WARN("my_rank:" << attributes->my_rank << " shm::g_ipport is released in advance!"); - attributes->ip_port = nullptr; - } - g_state.is_shmem_created = true; - return status; -} - -int32_t shmemi_control_barrier_all() -{ - SHM_ASSERT_RETURN(g_smem_handle != nullptr, SHMEM_INVALID_PARAM); - return smem_shm_control_barrier(g_smem_handle); -} - -int32_t update_device_state() -{ - if (!g_state.is_shmem_created) { - return SHMEM_NOT_INITED; - } - return smem_shm_set_extra_context(g_smem_handle, (void *)&g_state, sizeof(shmemi_device_host_state_t)); -} +// int32_t shmemi_heap_init(shmem_init_attr_t *attributes) +// { +// void *gva = nullptr; +// int32_t status = SHMEM_SUCCESS; +// int32_t device_id; +// SHMEM_CHECK_RET(aclrtGetDevice(&device_id)); + +// status = smem_init(DEFAULT_FLAG); +// if (status != SHMEM_SUCCESS) { +// SHM_LOG_ERROR("smem_init Failed"); +// return SHMEM_SMEM_ERROR; +// } +// smem_shm_config_t config; +// status = smem_shm_config_init(&config); +// if (status != SHMEM_SUCCESS) { +// SHM_LOG_ERROR("smem_shm_config_init Failed"); +// return SHMEM_SMEM_ERROR; +// } +// status = smem_shm_init(attributes->ip_port, attributes->n_ranks, attributes->my_rank, device_id, &config); +// if (status != SHMEM_SUCCESS) { +// SHM_LOG_ERROR("smem_shm_init Failed"); +// return SHMEM_SMEM_ERROR; +// } + +// config.shmInitTimeout = attributes->option_attr.shm_init_timeout; +// config.shmCreateTimeout = attributes->option_attr.shm_create_timeout; +// config.controlOperationTimeout = attributes->option_attr.control_operation_timeout; + +// g_smem_handle = smem_shm_create(DEFAULT_ID, attributes->n_ranks, attributes->my_rank, g_state.heap_size, +// static_cast(attributes->option_attr.data_op_engine_type), +// DEFAULT_FLAG, &gva); + +// if (g_smem_handle == nullptr || gva == nullptr) { +// SHM_LOG_ERROR("smem_shm_create Failed"); +// return SHMEM_SMEM_ERROR; +// } +// g_state.heap_base = (void *)((uintptr_t)gva + g_state.heap_size * attributes->my_rank); +// uint32_t reach_info = 0; +// for (int32_t i = 0; i < g_state.npes; i++) { +// status = smem_shm_topology_can_reach(g_smem_handle, i, &reach_info); +// g_state.p2p_heap_base[i] = (void *)((uintptr_t)gva + g_state.heap_size * i); +// if (reach_info & SMEMS_DATA_OP_MTE) { +// g_state.topo_list[i] |= SHMEM_TRANSPORT_MTE; +// } +// if (reach_info & SMEMS_DATA_OP_SDMA) { +// g_state.sdma_heap_base[i] = (void *)((uintptr_t)gva + g_state.heap_size * i); +// } else { +// g_state.sdma_heap_base[i] = NULL; +// } +// if (reach_info & SMEMS_DATA_OP_RDMA) { +// g_state.topo_list[i] |= SHMEM_TRANSPORT_ROCE; +// } +// } +// if (shm::g_ipport != nullptr) { +// delete[] shm::g_ipport; +// shm::g_ipport = nullptr; +// attributes->ip_port = nullptr; +// } else { +// SHM_LOG_WARN("my_rank:" << attributes->my_rank << " shm::g_ipport is released in advance!"); +// attributes->ip_port = nullptr; +// } +// g_state.is_shmem_created = true; +// return status; +// } + +// int32_t update_device_state() +// { +// if (!g_state.is_shmem_created) { +// return SHMEM_NOT_INITED; +// } +// return smem_shm_set_extra_context(g_smem_handle, (void *)&g_state, sizeof(shmemi_device_host_state_t)); +// } int32_t check_attr(shmem_init_attr_t *attributes) { @@ -184,6 +178,20 @@ int32_t check_attr(shmem_init_attr_t *attributes) } // namespace shm +init_base* init_manager; + +extern shmemi_device_host_state_t shm::g_state; + +int32_t shmemi_control_barrier_all() +{ + return init_manager->barrier_all(); +} + +int32_t update_device_state() +{ + return init_manager->update_device_state((void *)&shm::g_state, sizeof(shmemi_device_host_state_t)); +} + int32_t shmem_set_data_op_engine_type(shmem_init_attr_t *attributes, data_op_engine_type_t value) { attributes->option_attr.data_op_engine_type = value; @@ -233,7 +241,7 @@ int32_t shmem_set_attr(int32_t my_rank, int32_t n_ranks, uint64_t local_mem_size shm::g_attr.local_mem_size = local_mem_size; shm::g_attr.option_attr = {attr_version, SHMEM_DATA_OP_MTE, shm::DEFAULT_TIMEOUT, shm::DEFAULT_TIMEOUT, shm::DEFAULT_TIMEOUT}; - shm::g_attr_init = true; + // shm::g_attr_init = true; return SHMEM_SUCCESS; } @@ -249,15 +257,16 @@ int32_t shmem_init_status() return SHMEM_STATUS_INVALID; } -void shmem_rank_exit(int status) -{ - SHM_LOG_DEBUG("shmem_rank_exit is work ,status: " << status); - exit(status); -} +// void shmem_rank_exit(int status) +// { +// SHM_LOG_DEBUG("shmem_rank_exit is work ,status: " << status); +// exit(status); +// } int32_t shmem_init_attr(shmem_init_attr_t *attributes) { int32_t ret; + init_manager = new init_mf(attributes, shm::g_ipport); SHM_ASSERT_RETURN(attributes != nullptr, SHMEM_INVALID_PARAM); SHMEM_CHECK_RET(shmem_set_log_level(shm::ERROR_LEVEL)); @@ -266,16 +275,18 @@ int32_t shmem_init_attr(shmem_init_attr_t *attributes) SHMEM_CHECK_RET(shm::shmemi_options_init()); SHMEM_CHECK_RET(shm::shmemi_state_init_attr(attributes)); - SHMEM_CHECK_RET(shm::shmemi_heap_init(attributes)); - SHMEM_CHECK_RET(shm::update_device_state()); + SHMEM_CHECK_RET(init_manager->init_device_state()); + SHMEM_CHECK_RET(init_manager->heap_init(shm::g_state)); + // SHMEM_CHECK_RET(shm::shmemi_heap_init(attributes)); + SHMEM_CHECK_RET(update_device_state()); SHMEM_CHECK_RET(shm::memory_manager_initialize(shm::g_state.heap_base, shm::g_state.heap_size)); SHMEM_CHECK_RET(shm::shmemi_team_init(shm::g_state.mype, shm::g_state.npes)); - SHMEM_CHECK_RET(shm::update_device_state()); + SHMEM_CHECK_RET(update_device_state()); SHMEM_CHECK_RET(shm::shmemi_sync_init()); - SHMEM_CHECK_RET(smem_shm_register_exit(shm::g_smem_handle, &shmem_rank_exit)); + // SHMEM_CHECK_RET(smem_shm_register_exit(shm::g_smem_handle, &shmem_rank_exit)); shm::g_state.is_shmem_initialized = true; - SHMEM_CHECK_RET(shm::shmemi_control_barrier_all()); + SHMEM_CHECK_RET(shmemi_control_barrier_all()); return SHMEM_SUCCESS; } @@ -320,16 +331,7 @@ int32_t shmem_set_conf_store_tls(bool enable, const char *tls_info, const uint32 int32_t shmem_finalize() { SHMEM_CHECK_RET(shm::shmemi_team_finalize()); - if (shm::g_smem_handle != nullptr) { - int32_t status = smem_shm_destroy(shm::g_smem_handle, 0); - if (status != SHMEM_SUCCESS) { - SHM_LOG_ERROR("smem_shm_destroy Failed"); - return SHMEM_SMEM_ERROR; - } - shm::g_smem_handle = nullptr; - } - smem_shm_uninit(0); - smem_uninit(); + SHMEM_CHECK_RET(init_manager->heap_finalize()); return SHMEM_SUCCESS; } @@ -353,7 +355,7 @@ void shmem_info_get_name(char *name) name[i] = '\0'; } -void shmem_global_exit(int status) -{ - smem_shm_global_exit(shm::g_smem_handle, status); -} +// void shmem_global_exit(int status) +// { +// smem_shm_global_exit(shm::g_smem_handle, status); +// } diff --git a/src/host/init/shmemi_init.h b/src/host/init/shmemi_init.h index 73613e25..7f461b9e 100644 --- a/src/host/init/shmemi_init.h +++ b/src/host/init/shmemi_init.h @@ -12,15 +12,17 @@ #include "stdint.h" #include "internal/host_device/shmemi_types.h" +#include "init/init_impl/shmemi_init_mf.h" +#include "init/init_impl/shmemi_init_mf.h" namespace shm { extern shmemi_device_host_state_t g_state; extern shmemi_host_state_t g_state_host; -int32_t update_device_state(void); - int32_t shmemi_control_barrier_all(); } // namespace shm +int32_t update_device_state(void); + #endif // SHMEMI_INIT_H diff --git a/src/host/mem/shmem_mm.cpp b/src/host/mem/shmem_mm.cpp index 40da342c..6455c624 100644 --- a/src/host/mem/shmem_mm.cpp +++ b/src/host/mem/shmem_mm.cpp @@ -41,7 +41,7 @@ void *shmem_malloc(size_t size) void *ptr = shm::shm_memory_heap->allocate(size); SHM_LOG_DEBUG("shmem_malloc(" << size << ")"); - auto ret = shm::shmemi_control_barrier_all(); + auto ret = shmemi_control_barrier_all(); if (ret != 0) { SHM_LOG_ERROR("malloc mem barrier failed, ret: " << ret); if (ptr != nullptr) { @@ -71,7 +71,7 @@ void *shmem_calloc(size_t nmemb, size_t size) } } - auto ret = shm::shmemi_control_barrier_all(); + auto ret = shmemi_control_barrier_all(); if (ret != 0) { SHM_LOG_ERROR("calloc mem barrier failed, ret: " << ret); if (ptr != nullptr) { @@ -92,7 +92,7 @@ void *shmem_align(size_t alignment, size_t size) } auto ptr = shm::shm_memory_heap->aligned_allocate(alignment, size); - auto ret = shm::shmemi_control_barrier_all(); + auto ret = shmemi_control_barrier_all(); if (ret != 0) { SHM_LOG_ERROR("shmem_align barrier failed, ret: " << ret); if (ptr != nullptr) { diff --git a/src/host/mem/shmem_rma.cpp b/src/host/mem/shmem_rma.cpp index d842dade..635012b6 100644 --- a/src/host/mem/shmem_rma.cpp +++ b/src/host/mem/shmem_rma.cpp @@ -45,7 +45,7 @@ int32_t shmem_mte_set_ub_params(uint64_t offset, uint32_t ub_size, uint32_t even shm::g_state.mte_config.shmem_ub = offset; shm::g_state.mte_config.ub_size = ub_size; shm::g_state.mte_config.event_id = event_id; - SHMEM_CHECK_RET(shm::update_device_state()); + SHMEM_CHECK_RET(update_device_state()); return SHMEM_SUCCESS; } diff --git a/src/host/team/shmem_team.cpp b/src/host/team/shmem_team.cpp index 066c2b24..1a56a713 100644 --- a/src/host/team/shmem_team.cpp +++ b/src/host/team/shmem_team.cpp @@ -259,7 +259,7 @@ int32_t shmem_team_split_strided(shmem_team_t parent_team, int32_t pe_start, int SHM_LOG_ERROR("create team failed, malloc device state failed!"); return SHMEM_INNER_ERROR; } - if (shm::update_device_state() != 0) { + if (update_device_state() != 0) { shmem_team_destroy(my_team.team_idx); SHM_LOG_ERROR("create team failed, update state failed!"); return SHMEM_INNER_ERROR; @@ -382,7 +382,7 @@ void shmem_team_destroy(shmem_team_t team) shm::device_team_destroy(team); shm::g_team_mask ^= 1ULL << team; - if (shm::update_device_state() != SHMEM_SUCCESS) { + if (update_device_state() != SHMEM_SUCCESS) { SHM_LOG_WARN("update state failed when destroy team!"); } } diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp index b8f1e64c..c58e1b18 100644 --- a/src/host/transport/shmemi_transport.cpp +++ b/src/host/transport/shmemi_transport.cpp @@ -1,40 +1,40 @@ -#include "shmemi_host_common.h" - -extern shmemi_host_state_t g_host_state; - -int32_t shmemi_transport_init() { - uint32_t num_choosen_transport = 0; - -// #ifdef SHMEM_CONTINUOUS_ADDRESS_SPACE -// g_host_state.choosen_transports[num_choosen_transport++] = &g_transport_mte_c; -// g_host_state.choosen_transports[num_choosen_transport++] = &g_transport_rdma_c; -// #else -// g_host_state.choosen_transports[num_choosen_transport++] = &g_transport_mte_d; -// g_host_state.choosen_transports[num_choosen_transport++] = &g_transport_rdma_d; -// #endif - - g_host_state.num_choosen_transport = num_choosen_transport; - - for (int i = 0; i < num_choosen_transport; i++) { - auto t = g_host_state.choosen_transports + i; - t->boot_handle = g_host_state.boot_handle; - } -} - -int32_t shmemi_build_transport_map() { - // fill p2p/rdma/sdma heap bases -} - -int32_t shmemi_transport_setup_connections() { - for (int i = 0; i < g_host_state.num_choosen_transport; i++) { - auto t = g_host_state.choosen_transports + i; - t->connect_peers(t, nullptr, 0); - } -} - -int32_t shmemi_transport_finalize() { - for (int i = g_host_state.num_choosen_transport - 1; i >= 0; i--) { - auto t = g_host_state.choosen_transports + i; - t->finalize(t); - } -} \ No newline at end of file +// #include "shmemi_host_common.h" + +// extern shmemi_host_state_t g_host_state; + +// int32_t shmemi_transport_init() { +// uint32_t num_choosen_transport = 0; + +// // #ifdef SHMEM_CONTINUOUS_ADDRESS_SPACE +// // g_host_state.choosen_transports[num_choosen_transport++] = &g_transport_mte_c; +// // g_host_state.choosen_transports[num_choosen_transport++] = &g_transport_rdma_c; +// // #else +// // g_host_state.choosen_transports[num_choosen_transport++] = &g_transport_mte_d; +// // g_host_state.choosen_transports[num_choosen_transport++] = &g_transport_rdma_d; +// // #endif + +// g_host_state.num_choosen_transport = num_choosen_transport; + +// for (int i = 0; i < num_choosen_transport; i++) { +// auto t = g_host_state.choosen_transports + i; +// t->boot_handle = g_host_state.boot_handle; +// } +// } + +// int32_t shmemi_build_transport_map() { +// // fill p2p/rdma/sdma heap bases +// } + +// int32_t shmemi_transport_setup_connections() { +// for (int i = 0; i < g_host_state.num_choosen_transport; i++) { +// auto t = g_host_state.choosen_transports + i; +// t->connect_peers(t, nullptr, 0); +// } +// } + +// int32_t shmemi_transport_finalize() { +// for (int i = g_host_state.num_choosen_transport - 1; i >= 0; i--) { +// auto t = g_host_state.choosen_transports + i; +// t->finalize(t); +// } +// } \ No newline at end of file -- Gitee From 46d6537a9df2ea83213f06cda4b1114ff3b2d905 Mon Sep 17 00:00:00 2001 From: zhangyunqi05 Date: Fri, 19 Sep 2025 19:12:11 +0800 Subject: [PATCH 27/74] rectification by review comments --- examples/hellow_word/main.cpp | 9 ++ include/host/shmem_host_init.h | 2 +- src/CMakeLists.txt | 4 +- src/host/bootstrap/shmemi_bootstrap.cpp | 123 +++++++----------- src/host/common/shmemi_host_types.h | 9 ++ src/host/init/shmem_init_default.cpp | 2 +- .../bootstrap/shmemi_bootstrap_mpi.cpp | 56 +++++--- .../bootstrap/shmemi_bootstrap_uid.cpp | 9 ++ src/modules/transport/shmemi_mte.cpp | 9 ++ src/modules/transport/shmemi_rdma.cpp | 9 ++ 10 files changed, 133 insertions(+), 99 deletions(-) diff --git a/examples/hellow_word/main.cpp b/examples/hellow_word/main.cpp index 3798c73f..8ecf2220 100644 --- a/examples/hellow_word/main.cpp +++ b/examples/hellow_word/main.cpp @@ -1,3 +1,12 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #include #include #include diff --git a/include/host/shmem_host_init.h b/include/host/shmem_host_init.h index ff691974..52b1ba2d 100644 --- a/include/host/shmem_host_init.h +++ b/include/host/shmem_host_init.h @@ -69,7 +69,7 @@ SHMEM_HOST_API int shmem_set_timeout(shmem_init_attr_t *attributes, uint32_t val * @param attributes [in] Pointer to the user-defined attributes. * @return Returns 0 on success or an error code on failure */ -SHMEM_HOST_API int shmem_init_attr(uint32_t bootstrap_flags, shmem_init_attr_t *attributes); +SHMEM_HOST_API int shmem_init_attr(shmemx_bootstrap_t bootstrap_flags, shmem_init_attr_t *attributes); /** * @brief Register a decrypt key password handler. diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f1c0f54e..46cb010e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -68,10 +68,8 @@ if(SHMEM_MPI_SUPPORT) target_link_libraries(shmem_bootstrap_mpi PRIVATE MPI::MPI_CXX) target_include_directories(shmem_bootstrap_mpi PRIVATE - ${PROJECT_SOURCE_DIR}/include/ + ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src/host - ${PROJECT_SOURCE_DIR}/src/device - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/host/ ) set_target_properties(shmem_bootstrap_mpi PROPERTIES PREFIX "") install(TARGETS shmem_bootstrap_mpi diff --git a/src/host/bootstrap/shmemi_bootstrap.cpp b/src/host/bootstrap/shmemi_bootstrap.cpp index 6d8e3dce..63e64590 100644 --- a/src/host/bootstrap/shmemi_bootstrap.cpp +++ b/src/host/bootstrap/shmemi_bootstrap.cpp @@ -1,3 +1,12 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #include "shmemi_host_common.h" #include "dlfcn.h" @@ -6,45 +15,11 @@ #define BOOTSTRAP_PLUGIN_INIT_FUNC "shmemi_bootstrap_plugin_init" -#define shmemxi_error_unlikely(x) __builtin_expect(!!(x), 0) - -#define SHMEMI_NULL_ERROR_JMP(var, status, err, ...) \ - do { \ - if (shmemxi_error_unlikely(var == NULL)) { \ - fprintf(stderr, "%s:%d: NULL value ", __FILE__, __LINE__); \ - fprintf(stderr, __VA_ARGS__); \ - fprintf(stderr, "\n"); \ - status = err; \ - } \ - } while (0) - -#define GET_SYMBOL(lib_handle, name, var, status) \ - do { \ - void **var_ptr = (void **)&(var); \ - void *tmp = (void *)dlsym(lib_handle, name); \ - SHMEMI_NULL_ERROR_JMP( \ - tmp, status, SHMEM_INNER_ERROR, "Bootstrap failed to get symbol '%s'\n\t%s\n", name, dlerror()); \ - *var_ptr = tmp; \ - } while (0) - shmemi_bootstrap_handle_t g_boot_handle; static void *plugin_hdl = nullptr; static char *plugin_name = nullptr; -void _bootstrap_loader_fini_helper(void *plugin_hdl, char *plugin_name) -{ - if (plugin_hdl != nullptr) { - dlclose(plugin_hdl); - plugin_hdl = nullptr; - } - - if (plugin_name != nullptr) { - free(plugin_name); - plugin_name = nullptr; - } -} - int bootstrap_loader_finalize(shmemi_bootstrap_handle_t *handle) { int status = handle->finalize(handle); @@ -60,66 +35,68 @@ int bootstrap_loader_finalize(shmemi_bootstrap_handle_t *handle) return 0; } -static int _bootstrap_loader_init_helper(const char *plugin, shmemi_bootstrap_handle_t *handle) +// for UID +int32_t shmemi_bootstrap_pre_init() { + +} + +void shmemi_bootstrap_loader() { dlerror(); - if (plugin_name == nullptr) { - plugin_name = strdup(plugin); - if (!plugin_name) { - SHM_LOG_ERROR("Failed to strdup plugin name, err is: " << stderr); - return SHMEM_INVALID_VALUE; - } - } - if (plugin_hdl == nullptr) { - plugin_hdl = dlopen(plugin, RTLD_NOW); - } - dlerror(); - if (!plugin_hdl) { - SHM_LOG_ERROR("Bootstrap unable to load " << plugin << ", err is: " << stderr); - _bootstrap_loader_fini_helper(plugin_hdl, plugin_name); - return SHMEM_INVALID_VALUE; + plugin_hdl = dlopen(plugin_name, RTLD_NOW); } - - return SHMEM_SUCCESS; + dlerror(); } -int bootstrap_loader_init(const char *plugin, void *arg, shmemi_bootstrap_handle_t *handle) +void shmemi_bootstrap_free() { - int status = _bootstrap_loader_init_helper(plugin, handle); - if (status != 0) { - SHM_LOG_ERROR("Bootstrap library dlopen failed for " << plugin); - _bootstrap_loader_fini_helper(plugin_hdl, plugin_name); - return SHMEM_INNER_ERROR; - } - int (*bootstrap_plugin_initops)(void *arg, shmemi_bootstrap_handle_t *handle); - GET_SYMBOL(plugin_hdl, BOOTSTRAP_PLUGIN_INIT_FUNC, bootstrap_plugin_initops, status); - status = bootstrap_plugin_initops(arg, handle); - if (status != 0) { - SHM_LOG_ERROR("Bootstrap plugin init failed for " << plugin); - _bootstrap_loader_fini_helper(plugin_hdl, plugin_name); - return SHMEM_INNER_ERROR; + if (plugin_hdl != nullptr) { + dlclose(plugin_hdl); + plugin_hdl = nullptr; } - return SHMEM_SUCCESS; -} -// for UID -int32_t shmemi_bootstrap_pre_init() { + if (plugin_name != nullptr) { + free(plugin_name); + plugin_name = nullptr; + } } int32_t shmemi_bootstrap_init(int flags, shmemi_bootstrap_attr_t *attr) { - int32_t status; + int32_t status = SHMEM_SUCCESS; + void *arg; if (flags & SHMEMX_INIT_WITH_MPI) { plugin_name = BOOTSTRAP_MODULE_MPI; - status = bootstrap_loader_init(plugin_name, (attr != NULL) ? attr->mpi_comm : NULL, &g_boot_handle); + arg = (attr != NULL) ? attr->mpi_comm : NULL; } else if (flags & SHMEMX_INIT_WITH_UNIQUEID) { plugin_name = BOOTSTRAP_MODULE_UID; - status = bootstrap_loader_init(plugin_name, (attr->uid_args), &g_boot_handle); + status = shmemi_bootstrap_pre_init(); } else { SHM_LOG_ERROR("Unknown Type for bootstrap"); status = SHMEM_INVALID_PARAM; } + shmemi_bootstrap_loader(); + + if (!plugin_hdl) { + SHM_LOG_ERROR("Bootstrap unable to load " << plugin_name << ", err is: " << stderr); + shmemi_bootstrap_free(); + return SHMEM_INVALID_VALUE; + } + + int (*plugin_init)(void *, shmemi_bootstrap_handle_t *); + *((void **)&plugin_init) = dlsym(plugin_hdl, BOOTSTRAP_PLUGIN_INIT_FUNC); + if (!plugin_init) { + SHM_LOG_ERROR("Bootstrap plugin init func dlsym failed"); + shmemi_bootstrap_free(); + return SHMEM_INNER_ERROR; + } + status = plugin_init(arg, &g_boot_handle); + if (status != 0) { + SHM_LOG_ERROR("Bootstrap plugin init failed for " << plugin_name); + shmemi_bootstrap_free(); + return SHMEM_INNER_ERROR; + } return status; } diff --git a/src/host/common/shmemi_host_types.h b/src/host/common/shmemi_host_types.h index c7cbc6d9..9619470d 100644 --- a/src/host/common/shmemi_host_types.h +++ b/src/host/common/shmemi_host_types.h @@ -1,3 +1,12 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #ifndef SHMEMI_HOST_TYPES_H #define SHMEMI_HOST_TYPES_H diff --git a/src/host/init/shmem_init_default.cpp b/src/host/init/shmem_init_default.cpp index a17fcf09..39afda10 100644 --- a/src/host/init/shmem_init_default.cpp +++ b/src/host/init/shmem_init_default.cpp @@ -197,7 +197,7 @@ void shmem_rank_exit(int status) exit(status); } -int32_t shmem_init_attr(uint32_t bootstrap_flags, shmem_init_attr_t *attributes) { +int32_t shmem_init_attr(shmemx_bootstrap_t bootstrap_flags, shmem_init_attr_t *attributes) { // namespace to be deleted using namespace shm; SHMEM_CHECK_RET(shmem_set_log_level(shm::ERROR_LEVEL)); diff --git a/src/modules/bootstrap/shmemi_bootstrap_mpi.cpp b/src/modules/bootstrap/shmemi_bootstrap_mpi.cpp index 62f03be0..d0ebba3d 100644 --- a/src/modules/bootstrap/shmemi_bootstrap_mpi.cpp +++ b/src/modules/bootstrap/shmemi_bootstrap_mpi.cpp @@ -1,18 +1,32 @@ - +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #include #include #include #include #include -#include "shmemi_host_common.h" +#include "host/shmem_host_def.h" +#include "common/shmemi_logger.h" +#include "common/shmemi_host_types.h" +#include "bootstrap/shmemi_bootstrap.h" -static MPI_Comm shmemi_bootstrap_comm = MPI_COMM_NULL; -static int shmem_initialized_mpi = 0; +typedef struct { + MPI_Comm comm; + int mpi_initialized; +} shmemi_bootstrap_mpi_state_t; +static shmemi_bootstrap_mpi_state_t shmemi_bootstrap_mpi_state = {MPI_COMM_NULL, 0}; static int shmemi_bootstrap_mpi_barrier(shmemi_bootstrap_handle_t *handle) { int status = MPI_SUCCESS; - status = MPI_Barrier(shmemi_bootstrap_comm); + status = MPI_Barrier(shmemi_bootstrap_mpi_state.comm); SHMEM_CHECK_RET(status); return status; @@ -22,7 +36,7 @@ static int shmemi_bootstrap_mpi_allgather(const void *sendbuf, void *recvbuf, in shmemi_bootstrap_handle_t *handle) { int status = MPI_SUCCESS; - status = MPI_Allgather(sendbuf, length, MPI_BYTE, recvbuf, length, MPI_BYTE, shmemi_bootstrap_comm); + status = MPI_Allgather(sendbuf, length, MPI_BYTE, recvbuf, length, MPI_BYTE, shmemi_bootstrap_mpi_state.comm); SHMEM_CHECK_RET(status); return status; @@ -32,7 +46,7 @@ static int shmemi_bootstrap_mpi_alltoall(const void *sendbuf, void *recvbuf, int shmemi_bootstrap_handle_t *handle) { int status = MPI_SUCCESS; - status = MPI_Alltoall(sendbuf, length, MPI_BYTE, recvbuf, length, MPI_BYTE, shmemi_bootstrap_comm); + status = MPI_Alltoall(sendbuf, length, MPI_BYTE, recvbuf, length, MPI_BYTE, shmemi_bootstrap_mpi_state.comm); SHMEM_CHECK_RET(status); return status; @@ -41,7 +55,7 @@ static int shmemi_bootstrap_mpi_alltoall(const void *sendbuf, void *recvbuf, int static void shmemi_bootstrap_mpi_global_exit(int status) { int rc = MPI_SUCCESS; - rc = MPI_Abort(shmemi_bootstrap_comm, status); + rc = MPI_Abort(shmemi_bootstrap_mpi_state.comm, status); if (rc != MPI_SUCCESS) { exit(1); } @@ -54,7 +68,7 @@ static int shmemi_bootstrap_mpi_finalize(shmemi_bootstrap_handle_t *handle) { SHMEM_CHECK_RET(status); if (finalized) { - if (shmem_initialized_mpi) { + if (shmemi_bootstrap_mpi_state.mpi_initialized) { status = SHMEM_INNER_ERROR; } else { status = 0; @@ -63,12 +77,12 @@ static int shmemi_bootstrap_mpi_finalize(shmemi_bootstrap_handle_t *handle) { return status; } - if (!finalized && shmem_initialized_mpi) { - status = MPI_Comm_free(&shmemi_bootstrap_comm); + if (!finalized && shmemi_bootstrap_mpi_state.mpi_initialized) { + status = MPI_Comm_free(&shmemi_bootstrap_mpi_state.comm); SHMEM_CHECK_RET(status); } - if (shmem_initialized_mpi) MPI_Finalize(); + if (shmemi_bootstrap_mpi_state.mpi_initialized) MPI_Finalize(); return status; } @@ -86,27 +100,27 @@ int shmemi_bootstrap_plugin_init(void *mpi_comm, shmemi_bootstrap_handle_t *hand SHMEM_CHECK_RET(status); if (!initialized && !finalized) { MPI_Init(NULL, NULL); - shmem_initialized_mpi = 1; + shmemi_bootstrap_mpi_state.mpi_initialized = 1; if (src_comm != MPI_COMM_WORLD && src_comm != MPI_COMM_SELF) { status = SHMEM_INNER_ERROR; - if (shmem_initialized_mpi) { + if (shmemi_bootstrap_mpi_state.mpi_initialized) { MPI_Finalize(); - shmem_initialized_mpi = 0; + shmemi_bootstrap_mpi_state.mpi_initialized = 0; } } } else if (finalized) { status = SHMEM_INNER_ERROR; - if (shmem_initialized_mpi) { + if (shmemi_bootstrap_mpi_state.mpi_initialized) { MPI_Finalize(); - shmem_initialized_mpi = 0; + shmemi_bootstrap_mpi_state.mpi_initialized = 0; } } - status = MPI_Comm_dup(src_comm, &shmemi_bootstrap_comm); + status = MPI_Comm_dup(src_comm, &shmemi_bootstrap_mpi_state.comm); SHMEM_CHECK_RET(status); - status = MPI_Comm_rank(shmemi_bootstrap_comm, &handle->mype); + status = MPI_Comm_rank(shmemi_bootstrap_mpi_state.comm, &handle->mype); SHMEM_CHECK_RET(status); - status = MPI_Comm_size(shmemi_bootstrap_comm, &handle->npes); + status = MPI_Comm_size(shmemi_bootstrap_mpi_state.comm, &handle->npes); SHMEM_CHECK_RET(status); handle->allgather = shmemi_bootstrap_mpi_allgather; handle->alltoall = shmemi_bootstrap_mpi_alltoall; @@ -114,7 +128,7 @@ int shmemi_bootstrap_plugin_init(void *mpi_comm, shmemi_bootstrap_handle_t *hand handle->global_exit = shmemi_bootstrap_mpi_global_exit; handle->finalize = shmemi_bootstrap_mpi_finalize; handle->pre_init_ops = NULL; - handle->bootstrap_state = &shmemi_bootstrap_comm; + handle->bootstrap_state = &shmemi_bootstrap_mpi_state.comm; return status; } diff --git a/src/modules/bootstrap/shmemi_bootstrap_uid.cpp b/src/modules/bootstrap/shmemi_bootstrap_uid.cpp index 34ec3185..a3f6249e 100644 --- a/src/modules/bootstrap/shmemi_bootstrap_uid.cpp +++ b/src/modules/bootstrap/shmemi_bootstrap_uid.cpp @@ -1,3 +1,12 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #ifdef SHMEM_BOOTSTRAP_UID #include "shmemi_bootstrap.h" diff --git a/src/modules/transport/shmemi_mte.cpp b/src/modules/transport/shmemi_mte.cpp index 492c2eaf..a8a6e629 100644 --- a/src/modules/transport/shmemi_mte.cpp +++ b/src/modules/transport/shmemi_mte.cpp @@ -1,3 +1,12 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #include "shmemi_transport.h" typedef struct { diff --git a/src/modules/transport/shmemi_rdma.cpp b/src/modules/transport/shmemi_rdma.cpp index b35be44a..2fa76509 100644 --- a/src/modules/transport/shmemi_rdma.cpp +++ b/src/modules/transport/shmemi_rdma.cpp @@ -1,3 +1,12 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #include "shmemi_transport.h" typedef struct { -- Gitee From ac23fad24dd7914e3428d4147cafc693327d24d8 Mon Sep 17 00:00:00 2001 From: zhangyunqi05 Date: Mon, 22 Sep 2025 09:39:47 +0800 Subject: [PATCH 28/74] example cmake modify --- examples/hellow_word/CMakeLists.txt | 30 ++++------------------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/examples/hellow_word/CMakeLists.txt b/examples/hellow_word/CMakeLists.txt index 44b020af..954b552c 100644 --- a/examples/hellow_word/CMakeLists.txt +++ b/examples/hellow_word/CMakeLists.txt @@ -29,14 +29,7 @@ else() set(ASCEND_HOME_PATH $ENV{ASCEND_HOME_PATH}) endif() -option(USE_UNIT_TEST "USE_UNIT_TEST" OFF) -option(USE_EXAMPLES "USE_EXAMPLES" OFF) -message(STATUS "USE_UNIT_TEST:${USE_UNIT_TEST}") -message(STATUS "USE_EXAMPLES:${USE_EXAMPLES}") -set(ASCEND_DRIVER_PATH /usr/local/Ascend/driver) - -set(CMAKE_COMPILER g++) -# set(CMAKE_C_COMPILER ${CMAKE_COMPILER}) +set(CMAKE_COMPILER bisheng) set(CMAKE_CXX_COMPILER ${CMAKE_COMPILER}) add_compile_options( @@ -61,22 +54,19 @@ include_directories( ${ASCEND_HOME_PATH}/include ${ASCEND_HOME_PATH}/include/experiment/runtime ${ASCEND_HOME_PATH}/include/experiment/msprof - ${ASCEND_DRIVER_PATH}/kernel/inc ) link_directories( ${ASCEND_HOME_PATH}/lib64 - ${ASCEND_DRIVER_PATH}/lib64/driver ) -link_libraries(runtime stdc++ ascendcl m tiling_api platform c_sec dl nnopbase ascend_hal pthread) +link_libraries(runtime stdc++ ascendcl m tiling_api platform c_sec dl nnopbase pthread) find_package(MPI REQUIRED) include_directories( ${MPI_INCLUDE_PATH} ${ASCEND_HOME_PATH}/lib64 - ${ASCEND_DRIVER_PATH}/lib64/driver ) link_directories( @@ -86,21 +76,9 @@ link_directories( add_executable(helloword main.cpp) target_include_directories(helloword PRIVATE - ${ASCEND_DRIVER_PATH}/kernel/inc ${ASCEND_HOME_PATH}/include/ - ${CMAKE_CURRENT_SOURCE_DIR}/../../include - ${CMAKE_CURRENT_SOURCE_DIR}/../../install/memfabric_hybrid/include/smem/host/ - ${CMAKE_CURRENT_SOURCE_DIR}/../../install/memfabric_hybrid/include/smem/device/ + ${CMAKE_CURRENT_SOURCE_DIR}/../../install/shmem/include ) target_link_libraries(helloword PRIVATE MPI::MPI_CXX) -target_link_libraries(helloword PRIVATE shmem) - -target_compile_options(helloword PRIVATE ${MPI_CXX_COMPILE_FLAGS}) - -set_target_properties(helloword PROPERTIES - CXX_COMPILER "g++" -) -target_compile_options(helloword PRIVATE - "$<$:-g;-Wall>" -) \ No newline at end of file +target_link_libraries(helloword PRIVATE shmem) \ No newline at end of file -- Gitee From 79131aae76be297dced14e874dccb3533f9f4e48 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Tue, 23 Sep 2025 15:02:24 +0800 Subject: [PATCH 29/74] Add normal shmem init && enable multi backends --- CMakeLists.txt | 3 + examples/CMakeLists.txt | 29 +- examples/allgather/main.cpp | 24 +- examples/allgather/run.sh | 6 +- examples/allgather/scripts/data_statistic.py | 2 +- include/host/shmem_host_init.h | 40 -- .../internal/device/shmemi_base_copy_api.h | 113 ++++++ .../internal/device/shmemi_device_common.h | 24 +- .../device/sync/shmemi_device_barrier.h | 2 +- include/internal/host_device/shmemi_types.h | 5 + src/CMakeLists.txt | 36 +- .../mf}/shmemi_init_mf.cpp | 20 +- .../mf}/shmemi_init_mf.h | 24 +- .../normal/shmemi_init_normal.cpp | 70 ++++ .../normal}/shmemi_init_normal.h | 24 +- .../shmemi_init_base.h | 12 +- .../init/init_impl/shmemi_init_normal.cpp | 26 -- ...{shmem_init_default.cpp => shmem_init.cpp} | 175 +++------ src/host/init/shmem_init_mf.cpp | 361 ------------------ src/host/init/shmemi_init.h | 11 +- src/host/mem/shmemi_global_state.cpp | 32 ++ src/host/mem/shmemi_global_state.h | 29 ++ src/host/mem/shmemi_heap.cpp | 133 +++++++ src/host/mem/shmemi_heap.h | 62 +++ 24 files changed, 638 insertions(+), 625 deletions(-) create mode 100644 include/internal/device/shmemi_base_copy_api.h rename src/host/init/{init_impl => init_backends/mf}/shmemi_init_mf.cpp (91%) rename src/host/init/{init_impl => init_backends/mf}/shmemi_init_mf.h (64%) create mode 100644 src/host/init/init_backends/normal/shmemi_init_normal.cpp rename src/host/init/{init_impl => init_backends/normal}/shmemi_init_normal.h (53%) rename src/host/init/{init_impl => init_backends}/shmemi_init_base.h (67%) delete mode 100644 src/host/init/init_impl/shmemi_init_normal.cpp rename src/host/init/{shmem_init_default.cpp => shmem_init.cpp} (67%) delete mode 100644 src/host/init/shmem_init_mf.cpp create mode 100644 src/host/mem/shmemi_global_state.cpp create mode 100644 src/host/mem/shmemi_global_state.h create mode 100644 src/host/mem/shmemi_heap.cpp create mode 100644 src/host/mem/shmemi_heap.h diff --git a/CMakeLists.txt b/CMakeLists.txt index afc3c38b..f02c3046 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -93,6 +93,9 @@ link_directories( link_libraries(runtime stdc++ ascendcl m tiling_api platform c_sec dl nnopbase ascend_hal pthread) +# MF_BACKEND +set(USE_MF "1") + # 添加子目录 add_subdirectory(src) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 3262e525..80acb8cb 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -6,6 +6,8 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. +find_package(MPI REQUIRED) + function(shmem_add_fusion_example NAME) add_executable(${NAME} ${ARGN}) @@ -18,9 +20,16 @@ function(shmem_add_fusion_example NAME) ${PROJECT_SOURCE_DIR}/3rdparty/catlass/examples/common ${PROJECT_SOURCE_DIR}/examples/${NAME} ${PROJECT_SOURCE_DIR}/examples/utils + ${MPI_INCLUDE_PATH} ) target_link_options(${NAME} PRIVATE --cce-fatobj-link) - target_link_libraries(${NAME} PRIVATE shmem ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_smem.so) + target_link_libraries(${NAME} PRIVATE shmem ${MPI_CXX_COMPILE_FLAGS} ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_smem.so) + target_compile_options(${NAME} PRIVATE ${MPI_CXX_COMPILE_FLAGS}) + + if ("${USE_MF}" STREQUAL "1") + target_compile_definitions(${NAME} PRIVATE BACKEND_MF=1) + endif() + endfunction() function(shmem_add_collective_example NAME) @@ -34,6 +43,9 @@ function(shmem_add_collective_example NAME) ${PROJECT_SOURCE_DIR}/examples/${NAME} ) target_link_options(${NAME}_kernel PRIVATE --cce-fatobj-link) + if ("${USE_MF}" STREQUAL "1") + target_compile_definitions(${NAME}_kernel PRIVATE BACKEND_MF=1) + endif() add_executable(${NAME} main.cpp) target_compile_options(${NAME} PRIVATE ${CMAKE_CPP_COMPILE_OPTIONS}) @@ -44,15 +56,22 @@ function(shmem_add_collective_example NAME) ${PROJECT_SOURCE_DIR}/examples/${NAME} ${PROJECT_SOURCE_DIR}/examples/utils ${PROJECT_SOURCE_DIR}/src/host + ${MPI_INCLUDE_PATH} ) - target_link_libraries(${NAME} PRIVATE shmem ${NAME}_kernel ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_smem.so) + target_link_libraries(${NAME} PRIVATE shmem ${NAME}_kernel MPI::MPI_CXX ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_smem.so) + target_compile_options(${NAME} PRIVATE ${MPI_CXX_COMPILE_FLAGS}) + + if ("${USE_MF}" STREQUAL "1") + target_compile_definitions(${NAME} PRIVATE BACKEND_MF=1) + endif() + endfunction() foreach(EXAMPLE allgather - matmul_allreduce - rdma_perftest - rdma_demo + # matmul_allreduce + # rdma_perftest + # rdma_demo ) add_subdirectory(${EXAMPLE}) endforeach() \ No newline at end of file diff --git a/examples/allgather/main.cpp b/examples/allgather/main.cpp index e631ad9f..bbb42b84 100644 --- a/examples/allgather/main.cpp +++ b/examples/allgather/main.cpp @@ -30,14 +30,16 @@ using fp16_t = op::fp16_t; using bfloat16 = op::bfloat16; +#include + #include "acl/acl.h" #include "shmem_api.h" int g_npus = 8; -const char *ipport; +const char *ipport = "tcp://127.0.0.1:8998"; int f_rank = 0; int f_npu = 0; -const char *data_type; +const char *data_type = "int"; constexpr int64_t SYNC_FLAG_INTERVAL = 16; constexpr int64_t UB_DMA_MAX_SIZE = 190 * 1024; @@ -61,7 +63,7 @@ int test_shmem_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size) shmem_init_attr_t *attributes; status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); - status = shmem_init_attr(attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); // Prepare FFTS address uint64_t fftsAddr = shmemx_get_ffts_config(); @@ -172,16 +174,13 @@ int test_shmem_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size) int main(int argc, char *argv[]) { int status = 0; - int n_ranks = atoi(argv[1]); - int rank_id = atoi(argv[2]); - ipport = argv[3]; - g_npus = atoi(argv[4]); - f_rank = atoi(argv[5]); - f_npu = atoi(argv[6]); - data_type = argv[7]; + // MPI Init + MPI_Init(&argc, &argv); + int rank_id, n_ranks; + MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); + MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + uint64_t local_mem_size = 1024UL * 1024UL * 1024; - int32_t ret = shmem_set_conf_store_tls(false, nullptr, 0); - std::cout << "init shmem tls result:" << ret << std::endl; if (std::string(data_type) == "int") { status = test_shmem_all_gather(rank_id, n_ranks, local_mem_size); } else if (std::string(data_type) == "int32_t") { @@ -197,5 +196,6 @@ int main(int argc, char *argv[]) std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; + MPI_Finalize(); return 0; } \ No newline at end of file diff --git a/examples/allgather/run.sh b/examples/allgather/run.sh index 71b6fffd..061101a9 100644 --- a/examples/allgather/run.sh +++ b/examples/allgather/run.sh @@ -95,11 +95,7 @@ python3 ./scripts/data_gen.py $RANK_SIZE $TEST_TYPE # Kernel test rm -rf ./output -export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/install/memfabric_hybrid/lib/:${ASCEND_HOME_PATH}/lib64:$LD_LIBRARY_PATH -for (( idx =0; idx < ${GNPU_NUM}; idx = idx + 1 )); do - msprof --application="${PROJECT_ROOT}/build/bin/allgather $RANK_SIZE $idx $IPPORT $GNPU_NUM $FIRST_RANK $FIRST_NPU $TEST_TYPE" --output=${PROJECT_ROOT}/examples/allgather/output/ & -done -wait +mpirun -np ${GNPU_NUM} msprof --application="${PROJECT_ROOT}/build/bin/allgather" --output="${PROJECT_ROOT}/examples/allgather/output/" # Profiling data statistic python3 ./scripts/data_statistic.py diff --git a/examples/allgather/scripts/data_statistic.py b/examples/allgather/scripts/data_statistic.py index 2f7ca17a..a0619753 100644 --- a/examples/allgather/scripts/data_statistic.py +++ b/examples/allgather/scripts/data_statistic.py @@ -12,7 +12,7 @@ def open_input_file(input_file): return df def get_time_data(df, testLineNum: int): - df = df[df['kernel_type'] == "KERNEL_AIVEC"] + df = df[df['kernel_type'] == "AI_VECTOR_CORE"] df = df.reset_index(drop=True) time_data = [] total_rows = len(df) diff --git a/include/host/shmem_host_init.h b/include/host/shmem_host_init.h index c2ea76fd..5148f3b4 100644 --- a/include/host/shmem_host_init.h +++ b/include/host/shmem_host_init.h @@ -70,40 +70,6 @@ SHMEM_HOST_API int shmem_set_timeout(shmem_init_attr_t *attributes, uint32_t val */ SHMEM_HOST_API int shmem_init_attr(shmem_init_attr_t *attributes); -/** - * @brief Register a decrypt key password handler. - * - * @param decrypt_handler decrypt function pointer - * @return Returns 0 on success or an error code on failure - */ -SHMEM_HOST_API int32_t shmem_register_decrypt_handler(const shmem_decrypt_handler decrypt_handler); - -/** - * @brief Set the log print function for the SHMEM library. - * - * @param func the logging function, takes level and msg as parameter - * @return Returns 0 on success or an error code on failure - */ -SHMEM_HOST_API int32_t shmem_set_extern_logger(void (*func)(int level, const char *msg)); - -/** - * @brief Set the logging level. - * - * @param level the logging level. 0-debug, 1-info, 2-warn, 3-error - * @return Returns 0 on success or an error code on failure - */ -SHMEM_HOST_API int32_t shmem_set_log_level(int level); - -/** - * @brief Initialize the config store tls info. - * - * @param enable whether to enable tls - * @param tls_info the format describle in memfabric SECURITYNOTE.md, if disabled tls_info won't be use - * @param tls_info_len length of tls_info, if disabled tls_info_len won't be use - * @return Returns 0 on success or an error code on failure - */ -SHMEM_HOST_API int32_t shmem_set_conf_store_tls(bool enable, const char *tls_info, const uint32_t tls_info_len); - /** * @brief Release all resources used by the SHMEM library. * @@ -127,12 +93,6 @@ SHMEM_HOST_API void shmem_info_get_version(int *major, int *minor); */ SHMEM_HOST_API void shmem_info_get_name(char *name); -/** - * @brief exit all ranks. - * - * @param status [IN] name - */ -SHMEM_HOST_API void shmem_global_exit(int status); #ifdef __cplusplus } #endif diff --git a/include/internal/device/shmemi_base_copy_api.h b/include/internal/device/shmemi_base_copy_api.h new file mode 100644 index 00000000..18f9fa41 --- /dev/null +++ b/include/internal/device/shmemi_base_copy_api.h @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef __SHMEMI_BASE_COPY_H__ +#define __SHMEMI_BASE_COPY_H__ + +#include "kernel_operator.h" + +#define SMEM_SHM_INLINE_AICORE __attribute__((always_inline)) inline __aicore__ + +template +SMEM_SHM_INLINE_AICORE void smem_shm_copy_ub2gm(__gm__ T* dstGva, __ubuf__ T* srcUb, + uint32_t size, bool enableL2 = true) +{ + ASCENDC_ASSERT((dstGva != nullptr), "input gva is null"); + + AscendC::LocalTensor ubTensor; + AscendC::GlobalTensor gmTensor; + AscendC::DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); + ubTensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); + ubTensor.address_.bufferAddr = reinterpret_cast(srcUb); + gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(dstGva)); + if (!enableL2) { + gmTensor.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); + } + AscendC::DataCopyPad(gmTensor, ubTensor, dataCopyParams); +} + +template +SMEM_SHM_INLINE_AICORE void smem_shm_copy_ub2gm(const AscendC::GlobalTensor &dstGva, + const AscendC::LocalTensor &srcUb, uint32_t size) +{ + AscendC::DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); + AscendC::DataCopyPad(dstGva, srcUb, dataCopyParams); +} + +template +SMEM_SHM_INLINE_AICORE void smem_shm_copy_ub2gm(__gm__ T* dstGva, __ubuf__ T* srcUb, + AscendC::DataCopyExtParams ©Params) +{ + ASCENDC_ASSERT((dstGva != nullptr), "input gva is null"); + + AscendC::LocalTensor ubTensor; + AscendC::GlobalTensor gmTensor; + ubTensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); + ubTensor.address_.bufferAddr = reinterpret_cast(srcUb); + gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(dstGva)); + AscendC::DataCopyPad(gmTensor, ubTensor, copyParams); +} + +template +SMEM_SHM_INLINE_AICORE void smem_shm_copy_ub2gm(const AscendC::GlobalTensor &dstGva, + const AscendC::LocalTensor &srcUb, AscendC::DataCopyExtParams ©Params) +{ + AscendC::DataCopyPad(dstGva, srcUb, copyParams); +} + +template +SMEM_SHM_INLINE_AICORE void smem_shm_copy_gm2ub(__ubuf__ T* dstUb, __gm__ T* srcGva, + uint32_t size, bool enableL2 = true) +{ + ASCENDC_ASSERT((srcGva != nullptr), "input gva is null"); + AscendC::LocalTensor ubTensor; + AscendC::GlobalTensor gmTensor; + AscendC::DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); + ubTensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); + ubTensor.address_.bufferAddr = reinterpret_cast(dstUb); + gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(srcGva)); + if (!enableL2) { + gmTensor.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); + } + AscendC::DataCopyPadExtParams padParams; + AscendC::DataCopyPad(ubTensor, gmTensor, dataCopyParams, padParams); +} + +template +SMEM_SHM_INLINE_AICORE void smem_shm_copy_gm2ub(const AscendC::LocalTensor &dstUb, + const AscendC::GlobalTensor &srcGva, uint32_t size) +{ + AscendC::DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); + AscendC::DataCopyPadExtParams padParams; + AscendC::DataCopyPad(dstUb, srcGva, dataCopyParams, padParams); +} + +template +SMEM_SHM_INLINE_AICORE void smem_shm_copy_gm2ub(__ubuf__ T* dstUb, __gm__ T* srcGva, + AscendC::DataCopyExtParams ©Params) +{ + ASCENDC_ASSERT((srcGva != nullptr), "input gva is null"); + AscendC::LocalTensor ubTensor; + AscendC::GlobalTensor gmTensor; + ubTensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); + ubTensor.address_.bufferAddr = reinterpret_cast(dstUb); + gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(srcGva)); + AscendC::DataCopyPadExtParams padParams; + AscendC::DataCopyPad(ubTensor, gmTensor, copyParams, padParams); +} + +template +SMEM_SHM_INLINE_AICORE void smem_shm_copy_gm2ub(const AscendC::LocalTensor &dstUb, + const AscendC::GlobalTensor &srcGva, AscendC::DataCopyExtParams ©Params) +{ + AscendC::DataCopyPadExtParams padParams; + AscendC::DataCopyPad(dstUb, srcGva, copyParams, padParams); +} + +#endif // __SHMEMI_BASE_COPY_H__ \ No newline at end of file diff --git a/include/internal/device/shmemi_device_common.h b/include/internal/device/shmemi_device_common.h index 63f1b980..fa7de721 100644 --- a/include/internal/device/shmemi_device_common.h +++ b/include/internal/device/shmemi_device_common.h @@ -13,13 +13,33 @@ #include "shmemi_device_arch.h" #include "shmemi_device_def.h" -#include "smem_shm_aicore_base_api.h" - constexpr int ub_limit = 192 * 1024; +#ifdef BACKEND_MF +#include "smem_shm_aicore_base_api.h" + SHMEM_DEVICE __gm__ shmemi_device_host_state_t *shmemi_get_state() { return reinterpret_cast<__gm__ shmemi_device_host_state_t *>(smem_shm_get_extra_context_addr()); } +#else +#include "shmemi_base_copy_api.h" + +// rdma +constexpr uint64_t SMEM_SHM_DEVICE_PRE_META_SIZE = 128UL; // 128B +constexpr uint64_t SMEM_SHM_DEVICE_GLOBAL_META_SIZE = SMEM_SHM_DEVICE_PRE_META_SIZE; // 128B +constexpr uint64_t SMEM_OBJECT_NUM_MAX = 511UL; // entity最大数量 +constexpr uint64_t SMEM_SHM_DEVICE_META_SIZE = SMEM_SHM_DEVICE_PRE_META_SIZE * SMEM_OBJECT_NUM_MAX + + SMEM_SHM_DEVICE_GLOBAL_META_SIZE; // 64K + +constexpr uint64_t SMEM_SHM_DEVICE_USER_CONTEXT_PRE_SIZE = 64UL * 1024UL; // 64K +constexpr uint64_t SMEM_SHM_DEVICE_INFO_SIZE = SMEM_SHM_DEVICE_USER_CONTEXT_PRE_SIZE * SMEM_OBJECT_NUM_MAX + + SMEM_SHM_DEVICE_META_SIZE; // 元数据+用户context,总大小32M, 对齐2M +constexpr uint64_t SMEM_SHM_DEVICE_META_ADDR = SVM_END_ADDR - SMEM_SHM_DEVICE_INFO_SIZE; + +SHMEM_DEVICE __gm__ shmemi_device_host_state_t *shmemi_get_state() { + return reinterpret_cast<__gm__ shmemi_device_host_state_t *>((__gm__ void*)(SVM_END_ADDR - GLOBAL_STATE_SIZE)); +} +#endif SHMEM_DEVICE int shmemi_get_my_pe() { return shmemi_get_state()->mype; diff --git a/include/internal/device/sync/shmemi_device_barrier.h b/include/internal/device/sync/shmemi_device_barrier.h index 4215f8a3..969de915 100644 --- a/include/internal/device/sync/shmemi_device_barrier.h +++ b/include/internal/device/sync/shmemi_device_barrier.h @@ -273,7 +273,7 @@ SHMEM_DEVICE void shmemi_barrier_npu_v3(shmemi_team_t *team) } else { // read remote int remote_pe = start + i * stride; - shmemi_signal_wait_until_eq_for_barrier((__gm__ int32_t *)shmemi_ptr(sync_array, remote_pe), count); + shmemi_signal_wait_until_eq_for_barrier((__gm__ int32_t *)shmem_ptr(sync_array, remote_pe), count); } } diff --git a/include/internal/host_device/shmemi_types.h b/include/internal/host_device/shmemi_types.h index 2df1ac71..ec46ecb1 100644 --- a/include/internal/host_device/shmemi_types.h +++ b/include/internal/host_device/shmemi_types.h @@ -46,6 +46,11 @@ extern "C" { #define SHMEM_EXTRA_SIZE_UNALIGHED SYNC_POOL_SIZE #define SHMEM_EXTRA_SIZE ALIGH_TO(SHMEM_EXTRA_SIZE_UNALIGHED, SHMEM_PAGE_SIZE) +// global_state +constexpr uint64_t DEVMM_SVM_MEM_START = 0x100000000000ULL; +constexpr uint64_t SVM_END_ADDR = 0x100000000000ULL + 0x80000000000ULL - (1UL << 30UL); // svm end +constexpr uint64_t GLOBAL_STATE_SIZE = 512UL * 1024UL * 1024UL; // global_state fixed length + // synchronization typedef int32_t shmemi_sync_bit[SHMEMI_SYNCBIT_SIZE / sizeof(int32_t)]; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0a8e2b81..08147484 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -16,31 +16,51 @@ target_compile_options(shmem_device PRIVATE ${CMAKE_CCE_COMPILE_OPTIONS} --cce-a target_include_directories(shmem_device PUBLIC ${PROJECT_SOURCE_DIR}/include/ - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/device/ ) +if ("${USE_MF}" STREQUAL "1") + target_compile_definitions(shmem_device PRIVATE BACKEND_MF=1) + target_include_directories(shmem_device + PUBLIC + ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/device/ + ) +endif() + file(GLOB_RECURSE SHMEM_HOST_FILES ${CMAKE_CURRENT_SOURCE_DIR}/host/*.cpp) list(FILTER SHMEM_HOST_FILES EXCLUDE REGEX "python_wrapper") list(FILTER SHMEM_HOST_FILES EXCLUDE REGEX "modules") -list(FILTER SHMEM_HOST_FILES EXCLUDE REGEX "shmem_init_default.cpp") +if ("${USE_MF}" STREQUAL "1") + list(FILTER SHMEM_HOST_FILES EXCLUDE REGEX "shmemi_init_normal.cpp") +else() + list(FILTER SHMEM_HOST_FILES EXCLUDE REGEX "shmemi_init_mf.cpp") +endif() add_library(shmem_host OBJECT ${SHMEM_HOST_FILES}) target_compile_options(shmem_host PRIVATE ${CMAKE_CPP_COMPILE_OPTIONS}) target_include_directories(shmem_host PUBLIC ${PROJECT_SOURCE_DIR}/include/ - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/host/ ${PROJECT_SOURCE_DIR}/src/host ${PROJECT_SOURCE_DIR}/src/device ) +if ("${USE_MF}" STREQUAL "1") + target_compile_definitions(shmem_host PRIVATE BACKEND_MF=1) + target_include_directories(shmem_host + PUBLIC + ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/host/ + ) +endif() + add_library(shmem SHARED $ $) target_link_options(shmem PRIVATE --cce-fatobj-link) -target_link_libraries(shmem - PUBLIC - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_smem.so - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_hybm_core.so -) +if ("${USE_MF}" STREQUAL "1") + target_link_libraries(shmem + PUBLIC + ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_smem.so + ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_hybm_core.so + ) +endif() # 安装配置 install(TARGETS shmem diff --git a/src/host/init/init_impl/shmemi_init_mf.cpp b/src/host/init/init_backends/mf/shmemi_init_mf.cpp similarity index 91% rename from src/host/init/init_impl/shmemi_init_mf.cpp rename to src/host/init/init_backends/mf/shmemi_init_mf.cpp index 669b1e6c..f2ef068e 100644 --- a/src/host/init/init_impl/shmemi_init_mf.cpp +++ b/src/host/init/init_backends/mf/shmemi_init_mf.cpp @@ -16,12 +16,20 @@ init_mf::init_mf(shmem_init_attr_t *attr, char *ipport) g_ipport = ipport; aclrtGetDevice(&device_id); + smem_set_conf_store_tls(false, nullptr, 0); + int32_t status = smem_init(DEFAULT_FLAG); if (status != SHMEM_SUCCESS) { SHM_LOG_ERROR("smem_init Failed"); } } +init_mf::~init_mf() +{ + finalize_device_state(); + heap_finalize(); +} + int init_mf::init_device_state() { int32_t status = SHMEM_SUCCESS; @@ -54,6 +62,12 @@ int init_mf::update_device_state(void* host_ptr, size_t size) return smem_shm_set_extra_context(g_smem_handle, host_ptr, size); } +int init_mf::finalize_device_state() +{ + // dummy function + return SHMEM_SUCCESS; +} + int init_mf::heap_init(shmemi_device_host_state_t &g_state) { int32_t status = SHMEM_SUCCESS; @@ -108,10 +122,4 @@ int init_mf::heap_finalize() smem_shm_uninit(0); smem_uninit(); return SHMEM_SUCCESS; -} - -int init_mf::barrier_all() -{ - SHM_ASSERT_RETURN(g_smem_handle != nullptr, SHMEM_INVALID_PARAM); - return smem_shm_control_barrier(g_smem_handle); } \ No newline at end of file diff --git a/src/host/init/init_impl/shmemi_init_mf.h b/src/host/init/init_backends/mf/shmemi_init_mf.h similarity index 64% rename from src/host/init/init_impl/shmemi_init_mf.h rename to src/host/init/init_backends/mf/shmemi_init_mf.h index d2266e72..f163f3c8 100644 --- a/src/host/init/init_impl/shmemi_init_mf.h +++ b/src/host/init/init_backends/mf/shmemi_init_mf.h @@ -1,25 +1,33 @@ + #ifndef SHMEMI_INIT_MF_H #define SHMEMI_INIT_MF_H #include -#include "shmemi_init_base.h" +#include "init/init_backends/shmemi_init_base.h" #include "shmemi_host_common.h" #include "internal/host_device/shmemi_types.h" +// smem api +#include +#include +#include +#include +#include +#include +#include + class init_mf: public init_base { public: + init_mf(shmem_init_attr_t *attr, char *ipport); + ~init_mf(); + int init_device_state() override; + int finalize_device_state() override; int update_device_state(void* host_ptr, size_t size) override; + int heap_init(shmemi_device_host_state_t &g_state) override; int heap_finalize() override; - - int barrier_all() override; - - init_mf(shmem_init_attr_t *attr, char *ipport); - ~init_mf() { - std::cout << "init_mf destructor called. " << std::endl; - } private: int32_t device_id; diff --git a/src/host/init/init_backends/normal/shmemi_init_normal.cpp b/src/host/init/init_backends/normal/shmemi_init_normal.cpp new file mode 100644 index 00000000..fc42ad50 --- /dev/null +++ b/src/host/init/init_backends/normal/shmemi_init_normal.cpp @@ -0,0 +1,70 @@ +#include "shmemi_init_normal.h" + +global_state_reigister *global_state_d = nullptr; +shmem_symmetric_heap *heap_obj = nullptr; + +init_normal::init_normal(shmem_init_attr_t *attr) +{ + mype = attr->my_rank; + npes = attr->n_ranks; + + // EnablePeerAccess + for (int i = 0; i < npes; i++) { + if (i != mype) { + aclrtDeviceEnablePeerAccess(i, 0); + } + } +} + +init_normal::~init_normal() +{ + finalize_device_state(); + heap_finalize(); +} + +int init_normal::init_device_state() +{ + global_state_d = new global_state_reigister(mype); + return SHMEM_SUCCESS; +} + +int init_normal::finalize_device_state() +{ + delete global_state_d; + return SHMEM_SUCCESS; +} + +int init_normal::update_device_state(void* host_ptr, size_t size) +{ + int32_t status = SHMEM_SUCCESS; + status = aclrtMemcpy(global_state_d->get_ptr(), size, host_ptr, size, ACL_MEMCPY_HOST_TO_DEVICE); + return status; +} + +int init_normal::heap_init(shmemi_device_host_state_t &g_state) +{ + heap_obj = new shmem_symmetric_heap(mype, npes); + + heap_obj->reserve_heap(g_state.heap_size); + + heap_obj->setup_heap(); + + // Assign host_global_state + g_state.heap_base = heap_obj->get_heap_base(); + for (int32_t i = 0; i < g_state.npes; i++) { + g_state.p2p_heap_base[i] = heap_obj->get_peer_heap_base_p2p(i); + } + g_state.is_shmem_created = true; + return SHMEM_SUCCESS; +} + +int init_normal::heap_finalize() +{ + int32_t status = SHMEM_SUCCESS; + + heap_obj->remove_heap(); + heap_obj->unreserve_heap(); + + delete heap_obj; + return status; +} \ No newline at end of file diff --git a/src/host/init/init_impl/shmemi_init_normal.h b/src/host/init/init_backends/normal/shmemi_init_normal.h similarity index 53% rename from src/host/init/init_impl/shmemi_init_normal.h rename to src/host/init/init_backends/normal/shmemi_init_normal.h index 46a98dca..08269270 100644 --- a/src/host/init/init_impl/shmemi_init_normal.h +++ b/src/host/init/init_backends/normal/shmemi_init_normal.h @@ -1,25 +1,31 @@ + #ifndef SHMEMI_INIT_NORMAL_H #define SHMEMI_INIT_NORMAL_H #include -#include "shmemi_init_base.h" -#include "shmemi_host_common.h" -#include "internal/host_device/shmemi_types.h" +#include "init/init_backends/shmemi_init_base.h" +#include "host/shmem_host_def.h" +#include "mem/shmemi_global_state.h" +#include "mem/shmemi_heap.h" +#include "bootstrap/shmemi_bootstrap.h" +#include "internal/host_device/shmemi_types.h" class init_normal: public init_base { public: + init_normal(shmem_init_attr_t *attr); + ~init_normal(); + int init_device_state() override; + int finalize_device_state() override; int update_device_state(void* host_ptr, size_t size) override; + int heap_init(shmemi_device_host_state_t &g_state) override; int heap_finalize() override; - - int barrier_all() override; - - ~init_normal() { - std::cout << "init_normal destructor called. " << std::endl; - } +private: + int mype; + int npes; }; diff --git a/src/host/init/init_impl/shmemi_init_base.h b/src/host/init/init_backends/shmemi_init_base.h similarity index 67% rename from src/host/init/init_impl/shmemi_init_base.h rename to src/host/init/init_backends/shmemi_init_base.h index dbea6b74..887aeb49 100644 --- a/src/host/init/init_impl/shmemi_init_base.h +++ b/src/host/init/init_backends/shmemi_init_base.h @@ -9,16 +9,12 @@ class init_base { public: virtual int init_device_state() = 0; + virtual int finalize_device_state() = 0; virtual int update_device_state(void* host_ptr, size_t size) = 0; + virtual int heap_init(shmemi_device_host_state_t &g_state) = 0; virtual int heap_finalize() = 0; - virtual int barrier_all() = 0; - - virtual ~init_base() { - std::cout << "init_base destructor called. " << std::endl; - } - -}; + virtual ~init_base() {} -#endif // SHMEMI_INIT_BASE_H \ No newline at end of file +}; \ No newline at end of file diff --git a/src/host/init/init_impl/shmemi_init_normal.cpp b/src/host/init/init_impl/shmemi_init_normal.cpp deleted file mode 100644 index fbbacb0a..00000000 --- a/src/host/init/init_impl/shmemi_init_normal.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#include "shmemi_init_normal.h" - -int init_normal::init_device_state() -{ - return SHMEM_SUCCESS; -} - -int init_normal::update_device_state(void* host_ptr, size_t size) -{ - return SHMEM_SUCCESS; -} - -int init_normal::heap_init(shmemi_device_host_state_t &g_state) -{ - return SHMEM_SUCCESS; -} - -int init_normal::heap_finalize() -{ - return SHMEM_SUCCESS; -} - -int init_normal::barrier_all() -{ - return SHMEM_SUCCESS; -} \ No newline at end of file diff --git a/src/host/init/shmem_init_default.cpp b/src/host/init/shmem_init.cpp similarity index 67% rename from src/host/init/shmem_init_default.cpp rename to src/host/init/shmem_init.cpp index 4e67cc07..26ecd3c3 100644 --- a/src/host/init/shmem_init_default.cpp +++ b/src/host/init/shmem_init.cpp @@ -52,10 +52,8 @@ constexpr int DEFAULT_BLOCK_NUM = 1; shmemi_device_host_state_t g_state = SHMEM_DEVICE_HOST_STATE_INITIALIZER; shmemi_host_state_t g_state_host = {nullptr, DEFAULT_TEVENT, DEFAULT_BLOCK_NUM}; - shmem_init_attr_t g_attr; -static smem_shm_t g_smem_handle = nullptr; -static bool g_attr_init = false; + static char *g_ipport = nullptr; int32_t version_compatible() @@ -85,27 +83,6 @@ int32_t shmemi_state_init_attr(shmem_init_attr_t *attributes) return status; } -// TODO: use shmem native heap init -int32_t shmemi_heap_init_v2(shmem_init_attr_t *attributes) { - // 申请Physical Mem, 申请Virtual Addr,MMAP - // 连续映射, VA_size = rank * PA_size - // 非连续映射, VA_size = PA_size - - return 0; -} - -// TODO: use shmem native barrier -int32_t shmemi_control_barrier_all() -{ - return g_boot_handle.barrier(&g_boot_handle); -} - -// TODO: use shmem native global state -int32_t update_device_state() -{ - return 0; -} - int32_t check_attr(shmem_init_attr_t *attributes) { if ((attributes->my_rank < 0) || (attributes->n_ranks <= 0)) { @@ -127,6 +104,18 @@ int32_t check_attr(shmem_init_attr_t *attributes) } // namespace shm +init_base* init_manager; + +int32_t shmemi_control_barrier_all() +{ + return g_boot_handle.barrier(&g_boot_handle); +} + +int32_t update_device_state() +{ + return init_manager->update_device_state((void *)&shm::g_state, sizeof(shmemi_device_host_state_t)); +} + int32_t shmem_set_data_op_engine_type(shmem_init_attr_t *attributes, data_op_engine_type_t value) { attributes->option_attr.data_op_engine_type = value; @@ -153,11 +142,6 @@ int32_t shmem_set_attr(int32_t my_rank, int32_t n_ranks, uint64_t local_mem_size return SHMEM_INVALID_PARAM; } // 安全警告:此处strlen依赖ip_port以null结尾,如果ip_port不是合法的C字符串,将导致越界读取 - if (ip_port == nullptr) { - SHM_LOG_ERROR("my_rank:" << my_rank << " ip_port is NULL!"); - return SHMEM_INVALID_PARAM; - } - // 安全警告:此处strlen依赖ip_port以null结尾,如果ip_port不是合法的C字符串,将导致越界读取 size_t ip_len = strlen(ip_port); shm::g_ipport = new (std::nothrow) char[ip_len + 1]; if (shm::g_ipport == nullptr) { @@ -176,7 +160,7 @@ int32_t shmem_set_attr(int32_t my_rank, int32_t n_ranks, uint64_t local_mem_size shm::g_attr.local_mem_size = local_mem_size; shm::g_attr.option_attr = {attr_version, SHMEM_DATA_OP_MTE, shm::DEFAULT_TIMEOUT, shm::DEFAULT_TIMEOUT, shm::DEFAULT_TIMEOUT}; - shm::g_attr_init = true; + // shm::g_attr_init = true; return SHMEM_SUCCESS; } @@ -192,110 +176,48 @@ int32_t shmem_init_status() return SHMEM_STATUS_INVALID; } -void shmem_rank_exit(int status) +int32_t shmem_init_attr(shmem_init_attr_t *attributes) { - SHM_LOG_DEBUG("shmem_rank_exit is work ,status: " << status); - exit(status); -} + int32_t ret; -void shmem_init_attr(shmem_init_attr_t *attributes) { - // namespace to be deleted - using namespace shm; - + // config init SHM_ASSERT_RETURN(attributes != nullptr, SHMEM_INVALID_PARAM); - SHMEM_CHECK_RET(shmem_set_log_level(shm::ERROR_LEVEL)); - SHMEM_CHECK_RET(check_attr(attributes)); - SHMEM_CHECK_RET(version_compatible()); - SHMEM_CHECK_RET(shmemi_options_init()); - - SHMEM_CHECK_RET(shmemi_state_init_attr(attributes)); - - shmemi_bootstrap_init(); - - shmemi_heap_init(attributes); - SHMEM_CHECK_RET(memory_manager_initialize(g_state.heap_base, g_state.heap_size)); - - shmemi_build_transport_map(); - - shmemi_transport_init(); - - SHMEM_CHECK_RET(shmemi_team_init(g_state.mype, g_state.npes)); + // SHMEM_CHECK_RET(shmem_set_log_level(shm::ERROR_LEVEL)); + SHMEM_CHECK_RET(shm::check_attr(attributes)); + SHMEM_CHECK_RET(shm::version_compatible()); + SHMEM_CHECK_RET(shm::shmemi_options_init()); + + // bootstrap init + shmemi_bootstrap_attr_t attr = {}; + SHMEM_CHECK_RET(shmemi_bootstrap_init(bootstrap_flags, &attr)); + + // shmem basic init +#ifdef BACKEND_MF + init_manager = new init_mf(attributes, shm::g_ipport); +#else + init_manager = new init_normal(attributes); +#endif + SHMEM_CHECK_RET(shm::shmemi_state_init_attr(attributes)); + SHMEM_CHECK_RET(init_manager->init_device_state()); + SHMEM_CHECK_RET(init_manager->heap_init(shm::g_state)); SHMEM_CHECK_RET(update_device_state()); - SHMEM_CHECK_RET(shmemi_sync_init()); - g_state.is_shmem_initialized = true; - - return g_boot_handle.barrier(&g_boot_handle); -} - -int32_t shmem_finalize_v2() -{ - SHMEM_CHECK_RET(shm::shmemi_team_finalize()); - - if (shm::g_smem_handle != nullptr) { - int32_t status = smem_shm_destroy(shm::g_smem_handle, 0); - if (status != SHMEM_SUCCESS) { - SHM_LOG_ERROR("smem_shm_destroy Failed"); - return SHMEM_SMEM_ERROR; - } - shm::g_smem_handle = nullptr; - } - smem_shm_uninit(0); - smem_uninit(); + // shmem submodules init + SHMEM_CHECK_RET(shm::memory_manager_initialize(shm::g_state.heap_base, shm::g_state.heap_size)); + SHMEM_CHECK_RET(shm::shmemi_team_init(shm::g_state.mype, shm::g_state.npes)); + SHMEM_CHECK_RET(update_device_state()); + SHMEM_CHECK_RET(shm::shmemi_sync_init()); + // SHMEM_CHECK_RET(smem_shm_register_exit(shm::g_smem_handle, &shmem_rank_exit)); + shm::g_state.is_shmem_initialized = true; + SHMEM_CHECK_RET(shmemi_control_barrier_all()); return SHMEM_SUCCESS; } -int32_t shmem_register_decrypt_handler(const shmem_decrypt_handler handler) -{ - return smem_register_decrypt_handler(handler); -} - -int32_t shmem_set_extern_logger(void (*func)(int level, const char *msg)) -{ - shm::shm_out_logger::Instance().set_extern_log_func(func, true); - return smem_set_extern_logger(func); -} - -int32_t shmem_set_log_level(int level) -{ - // use env first, input level secondly, user may change level from env instead call func - const char *in_level = std::getenv("SHMEM_LOG_LEVEL"); - if (in_level != nullptr) { - auto tmp_level = std::string(in_level); - if (tmp_level == "DEBUG") { - level = shm::DEBUG_LEVEL; - } else if (tmp_level == "INFO") { - level = shm::INFO_LEVEL; - } else if (tmp_level == "WARN") { - level = shm::WARN_LEVEL; - } else if (tmp_level == "ERROR") { - level = shm::ERROR_LEVEL; - } else if (tmp_level == "FATAL") { - level = shm::FATAL_LEVEL; - } - } - shm::shm_out_logger::Instance().set_log_level(static_cast(level)); - return smem_set_log_level(level); -} - -int32_t shmem_set_conf_store_tls(bool enable, const char *tls_info, const uint32_t tls_info_len) -{ - return smem_set_conf_store_tls(enable, tls_info, tls_info_len); -} - int32_t shmem_finalize() { SHMEM_CHECK_RET(shm::shmemi_team_finalize()); - if (shm::g_smem_handle != nullptr) { - int32_t status = smem_shm_destroy(shm::g_smem_handle, 0); - if (status != SHMEM_SUCCESS) { - SHM_LOG_ERROR("smem_shm_destroy Failed"); - return SHMEM_SMEM_ERROR; - } - shm::g_smem_handle = nullptr; - } - smem_shm_uninit(0); - smem_uninit(); + delete init_manager; + shmemi_bootstrap_finalize(); return SHMEM_SUCCESS; } @@ -317,9 +239,4 @@ void shmem_info_get_name(char *name) name[i] = version_str[i]; } name[i] = '\0'; -} - -void shmem_global_exit(int status) -{ - smem_shm_global_exit(shm::g_smem_handle, status); -} +} \ No newline at end of file diff --git a/src/host/init/shmem_init_mf.cpp b/src/host/init/shmem_init_mf.cpp deleted file mode 100644 index 80b70653..00000000 --- a/src/host/init/shmem_init_mf.cpp +++ /dev/null @@ -1,361 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ -#include -#include -#include -#include -#include -#include -#include "acl/acl.h" -#include "shmemi_host_common.h" - -using namespace std; - -namespace shm { - -#define DEFAULT_MY_PE (-1) -#define DEFAULT_N_PES (-1) - -constexpr int DEFAULT_FLAG = 0; -constexpr int DEFAULT_ID = 0; -constexpr int DEFAULT_TIMEOUT = 120; -constexpr int DEFAULT_TEVENT = 0; -constexpr int DEFAULT_BLOCK_NUM = 1; - -// initializer -#define SHMEM_DEVICE_HOST_STATE_INITIALIZER \ - { \ - (1 << 16) + sizeof(shmemi_device_host_state_t), /* version */ \ - (DEFAULT_MY_PE), /* mype */ \ - (DEFAULT_N_PES), /* npes */ \ - NULL, /* heap_base */ \ - {NULL}, /* p2p_heap_base */ \ - {NULL}, /* sdma_heap_base */ \ - {}, /* topo_list */ \ - SIZE_MAX, /* heap_size */ \ - {NULL}, /* team_pools */ \ - 0, /* sync_pool */ \ - 0, /* sync_counter */ \ - 0, /* core_sync_pool */ \ - 0, /* core_sync_counter */ \ - false, /* shmem_is_shmem_initialized */ \ - false, /* shmem_is_shmem_created */ \ - {0, 16 * 1024, 0}, /* shmem_mte_config */ \ - } - -shmemi_device_host_state_t g_state = SHMEM_DEVICE_HOST_STATE_INITIALIZER; -shmemi_host_state_t g_state_host = {nullptr, DEFAULT_TEVENT, DEFAULT_BLOCK_NUM}; -shmem_init_attr_t g_attr; -// static smem_shm_t g_smem_handle = nullptr; -// static bool g_attr_init = false; -static char *g_ipport = nullptr; - -int32_t version_compatible() -{ - int32_t status = SHMEM_SUCCESS; - return status; -} - -int32_t shmemi_options_init() -{ - int32_t status = SHMEM_SUCCESS; - return status; -} - -int32_t shmemi_state_init_attr(shmem_init_attr_t *attributes) -{ - int32_t status = SHMEM_SUCCESS; - g_state.mype = attributes->my_rank; - g_state.npes = attributes->n_ranks; - g_state.heap_size = attributes->local_mem_size + SHMEM_EXTRA_SIZE; - - aclrtStream stream = nullptr; - SHMEM_CHECK_RET(aclrtCreateStream(&stream)); - g_state_host.default_stream = stream; - g_state_host.default_event_id = DEFAULT_TEVENT; - g_state_host.default_block_num = DEFAULT_BLOCK_NUM; - return status; -} - -// int32_t shmemi_heap_init(shmem_init_attr_t *attributes) -// { -// void *gva = nullptr; -// int32_t status = SHMEM_SUCCESS; -// int32_t device_id; -// SHMEM_CHECK_RET(aclrtGetDevice(&device_id)); - -// status = smem_init(DEFAULT_FLAG); -// if (status != SHMEM_SUCCESS) { -// SHM_LOG_ERROR("smem_init Failed"); -// return SHMEM_SMEM_ERROR; -// } -// smem_shm_config_t config; -// status = smem_shm_config_init(&config); -// if (status != SHMEM_SUCCESS) { -// SHM_LOG_ERROR("smem_shm_config_init Failed"); -// return SHMEM_SMEM_ERROR; -// } -// status = smem_shm_init(attributes->ip_port, attributes->n_ranks, attributes->my_rank, device_id, &config); -// if (status != SHMEM_SUCCESS) { -// SHM_LOG_ERROR("smem_shm_init Failed"); -// return SHMEM_SMEM_ERROR; -// } - -// config.shmInitTimeout = attributes->option_attr.shm_init_timeout; -// config.shmCreateTimeout = attributes->option_attr.shm_create_timeout; -// config.controlOperationTimeout = attributes->option_attr.control_operation_timeout; - -// g_smem_handle = smem_shm_create(DEFAULT_ID, attributes->n_ranks, attributes->my_rank, g_state.heap_size, -// static_cast(attributes->option_attr.data_op_engine_type), -// DEFAULT_FLAG, &gva); - -// if (g_smem_handle == nullptr || gva == nullptr) { -// SHM_LOG_ERROR("smem_shm_create Failed"); -// return SHMEM_SMEM_ERROR; -// } -// g_state.heap_base = (void *)((uintptr_t)gva + g_state.heap_size * attributes->my_rank); -// uint32_t reach_info = 0; -// for (int32_t i = 0; i < g_state.npes; i++) { -// status = smem_shm_topology_can_reach(g_smem_handle, i, &reach_info); -// g_state.p2p_heap_base[i] = (void *)((uintptr_t)gva + g_state.heap_size * i); -// if (reach_info & SMEMS_DATA_OP_MTE) { -// g_state.topo_list[i] |= SHMEM_TRANSPORT_MTE; -// } -// if (reach_info & SMEMS_DATA_OP_SDMA) { -// g_state.sdma_heap_base[i] = (void *)((uintptr_t)gva + g_state.heap_size * i); -// } else { -// g_state.sdma_heap_base[i] = NULL; -// } -// if (reach_info & SMEMS_DATA_OP_RDMA) { -// g_state.topo_list[i] |= SHMEM_TRANSPORT_ROCE; -// } -// } -// if (shm::g_ipport != nullptr) { -// delete[] shm::g_ipport; -// shm::g_ipport = nullptr; -// attributes->ip_port = nullptr; -// } else { -// SHM_LOG_WARN("my_rank:" << attributes->my_rank << " shm::g_ipport is released in advance!"); -// attributes->ip_port = nullptr; -// } -// g_state.is_shmem_created = true; -// return status; -// } - -// int32_t update_device_state() -// { -// if (!g_state.is_shmem_created) { -// return SHMEM_NOT_INITED; -// } -// return smem_shm_set_extra_context(g_smem_handle, (void *)&g_state, sizeof(shmemi_device_host_state_t)); -// } - -int32_t check_attr(shmem_init_attr_t *attributes) -{ - if ((attributes->my_rank < 0) || (attributes->n_ranks <= 0)) { - SHM_LOG_ERROR("my_rank:" << attributes->my_rank << " and n_ranks: " << attributes->n_ranks - << " cannot be less 0 , n_ranks still cannot be equal 0"); - return SHMEM_INVALID_VALUE; - } else if (attributes->n_ranks > SHMEM_MAX_RANKS) { - SHM_LOG_ERROR("n_ranks: " << attributes->n_ranks << " cannot be more than " << SHMEM_MAX_RANKS); - return SHMEM_INVALID_VALUE; - } else if (attributes->my_rank >= attributes->n_ranks) { - SHM_LOG_ERROR("n_ranks:" << attributes->n_ranks << " cannot be less than my_rank:" << attributes->my_rank); - return SHMEM_INVALID_PARAM; - } else if (attributes->local_mem_size <= 0) { - SHM_LOG_ERROR("local_mem_size:" << attributes->local_mem_size << " cannot be less or equal 0"); - return SHMEM_INVALID_VALUE; - } - return SHMEM_SUCCESS; -} - -} // namespace shm - -init_base* init_manager; - -extern shmemi_device_host_state_t shm::g_state; - -int32_t shmemi_control_barrier_all() -{ - return init_manager->barrier_all(); -} - -int32_t update_device_state() -{ - return init_manager->update_device_state((void *)&shm::g_state, sizeof(shmemi_device_host_state_t)); -} - -int32_t shmem_set_data_op_engine_type(shmem_init_attr_t *attributes, data_op_engine_type_t value) -{ - attributes->option_attr.data_op_engine_type = value; - return SHMEM_SUCCESS; -} - -int32_t shmem_set_timeout(shmem_init_attr_t *attributes, uint32_t value) -{ - attributes->option_attr.shm_init_timeout = value; - attributes->option_attr.shm_create_timeout = value; - attributes->option_attr.control_operation_timeout = value; - return SHMEM_SUCCESS; -} - -int32_t shmem_set_attr(int32_t my_rank, int32_t n_ranks, uint64_t local_mem_size, const char *ip_port, - shmem_init_attr_t **attributes) -{ - SHM_ASSERT_RETURN(local_mem_size <= SHMEM_MAX_LOCAL_SIZE, SHMEM_INVALID_VALUE); - SHM_ASSERT_RETURN(n_ranks <= SHMEM_MAX_RANKS, SHMEM_INVALID_VALUE); - SHM_ASSERT_RETURN(my_rank <= SHMEM_MAX_RANKS, SHMEM_INVALID_VALUE); - *attributes = &shm::g_attr; - if (ip_port == nullptr) { - SHM_LOG_ERROR("my_rank:" << my_rank << " ip_port is NULL!"); - return SHMEM_INVALID_PARAM; - } - // 安全警告:此处strlen依赖ip_port以null结尾,如果ip_port不是合法的C字符串,将导致越界读取 - if (ip_port == nullptr) { - SHM_LOG_ERROR("my_rank:" << my_rank << " ip_port is NULL!"); - return SHMEM_INVALID_PARAM; - } - // 安全警告:此处strlen依赖ip_port以null结尾,如果ip_port不是合法的C字符串,将导致越界读取 - size_t ip_len = strlen(ip_port); - shm::g_ipport = new (std::nothrow) char[ip_len + 1]; - if (shm::g_ipport == nullptr) { - SHM_LOG_ERROR("my_rank:" << my_rank << " failed to allocate IP port string!"); - return SHMEM_INNER_ERROR; - } - std::copy(ip_port, ip_port + ip_len + 1, shm::g_ipport); - if (shm::g_ipport == nullptr) { - SHM_LOG_ERROR("my_rank:" << my_rank << " shm::g_ipport is nullptr!"); - return SHMEM_INVALID_VALUE; - } - int attr_version = (1 << 16) + sizeof(shmem_init_attr_t); - shm::g_attr.my_rank = my_rank; - shm::g_attr.n_ranks = n_ranks; - shm::g_attr.ip_port = shm::g_ipport; - shm::g_attr.local_mem_size = local_mem_size; - shm::g_attr.option_attr = {attr_version, SHMEM_DATA_OP_MTE, shm::DEFAULT_TIMEOUT, - shm::DEFAULT_TIMEOUT, shm::DEFAULT_TIMEOUT}; - // shm::g_attr_init = true; - return SHMEM_SUCCESS; -} - -int32_t shmem_init_status() -{ - if (!shm::g_state.is_shmem_created) - return SHMEM_STATUS_NOT_INITIALIZED; - else if (!shm::g_state.is_shmem_initialized) - return SHMEM_STATUS_SHM_CREATED; - else if (shm::g_state.is_shmem_initialized) - return SHMEM_STATUS_IS_INITIALIZED; - else - return SHMEM_STATUS_INVALID; -} - -// void shmem_rank_exit(int status) -// { -// SHM_LOG_DEBUG("shmem_rank_exit is work ,status: " << status); -// exit(status); -// } - -int32_t shmem_init_attr(shmem_init_attr_t *attributes) -{ - int32_t ret; - init_manager = new init_mf(attributes, shm::g_ipport); - - SHM_ASSERT_RETURN(attributes != nullptr, SHMEM_INVALID_PARAM); - SHMEM_CHECK_RET(shmem_set_log_level(shm::ERROR_LEVEL)); - SHMEM_CHECK_RET(shm::check_attr(attributes)); - SHMEM_CHECK_RET(shm::version_compatible()); - SHMEM_CHECK_RET(shm::shmemi_options_init()); - - SHMEM_CHECK_RET(shm::shmemi_state_init_attr(attributes)); - SHMEM_CHECK_RET(init_manager->init_device_state()); - SHMEM_CHECK_RET(init_manager->heap_init(shm::g_state)); - // SHMEM_CHECK_RET(shm::shmemi_heap_init(attributes)); - SHMEM_CHECK_RET(update_device_state()); - - SHMEM_CHECK_RET(shm::memory_manager_initialize(shm::g_state.heap_base, shm::g_state.heap_size)); - SHMEM_CHECK_RET(shm::shmemi_team_init(shm::g_state.mype, shm::g_state.npes)); - SHMEM_CHECK_RET(update_device_state()); - SHMEM_CHECK_RET(shm::shmemi_sync_init()); - // SHMEM_CHECK_RET(smem_shm_register_exit(shm::g_smem_handle, &shmem_rank_exit)); - shm::g_state.is_shmem_initialized = true; - SHMEM_CHECK_RET(shmemi_control_barrier_all()); - return SHMEM_SUCCESS; -} - -int32_t shmem_register_decrypt_handler(const shmem_decrypt_handler handler) -{ - return smem_register_decrypt_handler(handler); -} - -int32_t shmem_set_extern_logger(void (*func)(int level, const char *msg)) -{ - shm::shm_out_logger::Instance().set_extern_log_func(func, true); - return smem_set_extern_logger(func); -} - -int32_t shmem_set_log_level(int level) -{ - // use env first, input level secondly, user may change level from env instead call func - const char *in_level = std::getenv("SHMEM_LOG_LEVEL"); - if (in_level != nullptr) { - auto tmp_level = std::string(in_level); - if (tmp_level == "DEBUG") { - level = shm::DEBUG_LEVEL; - } else if (tmp_level == "INFO") { - level = shm::INFO_LEVEL; - } else if (tmp_level == "WARN") { - level = shm::WARN_LEVEL; - } else if (tmp_level == "ERROR") { - level = shm::ERROR_LEVEL; - } else if (tmp_level == "FATAL") { - level = shm::FATAL_LEVEL; - } - } - shm::shm_out_logger::Instance().set_log_level(static_cast(level)); - return smem_set_log_level(level); -} - -int32_t shmem_set_conf_store_tls(bool enable, const char *tls_info, const uint32_t tls_info_len) -{ - return smem_set_conf_store_tls(enable, tls_info, tls_info_len); -} - -int32_t shmem_finalize() -{ - SHMEM_CHECK_RET(shm::shmemi_team_finalize()); - SHMEM_CHECK_RET(init_manager->heap_finalize()); - return SHMEM_SUCCESS; -} - -void shmem_info_get_version(int *major, int *minor) -{ - SHM_ASSERT_RET_VOID(major != nullptr && minor != nullptr); - *major = SHMEM_MAJOR_VERSION; - *minor = SHMEM_MINOR_VERSION; -} - -void shmem_info_get_name(char *name) -{ - SHM_ASSERT_RET_VOID(name != nullptr); - std::ostringstream oss; - oss << "SHMEM v" << SHMEM_VENDOR_MAJOR_VER << "." << SHMEM_VENDOR_MINOR_VER << "." << SHMEM_VENDOR_PATCH_VER; - auto version_str = oss.str(); - size_t i; - for (i = 0; i < SHMEM_MAX_NAME_LEN - 1 && version_str[i] != '\0'; i++) { - name[i] = version_str[i]; - } - name[i] = '\0'; -} - -// void shmem_global_exit(int status) -// { -// smem_shm_global_exit(shm::g_smem_handle, status); -// } diff --git a/src/host/init/shmemi_init.h b/src/host/init/shmemi_init.h index 7f461b9e..f9ae9d29 100644 --- a/src/host/init/shmemi_init.h +++ b/src/host/init/shmemi_init.h @@ -12,17 +12,20 @@ #include "stdint.h" #include "internal/host_device/shmemi_types.h" -#include "init/init_impl/shmemi_init_mf.h" -#include "init/init_impl/shmemi_init_mf.h" + +#ifdef BACKEND_MF +#include "init/init_backends/mf/shmemi_init_mf.h" +#else +#include "init/init_backends/normal/shmemi_init_normal.h" +#endif namespace shm { extern shmemi_device_host_state_t g_state; extern shmemi_host_state_t g_state_host; +} // namespace shm int32_t shmemi_control_barrier_all(); -} // namespace shm - int32_t update_device_state(void); #endif // SHMEMI_INIT_H diff --git a/src/host/mem/shmemi_global_state.cpp b/src/host/mem/shmemi_global_state.cpp new file mode 100644 index 00000000..a6d175eb --- /dev/null +++ b/src/host/mem/shmemi_global_state.cpp @@ -0,0 +1,32 @@ +#include "shmemi_global_state.h" + +global_state_reigister::global_state_reigister(int device_id): device_id_{device_id} +{ + halMemAddressReserve(&device_ptr, GLOBAL_STATE_SIZE, 0, (void *)(SVM_END_ADDR - GLOBAL_STATE_SIZE), 1); + + drv_mem_prop memprop; + memprop.side = 1; + memprop.devid = device_id_; + memprop.module_id = 0; + memprop.pg_type = 0; + memprop.mem_type = 0; + memprop.reserve = 0; + + halMemCreate(&alloc_handle, GLOBAL_STATE_SIZE, &memprop, 0); + + halMemMap(device_ptr, GLOBAL_STATE_SIZE, 0, alloc_handle, 0); +} + +global_state_reigister::~global_state_reigister() +{ + halMemUnmap(device_ptr); + + halMemRelease(alloc_handle); + + halMemAddressFree(device_ptr); +} + +void *global_state_reigister::get_ptr() +{ + return device_ptr; +} diff --git a/src/host/mem/shmemi_global_state.h b/src/host/mem/shmemi_global_state.h new file mode 100644 index 00000000..a9363238 --- /dev/null +++ b/src/host/mem/shmemi_global_state.h @@ -0,0 +1,29 @@ +#ifndef SHMEMI_GLOBAL_STATE_H +#define SHMEMI_GLOBAL_STATE_H + +#include +#include + +#include +#include + +#include "internal/host_device/shmemi_types.h" + +class global_state_reigister { +public: + global_state_reigister(); + global_state_reigister(int device_id); + + ~global_state_reigister(); + + void *get_ptr(); +private: + void *device_ptr = nullptr; + + drv_mem_handle_t *alloc_handle; + + int device_id_; +}; + + +#endif // SHMEMI_GLOBAL_STATE_H \ No newline at end of file diff --git a/src/host/mem/shmemi_heap.cpp b/src/host/mem/shmemi_heap.cpp new file mode 100644 index 00000000..67073bf2 --- /dev/null +++ b/src/host/mem/shmemi_heap.cpp @@ -0,0 +1,133 @@ +#include "shmem_heap.h" + +shmem_symmetric_heap::shmem_symmetric_heap(int pe_id, int pe_size): mype(pe_id), npes(pe_size) +{ + physical_handle_list.resize(pe_size); + share_handle_list.resize(pe_size); + pid_list.resize(pe_size); + + memprop.handleType = ACL_MEM_HANDLE_TYPE_NONE; + memprop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED; + memprop.memAttr = ACL_HBM_MEM_HUGE; + memprop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE; + memprop.location.id = pe_id; + memprop.reserve = 0; +} + +int shmem_symmetric_heap::reserve_heap(size_t size) +{ + int status = 0; + device_ptrs = (void **)std::calloc(npes, sizeof(void *)); + peer_heap_base_p2p_ = (void **)std::calloc(npes, sizeof(void *)); + + // reserve virtual ptrs + for (int i = 0; i < npes; i++) { + aclrtReserveMemAddress(&(device_ptrs[i]), size, 0, nullptr, 1); + peer_heap_base_p2p_[i] = device_ptrs[i]; + } + heap_base_ = peer_heap_base_p2p_[mype]; + + // alloc local physical memory + aclrtMallocPhysical(&local_handle, size, &memprop, 0); + + alloc_size = size; + return status; +} + +int shmem_symmetric_heap::export_memory() +{ + int status = 0; + // Get share_handle + status = aclrtMemExportToShareableHandle(local_handle, memprop.handleType, 0, &share_handle); + return status; +} + +int shmem_symmetric_heap::export_pid() +{ + int status = 0; + + // Get local pid + status = aclrtDeviceGetBareTgid(&my_pid); + return status; +} + +int shmem_symmetric_heap::import_pid() +{ + int status = 0; + + // Get all pids + g_boot_handle.allgather(&my_pid, pid_list.data(), 1 * sizeof(int), &g_boot_handle); + + // Add Pid into white list + std::vector share_pid = {}; + for (int i = 0; i < npes; i++) { + if (i == mype) { + continue; + } + share_pid.push_back(pid_list[i]); + } + + status = aclrtMemSetPidToShareableHandle(share_handle, share_pid.data(), npes - 1); + return status; +} + +int shmem_symmetric_heap::import_memory() +{ + int status = 0; + g_boot_handle.allgather(&share_handle, share_handle_list.data(), 1 * sizeof(uint64_t), &g_boot_handle); + for (int i = 0; i < npes; i++) { + if (i == mype) { + physical_handle_list[i] = local_handle; + continue; + } + status = aclrtMemImportFromShareableHandle(share_handle_list[i], mype, &(physical_handle_list[i])); + } + + return status; +} + +int shmem_symmetric_heap::setup_heap() +{ + int status = 0; + status = export_memory(); + status = export_pid(); + status = import_pid(); + status = import_memory(); + + // Shareable Handle Map + for (int i = 0; i < npes; i++) { + status = aclrtMapMem(device_ptrs[i], alloc_size, 0, physical_handle_list[i], 0); + } + return status; +} + +int shmem_symmetric_heap::remove_heap() +{ + int status = 0; + for (int i = 0; i < npes; i++) { + status = aclrtUnmapMem(device_ptrs[i]); + } + + return status; +} + +int shmem_symmetric_heap::unreserve_heap() +{ + int status = 0; + for (int i = 0; i < npes; i++) { + status = aclrtReleaseMemAddress(device_ptrs[i]); + } + + status = aclrtFreePhysical(local_handle); + return status; +} + +void *shmem_symmetric_heap::get_heap_base() +{ + return heap_base_; +} + +void *shmem_symmetric_heap::get_peer_heap_base_p2p(int pe_id) +{ + return peer_heap_base_p2p_[pe_id]; +} \ No newline at end of file diff --git a/src/host/mem/shmemi_heap.h b/src/host/mem/shmemi_heap.h new file mode 100644 index 00000000..ee543f72 --- /dev/null +++ b/src/host/mem/shmemi_heap.h @@ -0,0 +1,62 @@ +#ifndef SHMEMI_HEAP_H +#define SHMEMI_HEAP_H + +#include +#include +#include + +#include + +#include "common/shmemi_host_types.h" +#include "bootstrap/shmemi_bootstrap.h" + +class shmem_symmetric_heap { +public: + shmem_symmetric_heap() {} + shmem_symmetric_heap(int pe_id, int pe_size); + ~shmem_symmetric_heap() {}; + + int reserve_heap(size_t size); // aclrtReserveMemAddress && aclrtMallocPhysical + int unreserve_heap(); // halMemAddressFree && aclrtFreePhysical + + int setup_heap(); // export && import p2p memories && aclrtMapMem + int remove_heap(); // aclrtUnmapMem + + int *heap_alloc(); // ptr pretend alloc + int *heap_free(); // ptr pretend free + + void *get_heap_base(); // return heap_base_ + void *get_peer_heap_base_p2p(int pe_id); // peer_heap_base_p2p_ + +private: + int export_memory(); + int import_memory(); + + int export_pid(); + int import_pid(); + + int32_t mype; + int32_t npes; + + void **device_ptrs; + uint64_t alloc_size; + + void *heap_base_; + void **peer_heap_base_p2p_; + + // handle used to map local virtual ptr + aclrtPhysicalMemProp memprop; + aclrtDrvMemHandle local_handle; + std::vector physical_handle_list = {}; + + // pid used to set white list + int32_t my_pid = 0UL; + std::vector pid_list = {}; + + // handle used to share physical memory + uint64_t share_handle = 0UL; + std::vector share_handle_list = {}; +}; + + +#endif // SHMEMI_HEAP_H \ No newline at end of file -- Gitee From b95fde8993fa76d6c9e281b384b1f0b93d285a1c Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Tue, 23 Sep 2025 15:53:20 +0800 Subject: [PATCH 30/74] Reconstruct Example Allgather OK --- .../init/init_backends/shmemi_init_base.h | 4 ++- src/host/init/shmem_init.cpp | 26 +------------------ src/host/mem/shmemi_heap.cpp | 2 +- src/host/shmemi_host_common.h | 9 ------- 4 files changed, 5 insertions(+), 36 deletions(-) diff --git a/src/host/init/init_backends/shmemi_init_base.h b/src/host/init/init_backends/shmemi_init_base.h index 887aeb49..e30c7ab6 100644 --- a/src/host/init/init_backends/shmemi_init_base.h +++ b/src/host/init/init_backends/shmemi_init_base.h @@ -17,4 +17,6 @@ public: virtual ~init_base() {} -}; \ No newline at end of file +}; + +#endif // SHMEMI_INIT_BASE_H \ No newline at end of file diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index e025aa87..06360d67 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -83,30 +83,6 @@ int32_t shmemi_state_init_attr(shmem_init_attr_t *attributes) return status; } -<<<<<<< HEAD:src/host/init/shmem_init.cpp -======= -// TODO: use shmem native heap init -int32_t shmemi_heap_init(shmem_init_attr_t *attributes) { - // 申请Physical Mem, 申请Virtual Addr,MMAP - // 连续映射, VA_size = rank * PA_size - // 非连续映射, VA_size = PA_size - - return 0; -} - -// TODO: use shmem native barrier -int32_t shmemi_control_barrier_all() -{ - return g_boot_handle.barrier(&g_boot_handle); -} - -// TODO: use shmem native global state -int32_t update_device_state() -{ - return 0; -} - ->>>>>>> aee098b745adc448670cbf2097037ef86bf88d9a:src/host/init/shmem_init_default.cpp int32_t check_attr(shmem_init_attr_t *attributes) { if ((attributes->my_rank < 0) || (attributes->n_ranks <= 0)) { @@ -200,7 +176,7 @@ int32_t shmem_init_status() return SHMEM_STATUS_INVALID; } -int32_t shmem_init_attr(shmem_init_attr_t *attributes) +int32_t shmem_init_attr(shmemx_bootstrap_t bootstrap_flags, shmem_init_attr_t *attributes) { int32_t ret; diff --git a/src/host/mem/shmemi_heap.cpp b/src/host/mem/shmemi_heap.cpp index 67073bf2..2f8f3be1 100644 --- a/src/host/mem/shmemi_heap.cpp +++ b/src/host/mem/shmemi_heap.cpp @@ -1,4 +1,4 @@ -#include "shmem_heap.h" +#include "shmemi_heap.h" shmem_symmetric_heap::shmem_symmetric_heap(int pe_id, int pe_size): mype(pe_id), npes(pe_size) { diff --git a/src/host/shmemi_host_common.h b/src/host/shmemi_host_common.h index 907c01b0..ed872c2f 100644 --- a/src/host/shmemi_host_common.h +++ b/src/host/shmemi_host_common.h @@ -22,13 +22,4 @@ #include "bootstrap/shmemi_bootstrap.h" #include "transport/shmemi_transport.h" -// smem api -#include -#include -#include -#include -#include -#include -#include - #endif // SHMEM_SHMEMI_HOST_COMMON_H -- Gitee From 27fd9ad99c802f4e1f2b9d37da45f3d909e6ec4e Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Tue, 23 Sep 2025 16:14:54 +0800 Subject: [PATCH 31/74] Fix compile error --- include/host/shmem_host_def.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/host/shmem_host_def.h b/include/host/shmem_host_def.h index c5cbbb15..46600395 100644 --- a/include/host/shmem_host_def.h +++ b/include/host/shmem_host_def.h @@ -10,6 +10,7 @@ #ifndef SHMEM_HOST_DEF_H #define SHMEM_HOST_DEF_H #include +#include #include "host_device/shmem_types.h" #ifdef __cplusplus -- Gitee From 7c50d2cab17269cb2c51df91f2e1336d8d4cc21f Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Fri, 26 Sep 2025 17:50:45 +0800 Subject: [PATCH 32/74] transport Add 1.0 --- CMakeLists.txt | 2 +- examples/hellow_word/main.cpp | 22 ++- include/internal/host_device/shmemi_types.h | 1 + src/CMakeLists.txt | 17 +++ src/host/common/shmemi_host_types.h | 2 +- .../shmemi_init_default.cpp} | 66 +++++---- .../shmemi_init_default.h} | 29 +++- .../init/init_backends/mf/shmemi_init_mf.cpp | 24 +++- .../init/init_backends/mf/shmemi_init_mf.h | 9 +- .../init/init_backends/shmemi_init_base.h | 13 +- src/host/init/shmem_init.cpp | 47 +++++-- src/host/init/shmemi_init.h | 2 +- src/host/mem/shmemi_heap.cpp | 27 +++- src/host/mem/shmemi_heap.h | 3 +- src/host/transport/shmemi_transport.cpp | 133 ++++++++++++++---- src/host/transport/shmemi_transport.h | 8 +- src/modules/transport/shmemi_mte.cpp | 56 +++++--- 17 files changed, 344 insertions(+), 117 deletions(-) rename src/host/init/init_backends/{normal/shmemi_init_normal.cpp => default/shmemi_init_default.cpp} (37%) rename src/host/init/init_backends/{normal/shmemi_init_normal.h => default/shmemi_init_default.h} (43%) diff --git a/CMakeLists.txt b/CMakeLists.txt index f02c3046..f5dfb4ec 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -94,7 +94,7 @@ link_directories( link_libraries(runtime stdc++ ascendcl m tiling_api platform c_sec dl nnopbase ascend_hal pthread) # MF_BACKEND -set(USE_MF "1") +set(USE_MF "0") # 添加子目录 add_subdirectory(src) diff --git a/examples/hellow_word/main.cpp b/examples/hellow_word/main.cpp index 8ecf2220..3526a4fe 100644 --- a/examples/hellow_word/main.cpp +++ b/examples/hellow_word/main.cpp @@ -12,27 +12,35 @@ #include #include #include "shmem_api.h" + +const char *ipport = "tcp://127.0.0.1:8998"; +int f_rank = 0; +int f_npu = 0; + int main(int argc, char* argv[]) { - // 初始化MPI环境 + // MPI Init MPI_Init(&argc, &argv); + int rank_id, n_ranks; + MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); + MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); - // 获取当前进程的编号(rank) - int rank; - MPI_Comm_rank(MPI_COMM_WORLD, &rank); int status = SHMEM_SUCCESS; aclInit(nullptr); - aclrtSetDevice(rank); + aclrtSetDevice(rank_id); + + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + shmem_init_attr_t *attributes; + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); - status = shmem_finalize(); if ( status != SHMEM_SUCCESS) { std::cout << "[ERROR] demo run failed!" << std::endl; std::exit(status); } - aclrtResetDevice(rank); + aclrtResetDevice(rank_id); aclFinalize(); MPI_Finalize(); std::cout << "[SUCCESS] demo run success!" << std::endl; diff --git a/include/internal/host_device/shmemi_types.h b/include/internal/host_device/shmemi_types.h index ec46ecb1..7d86360d 100644 --- a/include/internal/host_device/shmemi_types.h +++ b/include/internal/host_device/shmemi_types.h @@ -90,6 +90,7 @@ typedef struct { uint64_t sync_counter; uint64_t core_sync_pool; uint64_t core_sync_counter; + uint64_t host_hash; bool is_shmem_initialized; bool is_shmem_created; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f98b1706..3f096b2a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -73,6 +73,23 @@ if ("${USE_MF}" STREQUAL "1") ) endif() +set(SHMEM_MTE_SUPPORT ON) +if(SHMEM_MTE_SUPPORT) + add_library(shmem_transport_mte SHARED) + + target_sources(shmem_transport_mte PRIVATE + modules/transport/shmemi_mte.cpp) + + target_include_directories(shmem_transport_mte PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src/host) + + set_target_properties(shmem_transport_mte PROPERTIES PREFIX "") + + install(TARGETS shmem_transport_mte + LIBRARY DESTINATION lib) +endif() + # MPI if(SHMEM_MPI_SUPPORT) separate_arguments(SHMEM_CXX_LINK_FLAGS NATIVE_COMMAND "${MPI_CXX_LINK_FLAGS}") diff --git a/src/host/common/shmemi_host_types.h b/src/host/common/shmemi_host_types.h index 9619470d..fa784722 100644 --- a/src/host/common/shmemi_host_types.h +++ b/src/host/common/shmemi_host_types.h @@ -92,5 +92,5 @@ typedef struct { int32_t num_choosen_transport; } shmemi_host_state_t; extern shmemi_bootstrap_handle_t g_boot_handle; -extern shmemi_host_state_t g_host_state; +// extern shmemi_host_state_t g_host_state; #endif // SHMEMI_HOST_TYPES_H \ No newline at end of file diff --git a/src/host/init/init_backends/normal/shmemi_init_normal.cpp b/src/host/init/init_backends/default/shmemi_init_default.cpp similarity index 37% rename from src/host/init/init_backends/normal/shmemi_init_normal.cpp rename to src/host/init/init_backends/default/shmemi_init_default.cpp index fc42ad50..6888aff7 100644 --- a/src/host/init/init_backends/normal/shmemi_init_normal.cpp +++ b/src/host/init/init_backends/default/shmemi_init_default.cpp @@ -1,70 +1,86 @@ -#include "shmemi_init_normal.h" +#include "shmemi_init_default.h" -global_state_reigister *global_state_d = nullptr; -shmem_symmetric_heap *heap_obj = nullptr; - -init_normal::init_normal(shmem_init_attr_t *attr) +shmemi_init_default::shmemi_init_default(shmem_init_attr_t *attr) { mype = attr->my_rank; npes = attr->n_ranks; - // EnablePeerAccess - for (int i = 0; i < npes; i++) { - if (i != mype) { - aclrtDeviceEnablePeerAccess(i, 0); - } - } + transport_map = (int *)calloc(npes * npes, sizeof(int)); } -init_normal::~init_normal() +shmemi_init_default::~shmemi_init_default() { finalize_device_state(); - heap_finalize(); + remove_heap(); + release_heap(); + transport_finalize(); } -int init_normal::init_device_state() +int shmemi_init_default::init_device_state() { global_state_d = new global_state_reigister(mype); return SHMEM_SUCCESS; } -int init_normal::finalize_device_state() +int shmemi_init_default::finalize_device_state() { delete global_state_d; return SHMEM_SUCCESS; } -int init_normal::update_device_state(void* host_ptr, size_t size) +int shmemi_init_default::update_device_state(void* host_ptr, size_t size) { int32_t status = SHMEM_SUCCESS; status = aclrtMemcpy(global_state_d->get_ptr(), size, host_ptr, size, ACL_MEMCPY_HOST_TO_DEVICE); return status; } -int init_normal::heap_init(shmemi_device_host_state_t &g_state) +int shmemi_init_default::reserve_heap(shmemi_device_host_state_t &g_state) { heap_obj = new shmem_symmetric_heap(mype, npes); heap_obj->reserve_heap(g_state.heap_size); + g_state.heap_base = heap_obj->get_heap_base(); - heap_obj->setup_heap(); + return SHMEM_SUCCESS; +} + +int shmemi_init_default::setup_heap(shmemi_device_host_state_t &g_state) +{ + heap_obj->setup_heap(transport_map); - // Assign host_global_state - g_state.heap_base = heap_obj->get_heap_base(); for (int32_t i = 0; i < g_state.npes; i++) { g_state.p2p_heap_base[i] = heap_obj->get_peer_heap_base_p2p(i); } g_state.is_shmem_created = true; + return SHMEM_SUCCESS; } -int init_normal::heap_finalize() +int shmemi_init_default::remove_heap() { - int32_t status = SHMEM_SUCCESS; - heap_obj->remove_heap(); + + return SHMEM_SUCCESS; +} + +int shmemi_init_default::release_heap() +{ heap_obj->unreserve_heap(); - delete heap_obj; - return status; + return SHMEM_SUCCESS; +} + +int shmemi_init_default::transport_init(shmemi_device_host_state_t &g_state) +{ + shmemi_transport_init(g_state); // mte init && rdma init + shmemi_build_transport_map(transport_map, g_state); // returns transport_map + shmemi_transport_setup_connections(transport_map, g_state); // connect_endpoints by transpost_map + return SHMEM_SUCCESS; +} + +int shmemi_init_default::transport_finalize() +{ + shmemi_transport_finalize(); + return SHMEM_SUCCESS; } \ No newline at end of file diff --git a/src/host/init/init_backends/normal/shmemi_init_normal.h b/src/host/init/init_backends/default/shmemi_init_default.h similarity index 43% rename from src/host/init/init_backends/normal/shmemi_init_normal.h rename to src/host/init/init_backends/default/shmemi_init_default.h index 08269270..826e15a1 100644 --- a/src/host/init/init_backends/normal/shmemi_init_normal.h +++ b/src/host/init/init_backends/default/shmemi_init_default.h @@ -7,26 +7,43 @@ #include "init/init_backends/shmemi_init_base.h" #include "host/shmem_host_def.h" +#include "internal/host_device/shmemi_types.h" + #include "mem/shmemi_global_state.h" #include "mem/shmemi_heap.h" + #include "bootstrap/shmemi_bootstrap.h" -#include "internal/host_device/shmemi_types.h" -class init_normal: public init_base { +#include "transport/shmemi_transport.h" + +class shmemi_init_default: public shmemi_init_base { public: - init_normal(shmem_init_attr_t *attr); - ~init_normal(); + shmemi_init_default(shmem_init_attr_t *attr); + ~shmemi_init_default(); int init_device_state() override; int finalize_device_state() override; int update_device_state(void* host_ptr, size_t size) override; - int heap_init(shmemi_device_host_state_t &g_state) override; - int heap_finalize() override; + int reserve_heap(shmemi_device_host_state_t &g_state) override; + int setup_heap(shmemi_device_host_state_t &g_state) override; + int remove_heap() override; + int release_heap() override; + + int transport_init(shmemi_device_host_state_t &g_state) override; + int transport_finalize() override; private: int mype; int npes; + // global_state + global_state_reigister *global_state_d = nullptr; + + // heap_obj + shmem_symmetric_heap *heap_obj = nullptr; + + // transport_map + int *transport_map = NULL; }; #endif // SHMEMI_INIT_NORMAL_H \ No newline at end of file diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.cpp b/src/host/init/init_backends/mf/shmemi_init_mf.cpp index f2ef068e..dfee3e78 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.cpp +++ b/src/host/init/init_backends/mf/shmemi_init_mf.cpp @@ -10,7 +10,7 @@ constexpr int DEFAULT_BLOCK_NUM = 1; static smem_shm_t g_smem_handle = nullptr; static char *g_ipport = nullptr; -init_mf::init_mf(shmem_init_attr_t *attr, char *ipport) +shmemi_init_mf::shmemi_init_mf(shmem_init_attr_t *attr, char *ipport) { attributes = attr; g_ipport = ipport; @@ -24,13 +24,13 @@ init_mf::init_mf(shmem_init_attr_t *attr, char *ipport) } } -init_mf::~init_mf() +shmemi_init_mf::~shmemi_init_mf() { finalize_device_state(); heap_finalize(); } -int init_mf::init_device_state() +int shmemi_init_mf::init_device_state() { int32_t status = SHMEM_SUCCESS; smem_shm_config_t config; @@ -53,7 +53,7 @@ int init_mf::init_device_state() return SHMEM_SUCCESS; } -int init_mf::update_device_state(void* host_ptr, size_t size) +int shmemi_init_mf::update_device_state(void* host_ptr, size_t size) { if (g_smem_handle == nullptr) { SHM_LOG_ERROR("smem_shm_create Not Success, update_device_state Failed"); @@ -62,13 +62,13 @@ int init_mf::update_device_state(void* host_ptr, size_t size) return smem_shm_set_extra_context(g_smem_handle, host_ptr, size); } -int init_mf::finalize_device_state() +int shmemi_init_mf::finalize_device_state() { // dummy function return SHMEM_SUCCESS; } -int init_mf::heap_init(shmemi_device_host_state_t &g_state) +int shmemi_init_mf::heap_init(shmemi_device_host_state_t &g_state) { int32_t status = SHMEM_SUCCESS; void *gva = nullptr; @@ -109,7 +109,7 @@ int init_mf::heap_init(shmemi_device_host_state_t &g_state) return status; } -int init_mf::heap_finalize() +int shmemi_init_mf::heap_finalize() { if (g_smem_handle != nullptr) { int32_t status = smem_shm_destroy(g_smem_handle, 0); @@ -122,4 +122,14 @@ int init_mf::heap_finalize() smem_shm_uninit(0); smem_uninit(); return SHMEM_SUCCESS; +} + +int shmemi_init_mf::transport_init() +{ + return SHMEM_SUCCESS; +} + +int shmemi_init_mf::transport_finalize() +{ + return SHMEM_SUCCESS; } \ No newline at end of file diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.h b/src/host/init/init_backends/mf/shmemi_init_mf.h index f163f3c8..952d5c8c 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.h +++ b/src/host/init/init_backends/mf/shmemi_init_mf.h @@ -17,10 +17,10 @@ #include #include -class init_mf: public init_base { +class shmemi_init_mf: public shmemi_init_base { public: - init_mf(shmem_init_attr_t *attr, char *ipport); - ~init_mf(); + shmemi_init_mf(shmem_init_attr_t *attr, char *ipport); + ~shmemi_init_mf(); int init_device_state() override; int finalize_device_state() override; @@ -28,6 +28,9 @@ public: int heap_init(shmemi_device_host_state_t &g_state) override; int heap_finalize() override; + + int transport_init() override; + int transport_finalize() override; private: int32_t device_id; diff --git a/src/host/init/init_backends/shmemi_init_base.h b/src/host/init/init_backends/shmemi_init_base.h index e30c7ab6..3f1aa394 100644 --- a/src/host/init/init_backends/shmemi_init_base.h +++ b/src/host/init/init_backends/shmemi_init_base.h @@ -6,16 +6,21 @@ #include "acl/acl.h" #include "internal/host_device/shmemi_types.h" -class init_base { +class shmemi_init_base { public: virtual int init_device_state() = 0; virtual int finalize_device_state() = 0; virtual int update_device_state(void* host_ptr, size_t size) = 0; - virtual int heap_init(shmemi_device_host_state_t &g_state) = 0; - virtual int heap_finalize() = 0; + virtual int reserve_heap(shmemi_device_host_state_t &g_state) = 0; + virtual int setup_heap(shmemi_device_host_state_t &g_state) = 0; + virtual int remove_heap() = 0; + virtual int release_heap() = 0; - virtual ~init_base() {} + virtual int transport_init(shmemi_device_host_state_t &g_state) = 0; + virtual int transport_finalize() = 0; + + virtual ~shmemi_init_base() {} }; diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index 06360d67..4918b85d 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -13,6 +13,10 @@ #include #include #include +#include +#include +#include + #include "acl/acl.h" #include "shmemi_host_common.h" @@ -45,6 +49,7 @@ constexpr int DEFAULT_BLOCK_NUM = 1; 0, /* sync_counter */ \ 0, /* core_sync_pool */ \ 0, /* core_sync_counter */ \ + 0, /* host_hash */ \ false, /* shmem_is_shmem_initialized */ \ false, /* shmem_is_shmem_created */ \ {0, 16 * 1024, 0}, /* shmem_mte_config */ \ @@ -68,12 +73,37 @@ int32_t shmemi_options_init() return status; } +uint64_t shmemi_get_host_hash() +{ + char hostname[128]; + struct hostent *he; + + if (gethostname(hostname, sizeof(hostname)) != 0) { + perror("gethostname"); + return 0; + } + + if ((he = gethostbyname(hostname)) == NULL) { + perror("gethostbyname"); + return 0; + } + + // Host IP Address + for (int i = 0; he->h_addr_list[i] != NULL; i++) { + char *ip = inet_ntoa(*(struct in_addr*)he->h_addr_list[i]); + } + + std::size_t host_hash = std::hash{}(hostname); + return host_hash; +} + int32_t shmemi_state_init_attr(shmem_init_attr_t *attributes) { int32_t status = SHMEM_SUCCESS; g_state.mype = attributes->my_rank; g_state.npes = attributes->n_ranks; g_state.heap_size = attributes->local_mem_size + SHMEM_EXTRA_SIZE; + g_state.host_hash = shmemi_get_host_hash(); aclrtStream stream = nullptr; SHMEM_CHECK_RET(aclrtCreateStream(&stream)); @@ -104,7 +134,7 @@ int32_t check_attr(shmem_init_attr_t *attributes) } // namespace shm -init_base* init_manager; +shmemi_init_base* init_manager; int32_t shmemi_control_barrier_all() { @@ -182,7 +212,6 @@ int32_t shmem_init_attr(shmemx_bootstrap_t bootstrap_flags, shmem_init_attr_t *a // config init SHM_ASSERT_RETURN(attributes != nullptr, SHMEM_INVALID_PARAM); - // SHMEM_CHECK_RET(shmem_set_log_level(shm::ERROR_LEVEL)); SHMEM_CHECK_RET(shm::check_attr(attributes)); SHMEM_CHECK_RET(shm::version_compatible()); SHMEM_CHECK_RET(shm::shmemi_options_init()); @@ -193,22 +222,22 @@ int32_t shmem_init_attr(shmemx_bootstrap_t bootstrap_flags, shmem_init_attr_t *a // shmem basic init #ifdef BACKEND_MF - init_manager = new init_mf(attributes, shm::g_ipport); + init_manager = new shmemi_init_mf(attributes, shm::g_ipport); #else - init_manager = new init_normal(attributes); + init_manager = new shmemi_init_default(attributes); #endif SHMEM_CHECK_RET(shm::shmemi_state_init_attr(attributes)); SHMEM_CHECK_RET(init_manager->init_device_state()); - SHMEM_CHECK_RET(init_manager->heap_init(shm::g_state)); - SHMEM_CHECK_RET(update_device_state()); - + SHMEM_CHECK_RET(init_manager->reserve_heap(shm::g_state)); + SHMEM_CHECK_RET(init_manager->transport_init(shm::g_state)); + SHMEM_CHECK_RET(init_manager->setup_heap(shm::g_state)); + // shmem submodules init SHMEM_CHECK_RET(shm::memory_manager_initialize(shm::g_state.heap_base, shm::g_state.heap_size)); SHMEM_CHECK_RET(shm::shmemi_team_init(shm::g_state.mype, shm::g_state.npes)); - SHMEM_CHECK_RET(update_device_state()); SHMEM_CHECK_RET(shm::shmemi_sync_init()); - // SHMEM_CHECK_RET(smem_shm_register_exit(shm::g_smem_handle, &shmem_rank_exit)); shm::g_state.is_shmem_initialized = true; + SHMEM_CHECK_RET(update_device_state()); SHMEM_CHECK_RET(shmemi_control_barrier_all()); return SHMEM_SUCCESS; } diff --git a/src/host/init/shmemi_init.h b/src/host/init/shmemi_init.h index f9ae9d29..b14c8cc5 100644 --- a/src/host/init/shmemi_init.h +++ b/src/host/init/shmemi_init.h @@ -16,7 +16,7 @@ #ifdef BACKEND_MF #include "init/init_backends/mf/shmemi_init_mf.h" #else -#include "init/init_backends/normal/shmemi_init_normal.h" +#include "init/init_backends/default/shmemi_init_default.h" #endif namespace shm { diff --git a/src/host/mem/shmemi_heap.cpp b/src/host/mem/shmemi_heap.cpp index 2f8f3be1..7aa9eb10 100644 --- a/src/host/mem/shmemi_heap.cpp +++ b/src/host/mem/shmemi_heap.cpp @@ -17,14 +17,15 @@ shmem_symmetric_heap::shmem_symmetric_heap(int pe_id, int pe_size): mype(pe_id), int shmem_symmetric_heap::reserve_heap(size_t size) { int status = 0; - device_ptrs = (void **)std::calloc(npes, sizeof(void *)); peer_heap_base_p2p_ = (void **)std::calloc(npes, sizeof(void *)); // reserve virtual ptrs for (int i = 0; i < npes; i++) { - aclrtReserveMemAddress(&(device_ptrs[i]), size, 0, nullptr, 1); - peer_heap_base_p2p_[i] = device_ptrs[i]; + peer_heap_base_p2p_[i] = NULL; } + + // reserve local heap_base_ + aclrtReserveMemAddress(&(peer_heap_base_p2p_[mype]), size, 0, nullptr, 1); heap_base_ = peer_heap_base_p2p_[mype]; // alloc local physical memory @@ -86,9 +87,21 @@ int shmem_symmetric_heap::import_memory() return status; } -int shmem_symmetric_heap::setup_heap() +int shmem_symmetric_heap::setup_heap(int *transport_map) { int status = 0; + + // MTE p2p_heap_base_ reserve + int local_offset = mype * npes; + for (int i = 0; i < npes; i++) { + if (i == mype) + continue; + + if (transport_map[local_offset + i] == 1) { + aclrtReserveMemAddress(&(peer_heap_base_p2p_[i]), alloc_size, 0, nullptr, 1); + } + } + status = export_memory(); status = export_pid(); status = import_pid(); @@ -96,7 +109,7 @@ int shmem_symmetric_heap::setup_heap() // Shareable Handle Map for (int i = 0; i < npes; i++) { - status = aclrtMapMem(device_ptrs[i], alloc_size, 0, physical_handle_list[i], 0); + status = aclrtMapMem(peer_heap_base_p2p_[i], alloc_size, 0, physical_handle_list[i], 0); } return status; } @@ -105,7 +118,7 @@ int shmem_symmetric_heap::remove_heap() { int status = 0; for (int i = 0; i < npes; i++) { - status = aclrtUnmapMem(device_ptrs[i]); + status = aclrtUnmapMem(peer_heap_base_p2p_[i]); } return status; @@ -115,7 +128,7 @@ int shmem_symmetric_heap::unreserve_heap() { int status = 0; for (int i = 0; i < npes; i++) { - status = aclrtReleaseMemAddress(device_ptrs[i]); + status = aclrtReleaseMemAddress(peer_heap_base_p2p_[i]); } status = aclrtFreePhysical(local_handle); diff --git a/src/host/mem/shmemi_heap.h b/src/host/mem/shmemi_heap.h index ee543f72..0248af5a 100644 --- a/src/host/mem/shmemi_heap.h +++ b/src/host/mem/shmemi_heap.h @@ -19,7 +19,7 @@ public: int reserve_heap(size_t size); // aclrtReserveMemAddress && aclrtMallocPhysical int unreserve_heap(); // halMemAddressFree && aclrtFreePhysical - int setup_heap(); // export && import p2p memories && aclrtMapMem + int setup_heap(int *transport_map); // export && import p2p memories && aclrtMapMem int remove_heap(); // aclrtUnmapMem int *heap_alloc(); // ptr pretend alloc @@ -38,7 +38,6 @@ private: int32_t mype; int32_t npes; - void **device_ptrs; uint64_t alloc_size; void *heap_base_; diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp index b96dc954..b5908ae9 100644 --- a/src/host/transport/shmemi_transport.cpp +++ b/src/host/transport/shmemi_transport.cpp @@ -1,40 +1,125 @@ #include "shmemi_host_common.h" +#include "dlfcn.h" -shmemi_host_state_t g_host_state; +#include "transport/shmemi_transport.h" + +// extern shmemi_device_host_state_t g_state; + +#define TRANSPORT_MODULE_MTE "shmem_transport_mte.so" + +static void *mte_plugin_hdl = nullptr; +static char *mte_plugin_name = nullptr; + +int (*shmemi_mte_init)(uint64_t *hash_list, int pe_id, int pe_size); +int (*shmemi_mte_can_access_peer)(int *access, int pe_id); +int (*shmemi_mte_connect_peers)(int *selected_dev_ids, int num_selected_devs); +int (*shmemi_mte_finalize)(); + +uint64_t *host_hash_list; + +void shmemi_transport_load() +{ + dlerror(); + if (mte_plugin_hdl == nullptr) { + + mte_plugin_hdl = dlopen(mte_plugin_name, RTLD_NOW); + } + dlerror(); +} + +void shmemi_transport_unload() +{ + if (mte_plugin_hdl != nullptr) { + dlclose(mte_plugin_hdl); + mte_plugin_hdl = nullptr; + } + + if (mte_plugin_name != nullptr) { + free(mte_plugin_name); + mte_plugin_name = nullptr; + } +} -int32_t shmemi_transport_init() { - uint32_t num_choosen_transport = 0; -// #ifdef SHMEM_CONTINUOUS_ADDRESS_SPACE -// g_host_state.choosen_transports[num_choosen_transport++] = &g_transport_mte_c; -// g_host_state.choosen_transports[num_choosen_transport++] = &g_transport_rdma_c; -// #else -// g_host_state.choosen_transports[num_choosen_transport++] = &g_transport_mte_d; -// g_host_state.choosen_transports[num_choosen_transport++] = &g_transport_rdma_d; -// #endif +int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state) { + int status = 0; - g_host_state.num_choosen_transport = num_choosen_transport; + uint32_t num_choosen_transport = 0; + + mte_plugin_name = TRANSPORT_MODULE_MTE; + shmemi_transport_load(); - for (int i = 0; i < num_choosen_transport; i++) { - auto t = g_host_state.choosen_transports + i; - t->boot_handle = g_host_state.boot_handle; + if (!mte_plugin_hdl) { + SHM_LOG_ERROR("Transport unable to load " << mte_plugin_name << ", err is: " << stderr); + shmemi_transport_unload(); + return SHMEM_INVALID_VALUE; } + + *((void **)&shmemi_mte_init) = dlsym(mte_plugin_hdl, "shmemi_mte_init"); + *((void **)&shmemi_mte_can_access_peer) = dlsym(mte_plugin_hdl, "shmemi_mte_can_access_peer"); + *((void **)&shmemi_mte_connect_peers) = dlsym(mte_plugin_hdl, "shmemi_mte_connect_peers"); + *((void **)&shmemi_mte_finalize) = dlsym(mte_plugin_hdl, "shmemi_mte_finalize"); + + host_hash_list = (uint64_t *)calloc(g_state.npes, sizeof(uint64_t)); + g_boot_handle.allgather(&g_state.host_hash, host_hash_list, 1 * sizeof(uint64_t), &g_boot_handle); + + status = shmemi_mte_init(host_hash_list, g_state.mype, g_state.npes); + + return status; } -int32_t shmemi_build_transport_map() { - // fill p2p/rdma/sdma heap bases +int32_t shmemi_build_transport_map(int *transport_map, shmemi_device_host_state_t &g_state) { + int status = SHMEM_SUCCESS; + + int *local_map = NULL; + local_map = (int *)calloc(g_state.npes, sizeof(int)); + + // Every selected transport must be access to all pe. + // If any can_reach_peer returns false, build_map should return failed. + for (int i = 0; i < g_state.npes; i++) { + int num_choosen_transport = 1; // now only mte. + + // Loop can_access_peer, j = 0 means MTE, j = 1 means RDMA ... + for (int j = 0; j < num_choosen_transport; j++) { + int reach = 0; + // Judge mte peer access + status = shmemi_mte_can_access_peer(&reach, i); + + if (reach) { + int m = 1 << j; + local_map[i] |= m; + } + } + } + + g_boot_handle.allgather(local_map, transport_map, g_state.npes * sizeof(int), &g_boot_handle); + + if (local_map) free(local_map); + return status; } -int32_t shmemi_transport_setup_connections() { - for (int i = 0; i < g_host_state.num_choosen_transport; i++) { - auto t = g_host_state.choosen_transports + i; - t->connect_peers(t, nullptr, 0); +int32_t shmemi_transport_setup_connections(int *transport_map, shmemi_device_host_state_t &g_state) { + + int *mte_peer_list; + int mte_peer_num = 0; + mte_peer_list = (int *)calloc(g_state.npes, sizeof(int)); + + int local_offset = g_state.mype * g_state.npes; + for (int i = 0; i < g_state.npes; i++) { + if (i == g_state.mype) + continue; + if (transport_map[i] == 1) { + mte_peer_list[mte_peer_num] = i; + ++mte_peer_num; + } } + + shmemi_mte_connect_peers(mte_peer_list, mte_peer_num); + + return 0; } int32_t shmemi_transport_finalize() { - for (int i = g_host_state.num_choosen_transport - 1; i >= 0; i--) { - auto t = g_host_state.choosen_transports + i; - t->finalize(t); - } + dlclose(mte_plugin_hdl); + return 0; } diff --git a/src/host/transport/shmemi_transport.h b/src/host/transport/shmemi_transport.h index c08ecb6e..0fb09e70 100644 --- a/src/host/transport/shmemi_transport.h +++ b/src/host/transport/shmemi_transport.h @@ -1,12 +1,12 @@ #ifndef SHMEMI_TRANSPORT_H #define SHMEMI_TRANSPORT_H -int32_t shmemi_transport_init(); +int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state); -int32_t shmemi_build_transport_map(); +int32_t shmemi_build_transport_map(int *transport_map, shmemi_device_host_state_t &g_state); -int32_t shmemi_transport_setup_connections(); +int32_t shmemi_transport_setup_connections(int *transport_map, shmemi_device_host_state_t &g_state); int32_t shmemi_transport_finalize(); -#endif \ No newline at end of file +#endif // SHMEMI_TRANSPORT_H \ No newline at end of file diff --git a/src/modules/transport/shmemi_mte.cpp b/src/modules/transport/shmemi_mte.cpp index a8a6e629..97257bfe 100644 --- a/src/modules/transport/shmemi_mte.cpp +++ b/src/modules/transport/shmemi_mte.cpp @@ -7,35 +7,59 @@ * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. * See LICENSE in the root of the software repository for the full text of the License. */ -#include "shmemi_transport.h" -typedef struct { +#include +#include +#include +#include -} shmemi_mted_transport_state_t; +#include "host/shmem_host_def.h" +#include "internal/host_device/shmemi_types.h" +#include "transport/shmemi_transport.h" -static shmemi_mted_transport_state_t shmemi_mted_transport_state; +#ifdef __cplusplus +extern "C" { +#endif + +static uint64_t *host_hash_list; +static int mype; +static int npes; // control plane -int shmemi_mted_init(shmemi_host_state_t *state, shmemi_transport_t *t) { +int shmemi_mte_init(uint64_t *hash_list, int pe_id, int pe_size) { + + host_hash_list = hash_list; + mype = pe_id; + npes = pe_size; + return 0; } -int shmemi_mted_can_access_peer(int *access, shmemi_transport_pe_info_t *peer, shmemi_transport_t *t) { - // host相同——true,否则false +int shmemi_mte_can_access_peer(int *access, int pe_id) { + // host_id same return 1, otherwise 0 + if (host_hash_list[mype] == host_hash_list[pe_id]) { + *access = 1; + } else { + *access = 0; + } + return 0; } -int shmemi_mted_connect_peers(shmemi_transport_t *t, int *selected_dev_ids, int num_selected_devs) { +int shmemi_mte_connect_peers(int *selected_dev_ids, int num_selected_devs) { + + // EnablePeerAccess + for (int i = 0; i < num_selected_devs; i++) { + aclrtDeviceEnablePeerAccess(selected_dev_ids[i], 0); + } + return 0; } -int shmemi_mted_finalize(shmemi_transport_t *t) { +int shmemi_mte_finalize() { + return 0; } - -shmemi_transport_t shmemi_mted_transport_state = { - .init = shmemi_mted_init, - .finalize = shmemi_mted_finalize, - .can_access_peer = shmemi_mted_can_access_peer, - .connect_peers = shmemi_mted_connect_peers, -} \ No newline at end of file +#ifdef __cplusplus +} +#endif \ No newline at end of file -- Gitee From d7c788b18bebdd6e815e08cbbbf7ae8153efea4c Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Sun, 28 Sep 2025 11:20:13 +0800 Subject: [PATCH 33/74] transport Add 1.0 adapt MF --- .../init/init_backends/mf/shmemi_init_mf.cpp | 21 +++++++++++++++---- .../init/init_backends/mf/shmemi_init_mf.h | 8 ++++--- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.cpp b/src/host/init/init_backends/mf/shmemi_init_mf.cpp index dfee3e78..f35e2ab9 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.cpp +++ b/src/host/init/init_backends/mf/shmemi_init_mf.cpp @@ -27,7 +27,8 @@ shmemi_init_mf::shmemi_init_mf(shmem_init_attr_t *attr, char *ipport) shmemi_init_mf::~shmemi_init_mf() { finalize_device_state(); - heap_finalize(); + remove_heap(); + release_heap(); } int shmemi_init_mf::init_device_state() @@ -68,7 +69,7 @@ int shmemi_init_mf::finalize_device_state() return SHMEM_SUCCESS; } -int shmemi_init_mf::heap_init(shmemi_device_host_state_t &g_state) +int shmemi_init_mf::reserve_heap(shmemi_device_host_state_t &g_state) { int32_t status = SHMEM_SUCCESS; void *gva = nullptr; @@ -109,7 +110,19 @@ int shmemi_init_mf::heap_init(shmemi_device_host_state_t &g_state) return status; } -int shmemi_init_mf::heap_finalize() +int shmemi_init_mf::setup_heap(shmemi_device_host_state_t &g_state) +{ + int32_t status = SHMEM_SUCCESS; + return status; +} + +int shmemi_init_mf::remove_heap() +{ + int32_t status = SHMEM_SUCCESS; + return status; +} + +int shmemi_init_mf::release_heap() { if (g_smem_handle != nullptr) { int32_t status = smem_shm_destroy(g_smem_handle, 0); @@ -124,7 +137,7 @@ int shmemi_init_mf::heap_finalize() return SHMEM_SUCCESS; } -int shmemi_init_mf::transport_init() +int shmemi_init_mf::transport_init(shmemi_device_host_state_t &g_state) { return SHMEM_SUCCESS; } diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.h b/src/host/init/init_backends/mf/shmemi_init_mf.h index 952d5c8c..455e5cc9 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.h +++ b/src/host/init/init_backends/mf/shmemi_init_mf.h @@ -26,10 +26,12 @@ public: int finalize_device_state() override; int update_device_state(void* host_ptr, size_t size) override; - int heap_init(shmemi_device_host_state_t &g_state) override; - int heap_finalize() override; + int reserve_heap(shmemi_device_host_state_t &g_state) override; + int setup_heap(shmemi_device_host_state_t &g_state) override; + int remove_heap() override; + int release_heap() override; - int transport_init() override; + int transport_init(shmemi_device_host_state_t &g_state) override; int transport_finalize() override; private: int32_t device_id; -- Gitee From 301435208f8db1fb090ac114ec6bb90970849db2 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Mon, 29 Sep 2025 19:29:24 +0800 Subject: [PATCH 34/74] Decouple driver && Add dynamic load hal --- CMakeLists.txt | 6 +-- scripts/set_env.sh | 1 + src/host/mem/shmemi_global_state.cpp | 71 +++++++++++++++++++++++++--- src/host/mem/shmemi_global_state.h | 13 ++++- 4 files changed, 79 insertions(+), 12 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f5dfb4ec..0350eae2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,8 +37,6 @@ message(STATUS "USE_EXAMPLES:${USE_EXAMPLES}") option(USE_FUZZ_TEST "USE_FUZZ_TEST" OFF) message(STATUS "USE_FUZZ_TEST:${USE_FUZZ_TEST}") -set(ASCEND_DRIVER_PATH /usr/local/Ascend/driver) - set(CMAKE_COMPILER bisheng) set(CMAKE_C_COMPILER ${CMAKE_COMPILER}) set(CMAKE_CXX_COMPILER ${CMAKE_COMPILER}) @@ -83,15 +81,13 @@ include_directories( ${ASCEND_HOME_PATH}/include ${ASCEND_HOME_PATH}/include/experiment/runtime ${ASCEND_HOME_PATH}/include/experiment/msprof - ${ASCEND_DRIVER_PATH}/kernel/inc ) link_directories( ${ASCEND_HOME_PATH}/lib64 - ${ASCEND_DRIVER_PATH}/lib64/driver ) -link_libraries(runtime stdc++ ascendcl m tiling_api platform c_sec dl nnopbase ascend_hal pthread) +link_libraries(runtime stdc++ ascendcl m tiling_api platform c_sec dl nnopbase pthread) # MF_BACKEND set(USE_MF "0") diff --git a/scripts/set_env.sh b/scripts/set_env.sh index d7fd60ac..41181a88 100644 --- a/scripts/set_env.sh +++ b/scripts/set_env.sh @@ -13,5 +13,6 @@ if [[ -f "$set_env_path" ]] && [[ "$set_env_path" =~ "set_env.sh" ]]; then shmem_path=$(cd $(dirname $set_env_path); pwd) export SHMEM_HOME_PATH="$shmem_path" export LD_LIBRARY_PATH=$SHMEM_HOME_PATH/shmem/lib:$SHMEM_HOME_PATH/memfabric_hybrid/lib:$LD_LIBRARY_PATH + export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver/:$LD_LIBRARY_PATH export PATH=$SHMEM_HOME_PATH/bin:$PATH fi \ No newline at end of file diff --git a/src/host/mem/shmemi_global_state.cpp b/src/host/mem/shmemi_global_state.cpp index a6d175eb..c091f38b 100644 --- a/src/host/mem/shmemi_global_state.cpp +++ b/src/host/mem/shmemi_global_state.cpp @@ -1,8 +1,64 @@ +#include +#include #include "shmemi_global_state.h" +#define HAL_LOAD_SYM(TARGET_FUNC, FILE_HANDLE, SYMBOL_NAME) \ + dlerror(); \ + *((void **)&TARGET_FUNC) = dlsym(FILE_HANDLE, SYMBOL_NAME); \ + error = dlerror(); \ + if (error != NULL) { \ + fprintf(stderr, "dlsym failed: %s\n", error); \ + dlclose(hal_handle); \ + } + +int (*halMemAddressReserveFunc)(void **ptr, size_t size, size_t alignment, void *addr, uint64_t flag); +int (*halMemAddressFreeFunc)(void *ptr); +int (*halMemCreateFunc)(drv_mem_handle_t **handle, size_t size, const struct drv_mem_prop *prop, uint64_t flag); +int (*halMemReleaseFunc)(drv_mem_handle_t *handle); +int (*halMemMapFunc)(void *ptr, size_t size, size_t offset, drv_mem_handle_t *handle, uint64_t flag); +int (*halMemUnmapFunc)(void *ptr); + +std::mutex g_mutex; +bool g_loaded = false; +static void *hal_handle; +const char *g_hal_lib_name = "libascend_hal.so"; + +int32_t load_hal_library() +{ + char *error; + std::lock_guard guard(g_mutex); + if (g_loaded) { + return 0; + } + + dlerror(); + + hal_handle = dlopen(g_hal_lib_name, RTLD_NOW); + if (!hal_handle) { + fprintf(stderr, "dlopen failed: %s\n", dlerror()); + return 1; + } + + HAL_LOAD_SYM(halMemAddressReserveFunc, hal_handle, "halMemAddressReserve"); + HAL_LOAD_SYM(halMemAddressFreeFunc, hal_handle, "halMemAddressFree"); + HAL_LOAD_SYM(halMemCreateFunc, hal_handle, "halMemCreate"); + HAL_LOAD_SYM(halMemReleaseFunc, hal_handle, "halMemRelease"); + HAL_LOAD_SYM(halMemMapFunc, hal_handle, "halMemMap"); + HAL_LOAD_SYM(halMemUnmapFunc, hal_handle, "halMemUnmap"); + + g_loaded = true; + return 0; +} + global_state_reigister::global_state_reigister(int device_id): device_id_{device_id} { - halMemAddressReserve(&device_ptr, GLOBAL_STATE_SIZE, 0, (void *)(SVM_END_ADDR - GLOBAL_STATE_SIZE), 1); + int32_t status = load_hal_library(); + if (status != 0) { + std::cout << "load_hal_library failed " << std::endl; + return; + } + + halMemAddressReserveFunc(&device_ptr, GLOBAL_STATE_SIZE, 0, (void *)(SVM_END_ADDR - GLOBAL_STATE_SIZE), 1); drv_mem_prop memprop; memprop.side = 1; @@ -12,18 +68,21 @@ global_state_reigister::global_state_reigister(int device_id): device_id_{device memprop.mem_type = 0; memprop.reserve = 0; - halMemCreate(&alloc_handle, GLOBAL_STATE_SIZE, &memprop, 0); + halMemCreateFunc(&alloc_handle, GLOBAL_STATE_SIZE, &memprop, 0); - halMemMap(device_ptr, GLOBAL_STATE_SIZE, 0, alloc_handle, 0); + halMemMapFunc(device_ptr, GLOBAL_STATE_SIZE, 0, alloc_handle, 0); } global_state_reigister::~global_state_reigister() { - halMemUnmap(device_ptr); + halMemUnmapFunc(device_ptr); + + halMemReleaseFunc(alloc_handle); - halMemRelease(alloc_handle); + halMemAddressFreeFunc(device_ptr); - halMemAddressFree(device_ptr); + if (hal_handle != nullptr) + dlclose(hal_handle); } void *global_state_reigister::get_ptr() diff --git a/src/host/mem/shmemi_global_state.h b/src/host/mem/shmemi_global_state.h index a9363238..2119fbf9 100644 --- a/src/host/mem/shmemi_global_state.h +++ b/src/host/mem/shmemi_global_state.h @@ -5,10 +5,21 @@ #include #include -#include #include "internal/host_device/shmemi_types.h" +typedef struct drv_mem_handle drv_mem_handle_t; + +struct drv_mem_prop { + uint32_t side; + uint32_t devid; + uint32_t module_id; + + uint32_t pg_type; + uint32_t mem_type; + uint64_t reserve; +}; + class global_state_reigister { public: global_state_reigister(); -- Gitee From 1fb7ba6c8b13eb44af9c26f5f53fc870f097ce52 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Mon, 29 Sep 2025 20:04:14 +0800 Subject: [PATCH 35/74] Remove namespace shm --- src/host/init/shmem_init.cpp | 64 ++++++++++++++++----------------- src/host/init/shmemi_init.h | 2 -- src/host/mem/shmem_mm.cpp | 28 +++++++-------- src/host/mem/shmem_rma.cpp | 52 +++++++++++++-------------- src/host/mem/shmemi_mm.h | 2 -- src/host/mem/shmemi_mm_heap.cpp | 8 ++--- src/host/mem/shmemi_mm_heap.h | 2 -- src/host/sync/shmemi_sync.cpp | 5 +-- src/host/sync/shmemi_sync.h | 4 --- src/host/team/shmem_team.cpp | 49 ++++++++++++------------- src/host/team/shmemi_team.h | 4 --- 11 files changed, 96 insertions(+), 124 deletions(-) diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index 4918b85d..bd0f0805 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -22,8 +22,6 @@ using namespace std; -namespace shm { - #define DEFAULT_MY_PE (-1) #define DEFAULT_N_PES (-1) @@ -132,8 +130,6 @@ int32_t check_attr(shmem_init_attr_t *attributes) return SHMEM_SUCCESS; } -} // namespace shm - shmemi_init_base* init_manager; int32_t shmemi_control_barrier_all() @@ -143,7 +139,7 @@ int32_t shmemi_control_barrier_all() int32_t update_device_state() { - return init_manager->update_device_state((void *)&shm::g_state, sizeof(shmemi_device_host_state_t)); + return init_manager->update_device_state((void *)&g_state, sizeof(shmemi_device_host_state_t)); } int32_t shmem_set_data_op_engine_type(shmem_init_attr_t *attributes, data_op_engine_type_t value) @@ -166,41 +162,41 @@ int32_t shmem_set_attr(int32_t my_rank, int32_t n_ranks, uint64_t local_mem_size SHM_ASSERT_RETURN(local_mem_size <= SHMEM_MAX_LOCAL_SIZE, SHMEM_INVALID_VALUE); SHM_ASSERT_RETURN(n_ranks <= SHMEM_MAX_RANKS, SHMEM_INVALID_VALUE); SHM_ASSERT_RETURN(my_rank <= SHMEM_MAX_RANKS, SHMEM_INVALID_VALUE); - *attributes = &shm::g_attr; + *attributes = &g_attr; if (ip_port == nullptr) { SHM_LOG_ERROR("my_rank:" << my_rank << " ip_port is NULL!"); return SHMEM_INVALID_PARAM; } // 安全警告:此处strlen依赖ip_port以null结尾,如果ip_port不是合法的C字符串,将导致越界读取 size_t ip_len = strlen(ip_port); - shm::g_ipport = new (std::nothrow) char[ip_len + 1]; - if (shm::g_ipport == nullptr) { + g_ipport = new (std::nothrow) char[ip_len + 1]; + if (g_ipport == nullptr) { SHM_LOG_ERROR("my_rank:" << my_rank << " failed to allocate IP port string!"); return SHMEM_INNER_ERROR; } - std::copy(ip_port, ip_port + ip_len + 1, shm::g_ipport); - if (shm::g_ipport == nullptr) { - SHM_LOG_ERROR("my_rank:" << my_rank << " shm::g_ipport is nullptr!"); + std::copy(ip_port, ip_port + ip_len + 1, g_ipport); + if (g_ipport == nullptr) { + SHM_LOG_ERROR("my_rank:" << my_rank << " g_ipport is nullptr!"); return SHMEM_INVALID_VALUE; } int attr_version = (1 << 16) + sizeof(shmem_init_attr_t); - shm::g_attr.my_rank = my_rank; - shm::g_attr.n_ranks = n_ranks; - shm::g_attr.ip_port = shm::g_ipport; - shm::g_attr.local_mem_size = local_mem_size; - shm::g_attr.option_attr = {attr_version, SHMEM_DATA_OP_MTE, shm::DEFAULT_TIMEOUT, - shm::DEFAULT_TIMEOUT, shm::DEFAULT_TIMEOUT}; - // shm::g_attr_init = true; + g_attr.my_rank = my_rank; + g_attr.n_ranks = n_ranks; + g_attr.ip_port = g_ipport; + g_attr.local_mem_size = local_mem_size; + g_attr.option_attr = {attr_version, SHMEM_DATA_OP_MTE, DEFAULT_TIMEOUT, + DEFAULT_TIMEOUT, DEFAULT_TIMEOUT}; + // g_attr_init = true; return SHMEM_SUCCESS; } int32_t shmem_init_status() { - if (!shm::g_state.is_shmem_created) + if (!g_state.is_shmem_created) return SHMEM_STATUS_NOT_INITIALIZED; - else if (!shm::g_state.is_shmem_initialized) + else if (!g_state.is_shmem_initialized) return SHMEM_STATUS_SHM_CREATED; - else if (shm::g_state.is_shmem_initialized) + else if (g_state.is_shmem_initialized) return SHMEM_STATUS_IS_INITIALIZED; else return SHMEM_STATUS_INVALID; @@ -212,9 +208,9 @@ int32_t shmem_init_attr(shmemx_bootstrap_t bootstrap_flags, shmem_init_attr_t *a // config init SHM_ASSERT_RETURN(attributes != nullptr, SHMEM_INVALID_PARAM); - SHMEM_CHECK_RET(shm::check_attr(attributes)); - SHMEM_CHECK_RET(shm::version_compatible()); - SHMEM_CHECK_RET(shm::shmemi_options_init()); + SHMEM_CHECK_RET(check_attr(attributes)); + SHMEM_CHECK_RET(version_compatible()); + SHMEM_CHECK_RET(shmemi_options_init()); // bootstrap init shmemi_bootstrap_attr_t attr = {}; @@ -222,21 +218,21 @@ int32_t shmem_init_attr(shmemx_bootstrap_t bootstrap_flags, shmem_init_attr_t *a // shmem basic init #ifdef BACKEND_MF - init_manager = new shmemi_init_mf(attributes, shm::g_ipport); + init_manager = new shmemi_init_mf(attributes, g_ipport); #else init_manager = new shmemi_init_default(attributes); #endif - SHMEM_CHECK_RET(shm::shmemi_state_init_attr(attributes)); + SHMEM_CHECK_RET(shmemi_state_init_attr(attributes)); SHMEM_CHECK_RET(init_manager->init_device_state()); - SHMEM_CHECK_RET(init_manager->reserve_heap(shm::g_state)); - SHMEM_CHECK_RET(init_manager->transport_init(shm::g_state)); - SHMEM_CHECK_RET(init_manager->setup_heap(shm::g_state)); + SHMEM_CHECK_RET(init_manager->reserve_heap(g_state)); + SHMEM_CHECK_RET(init_manager->transport_init(g_state)); + SHMEM_CHECK_RET(init_manager->setup_heap(g_state)); // shmem submodules init - SHMEM_CHECK_RET(shm::memory_manager_initialize(shm::g_state.heap_base, shm::g_state.heap_size)); - SHMEM_CHECK_RET(shm::shmemi_team_init(shm::g_state.mype, shm::g_state.npes)); - SHMEM_CHECK_RET(shm::shmemi_sync_init()); - shm::g_state.is_shmem_initialized = true; + SHMEM_CHECK_RET(memory_manager_initialize(g_state.heap_base, g_state.heap_size)); + SHMEM_CHECK_RET(shmemi_team_init(g_state.mype, g_state.npes)); + SHMEM_CHECK_RET(shmemi_sync_init()); + g_state.is_shmem_initialized = true; SHMEM_CHECK_RET(update_device_state()); SHMEM_CHECK_RET(shmemi_control_barrier_all()); return SHMEM_SUCCESS; @@ -244,7 +240,7 @@ int32_t shmem_init_attr(shmemx_bootstrap_t bootstrap_flags, shmem_init_attr_t *a int32_t shmem_finalize() { - SHMEM_CHECK_RET(shm::shmemi_team_finalize()); + SHMEM_CHECK_RET(shmemi_team_finalize()); delete init_manager; shmemi_bootstrap_finalize(); diff --git a/src/host/init/shmemi_init.h b/src/host/init/shmemi_init.h index b14c8cc5..256542a6 100644 --- a/src/host/init/shmemi_init.h +++ b/src/host/init/shmemi_init.h @@ -19,10 +19,8 @@ #include "init/init_backends/default/shmemi_init_default.h" #endif -namespace shm { extern shmemi_device_host_state_t g_state; extern shmemi_host_state_t g_state_host; -} // namespace shm int32_t shmemi_control_barrier_all(); diff --git a/src/host/mem/shmem_mm.cpp b/src/host/mem/shmem_mm.cpp index 6455c624..e522af16 100644 --- a/src/host/mem/shmem_mm.cpp +++ b/src/host/mem/shmem_mm.cpp @@ -12,7 +12,6 @@ #include "shmemi_host_common.h" #include "shmemi_mm_heap.h" -namespace shm { namespace { std::shared_ptr shm_memory_heap; } @@ -30,22 +29,21 @@ void memory_manager_destroy() { shm_memory_heap.reset(); } -} // namespace shm void *shmem_malloc(size_t size) { - if (shm::shm_memory_heap == nullptr) { + if (shm_memory_heap == nullptr) { SHM_LOG_ERROR("Memory Heap Not Initialized."); return nullptr; } - void *ptr = shm::shm_memory_heap->allocate(size); + void *ptr = shm_memory_heap->allocate(size); SHM_LOG_DEBUG("shmem_malloc(" << size << ")"); auto ret = shmemi_control_barrier_all(); if (ret != 0) { SHM_LOG_ERROR("malloc mem barrier failed, ret: " << ret); if (ptr != nullptr) { - shm::shm_memory_heap->release(ptr); + shm_memory_heap->release(ptr); ptr = nullptr; } } @@ -54,19 +52,19 @@ void *shmem_malloc(size_t size) void *shmem_calloc(size_t nmemb, size_t size) { - if (shm::shm_memory_heap == nullptr) { + if (shm_memory_heap == nullptr) { SHM_LOG_ERROR("Memory Heap Not Initialized."); return nullptr; } - SHM_ASSERT_MULTIPLY_OVERFLOW(nmemb, size, shm::g_state.heap_size, nullptr); + SHM_ASSERT_MULTIPLY_OVERFLOW(nmemb, size, g_state.heap_size, nullptr); auto total_size = nmemb * size; - auto ptr = shm::shm_memory_heap->allocate(total_size); + auto ptr = shm_memory_heap->allocate(total_size); if (ptr != nullptr) { auto ret = aclrtMemset(ptr, size, 0, size); if (ret != 0) { SHM_LOG_ERROR("shmem_calloc(" << nmemb << ", " << size << ") memset failed: " << ret); - shm::shm_memory_heap->release(ptr); + shm_memory_heap->release(ptr); ptr = nullptr; } } @@ -75,7 +73,7 @@ void *shmem_calloc(size_t nmemb, size_t size) if (ret != 0) { SHM_LOG_ERROR("calloc mem barrier failed, ret: " << ret); if (ptr != nullptr) { - shm::shm_memory_heap->release(ptr); + shm_memory_heap->release(ptr); ptr = nullptr; } } @@ -86,17 +84,17 @@ void *shmem_calloc(size_t nmemb, size_t size) void *shmem_align(size_t alignment, size_t size) { - if (shm::shm_memory_heap == nullptr) { + if (shm_memory_heap == nullptr) { SHM_LOG_ERROR("Memory Heap Not Initialized."); return nullptr; } - auto ptr = shm::shm_memory_heap->aligned_allocate(alignment, size); + auto ptr = shm_memory_heap->aligned_allocate(alignment, size); auto ret = shmemi_control_barrier_all(); if (ret != 0) { SHM_LOG_ERROR("shmem_align barrier failed, ret: " << ret); if (ptr != nullptr) { - shm::shm_memory_heap->release(ptr); + shm_memory_heap->release(ptr); ptr = nullptr; } } @@ -106,7 +104,7 @@ void *shmem_align(size_t alignment, size_t size) void shmem_free(void *ptr) { - if (shm::shm_memory_heap == nullptr) { + if (shm_memory_heap == nullptr) { SHM_LOG_ERROR("Memory Heap Not Initialized."); return; } @@ -114,7 +112,7 @@ void shmem_free(void *ptr) return; } - auto ret = shm::shm_memory_heap->release(ptr); + auto ret = shm_memory_heap->release(ptr); if (ret != 0) { SHM_LOG_ERROR("release failed: " << ret); } diff --git a/src/host/mem/shmem_rma.cpp b/src/host/mem/shmem_rma.cpp index 635012b6..7a6349c4 100644 --- a/src/host/mem/shmem_rma.cpp +++ b/src/host/mem/shmem_rma.cpp @@ -21,15 +21,15 @@ void *shmem_ptr(void *ptr, int32_t pe) SHM_LOG_ERROR("shmem_ptr Failed. PE: " << shmem_my_pe() << " Got Ilegal PE !!"); return nullptr; } - uint64_t lower_bound = (uint64_t)shm::g_state.heap_base; - uint64_t upper_bound = lower_bound + shm::g_state.heap_size; + uint64_t lower_bound = (uint64_t)g_state.heap_base; + uint64_t upper_bound = lower_bound + g_state.heap_size; if (uint64_t(ptr) < lower_bound || uint64_t(ptr) >= upper_bound) { SHM_LOG_ERROR("shmem_ptr Failed. PE: " << shmem_my_pe() << " Got Ilegal Address !!"); return nullptr; } - uint64_t offset = (uint64_t)ptr - (uint64_t)shm::g_state.heap_base; - void *symm_ptr = shm::g_state.p2p_heap_base[pe]; + uint64_t offset = (uint64_t)ptr - (uint64_t)g_state.heap_base; + void *symm_ptr = g_state.p2p_heap_base[pe]; if (symm_ptr != nullptr) { symm_ptr = (void *)((uint64_t)symm_ptr + offset); return symm_ptr; @@ -42,9 +42,9 @@ void *shmem_ptr(void *ptr, int32_t pe) // Set Memcpy Interfaces necessary UB Buffer. int32_t shmem_mte_set_ub_params(uint64_t offset, uint32_t ub_size, uint32_t event_id) { - shm::g_state.mte_config.shmem_ub = offset; - shm::g_state.mte_config.ub_size = ub_size; - shm::g_state.mte_config.event_id = event_id; + g_state.mte_config.shmem_ub = offset; + g_state.mte_config.ub_size = ub_size; + g_state.mte_config.event_id = event_id; SHMEM_CHECK_RET(update_device_state()); return SHMEM_SUCCESS; } @@ -62,7 +62,7 @@ int32_t shmem_mte_set_ub_params(uint64_t offset, uint32_t ub_size, uint32_t even { \ int ret = shmemi_prepare_and_post_rma("shmem_put_" #NAME "_mem", SHMEMI_OP_PUT, NO_NBI, (uint8_t *)dest, \ (uint8_t *)source, nelems, sizeof(TYPE), pe, nullptr, 0, 0, 1, 1, \ - shm::g_state_host.default_stream, shm::g_state_host.default_block_num); \ + g_state_host.default_stream, g_state_host.default_block_num); \ if (ret < 0) { \ SHM_LOG_ERROR("device calling transfer failed"); \ } \ @@ -84,7 +84,7 @@ SHMEM_TYPE_FUNC(SHMEM_TYPE_PUT) { \ int ret = shmemi_prepare_and_post_rma("shmem_put_" #NAME "_mem_nbi", SHMEMI_OP_PUT, NBI, (uint8_t *)dest, \ (uint8_t *)source, nelems, sizeof(TYPE), pe, nullptr, 0, 0, 1, 1, \ - shm::g_state_host.default_stream, shm::g_state_host.default_block_num); \ + g_state_host.default_stream, g_state_host.default_block_num); \ if (ret < 0) { \ SHM_LOG_ERROR("device calling transfer failed"); \ } \ @@ -107,7 +107,7 @@ SHMEM_TYPE_FUNC(SHMEM_TYPE_PUT_NBI) { \ int ret = shmemi_prepare_and_post_rma("shmem_get_" #NAME "_mem", SHMEMI_OP_GET, NO_NBI, (uint8_t *)dest, \ (uint8_t *)source, nelems, sizeof(TYPE), pe, nullptr, 0, 0, 1, 1, \ - shm::g_state_host.default_stream, shm::g_state_host.default_block_num); \ + g_state_host.default_stream, g_state_host.default_block_num); \ if (ret < 0) { \ SHM_LOG_ERROR("device calling transfer failed"); \ } \ @@ -130,7 +130,7 @@ SHMEM_TYPE_FUNC(SHMEM_TYPE_GET) { \ int ret = shmemi_prepare_and_post_rma("shmem_get_" #NAME "_mem_nbi", SHMEMI_OP_GET, NBI, (uint8_t *)dest, \ (uint8_t *)source, nelems, sizeof(TYPE), pe, nullptr, 0, 0, 1, 1, \ - shm::g_state_host.default_stream, shm::g_state_host.default_block_num); \ + g_state_host.default_stream, g_state_host.default_block_num); \ if (ret < 0) { \ SHM_LOG_ERROR("device calling transfer failed"); \ } \ @@ -157,8 +157,8 @@ SHMEM_TYPE_FUNC(SHMEM_TYPE_GET_NBI) { \ int ret = shmemi_prepare_and_post_rma("shmem_put_" #NAME "_mem_signal", SHMEMI_OP_PUT_SIGNAL, NO_NBI, \ (uint8_t *)dst, (uint8_t *)src, elem_size, sizeof(TYPE), pe, sig_addr, \ - signal, sig_op, 1, 1, shm::g_state_host.default_stream, \ - shm::g_state_host.default_block_num); \ + signal, sig_op, 1, 1, g_state_host.default_stream, \ + g_state_host.default_block_num); \ if (ret < 0) { \ SHM_LOG_ERROR("device calling transfer failed"); \ } \ @@ -185,8 +185,8 @@ SHMEM_TYPE_FUNC(SHMEM_PUT_TYPENAME_MEM_SIGNAL) { \ int ret = shmemi_prepare_and_post_rma("shmem_put_" #NAME "_mem_signal_nbi", SHMEMI_OP_PUT_SIGNAL, NBI, \ (uint8_t *)dst, (uint8_t *)src, elem_size, sizeof(TYPE), pe, sig_addr, \ - signal, sig_op, 1, 1, shm::g_state_host.default_stream, \ - shm::g_state_host.default_block_num); \ + signal, sig_op, 1, 1, g_state_host.default_stream, \ + g_state_host.default_block_num); \ if (ret < 0) { \ SHM_LOG_ERROR("device calling transfer failed"); \ } \ @@ -206,7 +206,7 @@ SHMEM_TYPE_FUNC(SHMEM_PUT_TYPENAME_MEM_SIGNAL_NBI) SHMEM_HOST_API void shmem_##NAME##_p(TYPE *dst, const TYPE value, int pe) \ { \ shmemi_prepare_and_post_rma_##NAME##_p("shmem_" #NAME "_p", (uint8_t *)dst, value, pe, \ - shm::g_state_host.default_stream, shm::g_state_host.default_block_num); \ + g_state_host.default_stream, g_state_host.default_block_num); \ } SHMEM_TYPE_FUNC(SHMEM_TYPENAME_P) @@ -242,8 +242,8 @@ SHMEM_TYPE_FUNC(SHMEM_TYPENAME_G) void shmem_putmem(void *dst, void *src, size_t elem_size, int32_t pe) { int ret = shmemi_prepare_and_post_rma("shmem putmem", SHMEMI_OP_PUT, NO_NBI, (uint8_t *)dst, (uint8_t *)src, - elem_size, 1, pe, nullptr, 0, 0, 1, 1, shm::g_state_host.default_stream, - shm::g_state_host.default_block_num); + elem_size, 1, pe, nullptr, 0, 0, 1, 1, g_state_host.default_stream, + g_state_host.default_block_num); if (ret < 0) { SHM_LOG_ERROR("shmem_putmem failed"); } @@ -252,8 +252,8 @@ void shmem_putmem(void *dst, void *src, size_t elem_size, int32_t pe) void shmem_getmem(void *dst, void *src, size_t elem_size, int32_t pe) { int ret = shmemi_prepare_and_post_rma("shmem getmem", SHMEMI_OP_GET, NO_NBI, (uint8_t *)dst, (uint8_t *)src, - elem_size, 1, pe, nullptr, 0, 0, 1, 1, shm::g_state_host.default_stream, - shm::g_state_host.default_block_num); + elem_size, 1, pe, nullptr, 0, 0, 1, 1, g_state_host.default_stream, + g_state_host.default_block_num); if (ret < 0) { SHM_LOG_ERROR("shmem_getmem failed"); } @@ -262,8 +262,8 @@ void shmem_getmem(void *dst, void *src, size_t elem_size, int32_t pe) void shmem_putmem_nbi(void *dst, void *src, size_t elem_size, int32_t pe) { int ret = shmemi_prepare_and_post_rma("shmem_putmem_nbi", SHMEMI_OP_PUT, NBI, (uint8_t *)dst, (uint8_t *)src, - elem_size, 1, pe, nullptr, 0, 0, 1, 1, shm::g_state_host.default_stream, - shm::g_state_host.default_block_num); + elem_size, 1, pe, nullptr, 0, 0, 1, 1, g_state_host.default_stream, + g_state_host.default_block_num); if (ret < 0) { SHM_LOG_ERROR("shmem_putmem_nbi failed"); } @@ -272,8 +272,8 @@ void shmem_putmem_nbi(void *dst, void *src, size_t elem_size, int32_t pe) void shmem_getmem_nbi(void *dst, void *src, size_t elem_size, int32_t pe) { int ret = shmemi_prepare_and_post_rma("shmem_getmem_nbi", SHMEMI_OP_GET, NBI, (uint8_t *)dst, (uint8_t *)src, - elem_size, 1, pe, nullptr, 0, 0, 1, 1, shm::g_state_host.default_stream, - shm::g_state_host.default_block_num); + elem_size, 1, pe, nullptr, 0, 0, 1, 1, g_state_host.default_stream, + g_state_host.default_block_num); if (ret < 0) { SHM_LOG_ERROR("shmem_getmem_nbi failed"); } @@ -283,7 +283,7 @@ void shmem_putmem_signal_nbi(void *dst, void *src, size_t elem_size, void *sig_a { int ret = shmemi_prepare_and_post_rma("shmem_putmem_signal_nbi", SHMEMI_OP_PUT_SIGNAL, NBI, (uint8_t *)dst, (uint8_t *)src, elem_size, 1, pe, (uint8_t *)sig_addr, signal, sig_op, 1, 1, - shm::g_state_host.default_stream, shm::g_state_host.default_block_num); + g_state_host.default_stream, g_state_host.default_block_num); if (ret < 0) { SHM_LOG_ERROR("device calling transfer failed"); } @@ -293,7 +293,7 @@ void shmem_putmem_signal(void *dst, void *src, size_t elem_size, void *sig_addr, { int ret = shmemi_prepare_and_post_rma("shmem_putmem_signal", SHMEMI_OP_PUT_SIGNAL, NO_NBI, (uint8_t *)dst, (uint8_t *)src, elem_size, 1, pe, (uint8_t *)sig_addr, signal, sig_op, 1, 1, - shm::g_state_host.default_stream, shm::g_state_host.default_block_num); + g_state_host.default_stream, g_state_host.default_block_num); if (ret < 0) { SHM_LOG_ERROR("device calling transfer failed"); } diff --git a/src/host/mem/shmemi_mm.h b/src/host/mem/shmemi_mm.h index 1867e60b..3a9316a4 100644 --- a/src/host/mem/shmemi_mm.h +++ b/src/host/mem/shmemi_mm.h @@ -12,9 +12,7 @@ #include "host/shmem_host_def.h" -namespace shm { int32_t memory_manager_initialize(void *base, uint64_t size); void memory_manager_destroy(); -} // namespace shm #endif // SHMEMI_MM_H diff --git a/src/host/mem/shmemi_mm_heap.cpp b/src/host/mem/shmemi_mm_heap.cpp index c48c0e34..02284684 100644 --- a/src/host/mem/shmemi_mm_heap.cpp +++ b/src/host/mem/shmemi_mm_heap.cpp @@ -10,7 +10,6 @@ #include "shmemi_host_common.h" #include "shmemi_mm_heap.h" -namespace shm { bool range_size_first_comparator::operator()(const memory_range &mr1, const memory_range &mr2) const noexcept { if (mr1.size != mr2.size) { @@ -34,7 +33,7 @@ memory_heap::~memory_heap() noexcept void *memory_heap::allocate(uint64_t size) noexcept { - if (size == 0 || size > shm::g_state.heap_size) { + if (size == 0 || size > g_state.heap_size) { SHM_LOG_ERROR("cannot allocate with size " << size); return nullptr; } @@ -74,7 +73,7 @@ void *memory_heap::allocate(uint64_t size) noexcept void *memory_heap::aligned_allocate(uint64_t alignment, uint64_t size) noexcept { - if (size == 0 || alignment == 0 || size > shm::g_state.heap_size) { + if (size == 0 || alignment == 0 || size > g_state.heap_size) { SHM_LOG_ERROR("invalid input, align=" << alignment << ", size=" << size); return nullptr; } @@ -296,5 +295,4 @@ bool memory_heap::expend_size_in_lock(const std::map::iterat } return true; -} -} // namespace shm \ No newline at end of file +} \ No newline at end of file diff --git a/src/host/mem/shmemi_mm_heap.h b/src/host/mem/shmemi_mm_heap.h index 27fd0135..7107547b 100644 --- a/src/host/mem/shmemi_mm_heap.h +++ b/src/host/mem/shmemi_mm_heap.h @@ -15,7 +15,6 @@ #include #include -namespace shm { struct memory_range { const uint64_t offset; const uint64_t size; @@ -56,6 +55,5 @@ private: std::map address_used_tree_; std::set size_idle_tree_; }; -} // namespace shm #endif // SHMEMI_MM_HEAP_H diff --git a/src/host/sync/shmemi_sync.cpp b/src/host/sync/shmemi_sync.cpp index 62fe991f..fe6bdd1f 100644 --- a/src/host/sync/shmemi_sync.cpp +++ b/src/host/sync/shmemi_sync.cpp @@ -20,7 +20,6 @@ extern "C" int rtGetC2cCtrlAddr(uint64_t *config, uint32_t *len); -namespace shm { static uint64_t ffts_config; int32_t shmemi_sync_init() @@ -29,11 +28,9 @@ int32_t shmemi_sync_init() return rtGetC2cCtrlAddr(&ffts_config, &len); } -} // namespace shm - uint64_t shmemx_get_ffts_config() { - return shm::ffts_config; + return ffts_config; } void shmem_barrier(shmem_team_t tid) diff --git a/src/host/sync/shmemi_sync.h b/src/host/sync/shmemi_sync.h index e89ae5b8..b8b3d546 100644 --- a/src/host/sync/shmemi_sync.h +++ b/src/host/sync/shmemi_sync.h @@ -10,10 +10,6 @@ #ifndef SHMEMI_SYNC_H #define SHMEMI_SYNC_H -namespace shm { - int32_t shmemi_sync_init(); -} - #endif // SHMEMI_TEAM_H diff --git a/src/host/team/shmem_team.cpp b/src/host/team/shmem_team.cpp index 1a56a713..c97d7070 100644 --- a/src/host/team/shmem_team.cpp +++ b/src/host/team/shmem_team.cpp @@ -20,7 +20,6 @@ #include "shmemi_device_intf.h" using namespace std; -namespace shm { uint64_t g_team_mask = 0; shmemi_team_t *g_shmem_team_pool = nullptr; @@ -191,8 +190,6 @@ int32_t shmemi_team_finalize() return 0; } -} // namespace shm - int32_t shmem_team_split_strided(shmem_team_t parent_team, int32_t pe_start, int32_t pe_stride, int32_t pe_size, shmem_team_t *new_team) { @@ -202,18 +199,18 @@ int32_t shmem_team_split_strided(shmem_team_t parent_team, int32_t pe_start, int } *new_team = SHMEM_TEAM_INVALID; - if (!shm::is_valid_team(parent_team)) { + if (!is_valid_team(parent_team)) { SHM_LOG_ERROR("input parent team is invalid!, team: " << parent_team); return SHMEM_INVALID_PARAM; } shmemi_team_t my_team; - shmemi_team_t *src_team = &shm::g_shmem_team_pool[parent_team]; + shmemi_team_t *src_team = &g_shmem_team_pool[parent_team]; if (pe_start >= SHMEM_MAX_RANKS || pe_stride >= SHMEM_MAX_RANKS || pe_size > SHMEM_MAX_RANKS) { SHM_LOG_ERROR("create team failed, input invalid, pe_start:" << pe_start << " pe_size:" << pe_size << " pe_stride:" << pe_stride << " parent:" - << shm::team_config2string(src_team)); + << team_config2string(src_team)); return SHMEM_INVALID_PARAM; } @@ -225,14 +222,14 @@ int32_t shmem_team_split_strided(shmem_team_t parent_team, int32_t pe_start, int if (pe_start < 0 || pe_start >= src_team->size || pe_size <= 0 || pe_size > src_team->size || pe_stride < 1) { SHM_LOG_ERROR("create team failed, input invalid, pe_start:" << pe_start << " pe_size:" << pe_size << " pe_stride:" << pe_stride << " parent:" - << shm::team_config2string(src_team)); + << team_config2string(src_team)); return SHMEM_INVALID_PARAM; } if (global_pe_start >= shmem_n_pes() || global_pe_end >= shmem_n_pes()) { SHM_LOG_ERROR("create team failed, large than world size, pe_start:" << pe_start << " pe_size:" << pe_size << " pe_stride:" << pe_stride - << " world_size:" << shmem_n_pes() << " parent:" << shm::team_config2string(src_team)); + << " world_size:" << shmem_n_pes() << " parent:" << team_config2string(src_team)); return SHMEM_INVALID_PARAM; } @@ -247,14 +244,14 @@ int32_t shmem_team_split_strided(shmem_team_t parent_team, int32_t pe_start, int my_team.stride = global_pe_stride; my_team.size = pe_size; - my_team.team_idx = shm::first_free_idx_fetch(); + my_team.team_idx = first_free_idx_fetch(); if (my_team.team_idx == -1) { SHM_LOG_ERROR("create team failed, team num is full!"); return SHMEM_INNER_ERROR; } - shm::g_shmem_team_pool[my_team.team_idx] = my_team; - if (shm::device_team_update(my_team.team_idx, &shm::g_shmem_team_pool[my_team.team_idx]) != 0) { + g_shmem_team_pool[my_team.team_idx] = my_team; + if (device_team_update(my_team.team_idx, &g_shmem_team_pool[my_team.team_idx]) != 0) { shmem_team_destroy(my_team.team_idx); SHM_LOG_ERROR("create team failed, malloc device state failed!"); return SHMEM_INNER_ERROR; @@ -282,12 +279,12 @@ int shmem_team_split_2d(shmem_team_t parent_team, int x_range, shmem_team_t *x_t *x_team = SHMEM_TEAM_INVALID; *y_team = SHMEM_TEAM_INVALID; - if (!shm::is_valid_team(parent_team)) { + if (!is_valid_team(parent_team)) { SHM_LOG_ERROR("input parent team is invalid!, team: " << parent_team); return SHMEM_INVALID_PARAM; } - shmemi_team_t *src_team = &shm::g_shmem_team_pool[parent_team]; + shmemi_team_t *src_team = &g_shmem_team_pool[parent_team]; int32_t src_start = src_team->start; int32_t src_stride = src_team->stride; @@ -349,12 +346,12 @@ int shmem_team_split_2d(shmem_team_t parent_team, int x_range, shmem_team_t *x_t int32_t shmem_team_translate_pe(shmem_team_t src_team, int32_t src_pe, shmem_team_t dest_team) { - if (!shm::is_valid_team(src_team) || !shm::is_valid_team(dest_team)) { + if (!is_valid_team(src_team) || !is_valid_team(dest_team)) { return -1; } - shmemi_team_t *src_team_ptr = &shm::g_shmem_team_pool[src_team]; - shmemi_team_t *dest_team_ptr = &shm::g_shmem_team_pool[dest_team]; + shmemi_team_t *src_team_ptr = &g_shmem_team_pool[src_team]; + shmemi_team_t *dest_team_ptr = &g_shmem_team_pool[dest_team]; if (src_pe > src_team_ptr->size) { return -1; @@ -375,13 +372,13 @@ int32_t shmem_team_translate_pe(shmem_team_t src_team, int32_t src_pe, shmem_tea void shmem_team_destroy(shmem_team_t team) { - if (!shm::is_valid_team(team)) { + if (!is_valid_team(team)) { SHM_LOG_WARN("input team is invalid!, team: " << team); return; } - shm::device_team_destroy(team); - shm::g_team_mask ^= 1ULL << team; + device_team_destroy(team); + g_team_mask ^= 1ULL << team; if (update_device_state() != SHMEM_SUCCESS) { SHM_LOG_WARN("update state failed when destroy team!"); } @@ -389,18 +386,18 @@ void shmem_team_destroy(shmem_team_t team) int32_t shmem_my_pe() { - return shm::g_state.mype; + return g_state.mype; } int32_t shmem_n_pes() { - return shm::g_state.npes; + return g_state.npes; } int32_t shmem_team_my_pe(shmem_team_t team) { - if (shm::is_valid_team(team)) { - return shm::g_shmem_team_pool[team].mype; + if (is_valid_team(team)) { + return g_shmem_team_pool[team].mype; } else { return -1; } @@ -408,8 +405,8 @@ int32_t shmem_team_my_pe(shmem_team_t team) int32_t shmem_team_n_pes(shmem_team_t team) { - if (shm::is_valid_team(team)) { - return shm::g_shmem_team_pool[team].size; + if (is_valid_team(team)) { + return g_shmem_team_pool[team].size; } else { return -1; } @@ -418,7 +415,7 @@ int32_t shmem_team_n_pes(shmem_team_t team) int shmem_team_get_config(shmem_team_t team, shmem_team_config_t *config) { SHMEM_CHECK_RET(config == nullptr); - if (shm::is_valid_team(team)) { + if (is_valid_team(team)) { config->num_contexts = 0; return 0; } else { diff --git a/src/host/team/shmemi_team.h b/src/host/team/shmemi_team.h index be55d2ae..02ec1482 100644 --- a/src/host/team/shmemi_team.h +++ b/src/host/team/shmemi_team.h @@ -12,12 +12,8 @@ #include "stdint.h" -namespace shm { - int32_t shmemi_team_init(int32_t rank, int32_t size); int32_t shmemi_team_finalize(); -} // namespace shm - #endif // SHMEMI_TEAM_H -- Gitee From 0c451acaa10f930871b52aecaa09e98871a9bdc0 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Tue, 30 Sep 2025 15:03:27 +0800 Subject: [PATCH 36/74] Shrink CMakeLists USE_MF Code --- CMakeLists.txt | 14 +++++++++ examples/CMakeLists.txt | 21 ++----------- src/CMakeLists.txt | 30 +------------------ .../init/init_backends/mf/shmemi_init_mf.cpp | 15 +++++++++- .../init/init_backends/mf/shmemi_init_mf.h | 9 ------ 5 files changed, 32 insertions(+), 57 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0350eae2..4aef7718 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,6 +92,20 @@ link_libraries(runtime stdc++ ascendcl m tiling_api platform c_sec dl nnopbase p # MF_BACKEND set(USE_MF "0") +if ("${USE_MF}" STREQUAL "1") + add_compile_definitions(BACKEND_MF=1) + include_directories( + ${PROJECT_SOURCE_DIR}/include/ + ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/host/ + ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/device/ + ) + + link_libraries( + ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_smem.so + ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_hybm_core.so + ) +endif() + # 添加子目录 add_subdirectory(src) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 80acb8cb..f09b1614 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -14,8 +14,6 @@ function(shmem_add_fusion_example NAME) target_compile_options(${NAME} PRIVATE ${CMAKE_CCE_COMPILE_OPTIONS} --cce-aicore-arch=dav-c220) target_include_directories(${NAME} PRIVATE ${PROJECT_SOURCE_DIR}/include - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/host/ - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/device/ ${PROJECT_SOURCE_DIR}/3rdparty/catlass/include ${PROJECT_SOURCE_DIR}/3rdparty/catlass/examples/common ${PROJECT_SOURCE_DIR}/examples/${NAME} @@ -23,13 +21,9 @@ function(shmem_add_fusion_example NAME) ${MPI_INCLUDE_PATH} ) target_link_options(${NAME} PRIVATE --cce-fatobj-link) - target_link_libraries(${NAME} PRIVATE shmem ${MPI_CXX_COMPILE_FLAGS} ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_smem.so) + target_link_libraries(${NAME} PRIVATE shmem ${MPI_CXX_COMPILE_FLAGS}) target_compile_options(${NAME} PRIVATE ${MPI_CXX_COMPILE_FLAGS}) - if ("${USE_MF}" STREQUAL "1") - target_compile_definitions(${NAME} PRIVATE BACKEND_MF=1) - endif() - endfunction() function(shmem_add_collective_example NAME) @@ -37,34 +31,25 @@ function(shmem_add_collective_example NAME) target_compile_options(${NAME}_kernel PRIVATE ${CMAKE_CCE_COMPILE_OPTIONS} --cce-aicore-arch=dav-c220-vec) target_include_directories(${NAME}_kernel PRIVATE ${PROJECT_SOURCE_DIR}/include - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/device/ ${PROJECT_SOURCE_DIR}/3rdparty/catlass/include ${PROJECT_SOURCE_DIR}/3rdparty/catlass/examples/common ${PROJECT_SOURCE_DIR}/examples/${NAME} ) target_link_options(${NAME}_kernel PRIVATE --cce-fatobj-link) - if ("${USE_MF}" STREQUAL "1") - target_compile_definitions(${NAME}_kernel PRIVATE BACKEND_MF=1) - endif() - + add_executable(${NAME} main.cpp) target_compile_options(${NAME} PRIVATE ${CMAKE_CPP_COMPILE_OPTIONS}) target_include_directories(${NAME} PRIVATE ${PROJECT_SOURCE_DIR}/include - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/host/ ${PROJECT_SOURCE_DIR}/3rdparty/catlass/examples/common ${PROJECT_SOURCE_DIR}/examples/${NAME} ${PROJECT_SOURCE_DIR}/examples/utils ${PROJECT_SOURCE_DIR}/src/host ${MPI_INCLUDE_PATH} ) - target_link_libraries(${NAME} PRIVATE shmem ${NAME}_kernel MPI::MPI_CXX ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_smem.so) + target_link_libraries(${NAME} PRIVATE shmem ${NAME}_kernel MPI::MPI_CXX) target_compile_options(${NAME} PRIVATE ${MPI_CXX_COMPILE_FLAGS}) - if ("${USE_MF}" STREQUAL "1") - target_compile_definitions(${NAME} PRIVATE BACKEND_MF=1) - endif() - endfunction() foreach(EXAMPLE diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3f096b2a..3b992d65 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -29,22 +29,10 @@ target_include_directories(shmem_device ${PROJECT_SOURCE_DIR}/include/ ) -if ("${USE_MF}" STREQUAL "1") - target_compile_definitions(shmem_device PRIVATE BACKEND_MF=1) - target_include_directories(shmem_device - PUBLIC - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/device/ - ) -endif() - file(GLOB_RECURSE SHMEM_HOST_FILES ${CMAKE_CURRENT_SOURCE_DIR}/host/*.cpp) list(FILTER SHMEM_HOST_FILES EXCLUDE REGEX "python_wrapper") list(FILTER SHMEM_HOST_FILES EXCLUDE REGEX "modules") -if ("${USE_MF}" STREQUAL "1") - list(FILTER SHMEM_HOST_FILES EXCLUDE REGEX "shmemi_init_normal.cpp") -else() - list(FILTER SHMEM_HOST_FILES EXCLUDE REGEX "shmemi_init_mf.cpp") -endif() + add_library(shmem_host OBJECT ${SHMEM_HOST_FILES}) target_compile_options(shmem_host PRIVATE ${CMAKE_CPP_COMPILE_OPTIONS}) target_include_directories(shmem_host @@ -54,25 +42,9 @@ target_include_directories(shmem_host ${PROJECT_SOURCE_DIR}/src/device ) -if ("${USE_MF}" STREQUAL "1") - target_compile_definitions(shmem_host PRIVATE BACKEND_MF=1) - target_include_directories(shmem_host - PUBLIC - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/host/ - ) -endif() - add_library(shmem SHARED $ $) target_link_options(shmem PRIVATE --cce-fatobj-link) -if ("${USE_MF}" STREQUAL "1") - target_link_libraries(shmem - PUBLIC - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_smem.so - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_hybm_core.so - ) -endif() - set(SHMEM_MTE_SUPPORT ON) if(SHMEM_MTE_SUPPORT) add_library(shmem_transport_mte SHARED) diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.cpp b/src/host/init/init_backends/mf/shmemi_init_mf.cpp index f35e2ab9..00909dbe 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.cpp +++ b/src/host/init/init_backends/mf/shmemi_init_mf.cpp @@ -1,5 +1,16 @@ #include "shmemi_init_mf.h" +#ifdef BACKEND_MF + +// smem api +#include +#include +#include +#include +#include +#include +#include + constexpr int DEFAULT_FLAG = 0; constexpr int DEFAULT_ID = 0; constexpr int DEFAULT_TIMEOUT = 120; @@ -145,4 +156,6 @@ int shmemi_init_mf::transport_init(shmemi_device_host_state_t &g_state) int shmemi_init_mf::transport_finalize() { return SHMEM_SUCCESS; -} \ No newline at end of file +} + +#endif \ No newline at end of file diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.h b/src/host/init/init_backends/mf/shmemi_init_mf.h index 455e5cc9..bd4e1acb 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.h +++ b/src/host/init/init_backends/mf/shmemi_init_mf.h @@ -8,15 +8,6 @@ #include "shmemi_host_common.h" #include "internal/host_device/shmemi_types.h" -// smem api -#include -#include -#include -#include -#include -#include -#include - class shmemi_init_mf: public shmemi_init_base { public: shmemi_init_mf(shmem_init_attr_t *attr, char *ipport); -- Gitee From 38074e843c6329a3c470ae750ef9c8d89e058f71 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Tue, 30 Sep 2025 15:50:53 +0800 Subject: [PATCH 37/74] Rename interfaces in shmemi_base_copy_api.h --- .../low_level/shmem_device_low_level_rma.h | 64 +++++++++---------- .../low_level/shmemx_device_low_level_rma.h | 16 ++--- .../internal/device/shmemi_base_copy_api.h | 27 ++++---- .../internal/device/shmemi_device_common.h | 2 +- 4 files changed, 54 insertions(+), 55 deletions(-) diff --git a/include/device/low_level/shmem_device_low_level_rma.h b/include/device/low_level/shmem_device_low_level_rma.h index e4a674f6..a3f37671 100644 --- a/include/device/low_level/shmem_device_low_level_rma.h +++ b/include/device/low_level/shmem_device_low_level_rma.h @@ -65,20 +65,20 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(__gm__ T *dst, __gm__ T *src, __ubuf__ T uint64_t repeat_elem = block_size / sizeof(T); uint64_t loop_times = remain > 0 ? repeat_times + 1 : repeat_times; for (uint64_t i = 0; i < repeat_times; i++) { - smem_shm_copy_gm2ub(buf, remote_ptr + i * repeat_elem, block_size); + shmemi_copy_gm2ub(buf, remote_ptr + i * repeat_elem, block_size); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(dst + i * repeat_elem, buf, block_size); + shmemi_copy_ub2gm(dst + i * repeat_elem, buf, block_size); if (i != loop_times - 1) { // Last PIPE Sync Should be done outside AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); } } if (remain > 0) { - smem_shm_copy_gm2ub(buf, remote_ptr + repeat_times * repeat_elem, remain); + shmemi_copy_gm2ub(buf, remote_ptr + repeat_times * repeat_elem, remain); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(dst + repeat_times * repeat_elem, buf, remain); + shmemi_copy_ub2gm(dst + repeat_times * repeat_elem, buf, remain); } } @@ -140,7 +140,7 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(__gm__ T *dst, __gm__ T *src, __ubuf__ T AscendC::DataCopyExtParams data_copy_params_gm2ub(copy_params.repeat, copy_params.length * sizeof(T), (copy_params.src_ld - copy_params.length) * sizeof(T), (ub_stride - copy_params.length) / ELE_NUM_PER_UNIT, 0); - smem_shm_copy_gm2ub(ub_tensor, src_tensor, data_copy_params_gm2ub); + shmemi_copy_gm2ub(ub_tensor, src_tensor, data_copy_params_gm2ub); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); @@ -148,7 +148,7 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(__gm__ T *dst, __gm__ T *src, __ubuf__ T AscendC::DataCopyExtParams data_copy_params_ub2gm(copy_params.repeat, copy_params.length * sizeof(T), (ub_stride - copy_params.length) / ELE_NUM_PER_UNIT, (copy_params.dst_ld - copy_params.length) * sizeof(T), 0); - smem_shm_copy_ub2gm(dst_tensor, ub_tensor, data_copy_params_ub2gm); + shmemi_copy_ub2gm(dst_tensor, ub_tensor, data_copy_params_ub2gm); } /** @@ -179,20 +179,20 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(AscendC::GlobalTensor dst, AscendC::G uint64_t repeat_elem = block_size / sizeof(T); uint64_t loop_times = remain > 0 ? repeat_times + 1 : repeat_times; for (uint64_t i = 0; i < repeat_times; i++) { - smem_shm_copy_gm2ub(buf, remote_buff[i * repeat_elem], block_size); + shmemi_copy_gm2ub(buf, remote_buff[i * repeat_elem], block_size); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(dst[i * repeat_elem], buf, block_size); + shmemi_copy_ub2gm(dst[i * repeat_elem], buf, block_size); if (i != loop_times - 1) { // Last PIPE Sync Should be done outside AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); } } if (remain > 0) { - smem_shm_copy_gm2ub(buf, remote_buff[repeat_times * repeat_elem], remain); + shmemi_copy_gm2ub(buf, remote_buff[repeat_times * repeat_elem], remain); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(dst[repeat_times * repeat_elem], buf, remain); + shmemi_copy_ub2gm(dst[repeat_times * repeat_elem], buf, remain); } } @@ -246,7 +246,7 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(AscendC::GlobalTensor dst, AscendC::G AscendC::DataCopyExtParams data_copy_params_gm2ub(copy_params.repeat, copy_params.length * sizeof(T), (copy_params.src_ld - copy_params.length) * sizeof(T), (ub_stride - copy_params.length) / ELE_NUM_PER_UNIT, 0); - smem_shm_copy_gm2ub(buf, remote_buff, data_copy_params_gm2ub); + shmemi_copy_gm2ub(buf, remote_buff, data_copy_params_gm2ub); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); @@ -254,7 +254,7 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(AscendC::GlobalTensor dst, AscendC::G AscendC::DataCopyExtParams data_copy_params_ub2gm(copy_params.repeat, copy_params.length * sizeof(T), (ub_stride - copy_params.length) / ELE_NUM_PER_UNIT, (copy_params.dst_ld - copy_params.length) * sizeof(T), 0); - smem_shm_copy_ub2gm(dst, buf, data_copy_params_ub2gm); + shmemi_copy_ub2gm(dst, buf, data_copy_params_ub2gm); } /** @@ -283,20 +283,20 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(__gm__ T *dst, __gm__ T *src, __ubuf__ T uint64_t repeat_elem = block_size / sizeof(T); uint64_t loop_times = remain > 0 ? repeat_times + 1 : repeat_times; for (uint64_t i = 0; i < repeat_times; i++) { - smem_shm_copy_gm2ub(buf, src + i * repeat_elem, block_size); + shmemi_copy_gm2ub(buf, src + i * repeat_elem, block_size); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(remote_ptr + i * repeat_elem, buf, block_size); + shmemi_copy_ub2gm(remote_ptr + i * repeat_elem, buf, block_size); if (i != loop_times - 1) { // Last PIPE Sync Should be done outside AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); } } if (remain > 0) { - smem_shm_copy_gm2ub(buf, src + repeat_times * repeat_elem, remain); + shmemi_copy_gm2ub(buf, src + repeat_times * repeat_elem, remain); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(remote_ptr + repeat_times * repeat_elem, buf, remain); + shmemi_copy_ub2gm(remote_ptr + repeat_times * repeat_elem, buf, remain); } } @@ -357,7 +357,7 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(__gm__ T *dst, __gm__ T *src, __ubuf__ T AscendC::DataCopyExtParams data_copy_params_gm2ub(copy_params.repeat, copy_params.length * sizeof(T), (copy_params.src_ld - copy_params.length) * sizeof(T), (ub_stride - copy_params.length) / ELE_NUM_PER_UNIT, 0); - smem_shm_copy_gm2ub(ub_tensor, src_tensor, data_copy_params_gm2ub); + shmemi_copy_gm2ub(ub_tensor, src_tensor, data_copy_params_gm2ub); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); @@ -365,7 +365,7 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(__gm__ T *dst, __gm__ T *src, __ubuf__ T AscendC::DataCopyExtParams data_copy_params_ub2gm(copy_params.repeat, copy_params.length * sizeof(T), (ub_stride - copy_params.length) / ELE_NUM_PER_UNIT, (copy_params.dst_ld - copy_params.length) * sizeof(T), 0); - smem_shm_copy_ub2gm(dst_tensor, ub_tensor, data_copy_params_ub2gm); + shmemi_copy_ub2gm(dst_tensor, ub_tensor, data_copy_params_ub2gm); } /** @@ -396,20 +396,20 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(AscendC::GlobalTensor dst, AscendC::G uint64_t repeat_elem = block_size / sizeof(T); uint64_t loop_times = remain > 0 ? repeat_times + 1 : repeat_times; for (uint64_t i = 0; i < repeat_times; i++) { - smem_shm_copy_gm2ub(buf, src[i * repeat_elem], block_size); + shmemi_copy_gm2ub(buf, src[i * repeat_elem], block_size); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(remote_buff[i * repeat_elem], buf, block_size); + shmemi_copy_ub2gm(remote_buff[i * repeat_elem], buf, block_size); if (i != loop_times - 1) { // Last PIPE Sync Should be done outside AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); } } if (remain > 0) { - smem_shm_copy_gm2ub(buf, src[repeat_times * repeat_elem], remain); + shmemi_copy_gm2ub(buf, src[repeat_times * repeat_elem], remain); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(remote_buff[repeat_times * repeat_elem], buf, remain); + shmemi_copy_ub2gm(remote_buff[repeat_times * repeat_elem], buf, remain); } } @@ -463,7 +463,7 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(AscendC::GlobalTensor dst, AscendC::G AscendC::DataCopyExtParams data_copy_params_gm2ub(copy_params.repeat, copy_params.length * sizeof(T), (copy_params.src_ld - copy_params.length) * sizeof(T), (ub_stride - copy_params.length) / ELE_NUM_PER_UNIT, 0); - smem_shm_copy_gm2ub(buf, src, data_copy_params_gm2ub); + shmemi_copy_gm2ub(buf, src, data_copy_params_gm2ub); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); @@ -471,7 +471,7 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(AscendC::GlobalTensor dst, AscendC::G AscendC::DataCopyExtParams data_copy_params_ub2gm(copy_params.repeat, copy_params.length * sizeof(T), (ub_stride - copy_params.length) / ELE_NUM_PER_UNIT, (copy_params.dst_ld - copy_params.length) * sizeof(T), 0); - smem_shm_copy_ub2gm(remote_buff, buf, data_copy_params_ub2gm); + shmemi_copy_ub2gm(remote_buff, buf, data_copy_params_ub2gm); } /** @@ -497,7 +497,7 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(__ubuf__ T *dst, __gm__ T *src, uint32_t __gm__ T *remote_ptr = reinterpret_cast<__gm__ T *>(ptr); - smem_shm_copy_gm2ub(dst, remote_ptr, elem_size * sizeof(T)); + shmemi_copy_gm2ub(dst, remote_ptr, elem_size * sizeof(T)); } /** @@ -525,7 +525,7 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(AscendC::LocalTensor dst, AscendC::Gl AscendC::GlobalTensor remote_buff; remote_buff.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(ptr)); - smem_shm_copy_gm2ub(dst, remote_buff, elem_size * sizeof(T)); + shmemi_copy_gm2ub(dst, remote_buff, elem_size * sizeof(T)); } /** @@ -562,7 +562,7 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(__ubuf__ T *dst, __gm__ T *src, const no AscendC::DataCopyExtParams data_copy_params_gm2ub(copy_params.repeat, copy_params.length * sizeof(T), (copy_params.src_ld - copy_params.length) * sizeof(T), (copy_params.dst_ld - copy_params.length) / ELE_NUM_PER_UNIT, 0); - smem_shm_copy_gm2ub(ub_tensor, src_tensor, data_copy_params_gm2ub); + shmemi_copy_gm2ub(ub_tensor, src_tensor, data_copy_params_gm2ub); } /** @@ -596,7 +596,7 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(AscendC::LocalTensor dst, AscendC::Gl AscendC::DataCopyExtParams data_copy_params_gm2ub(copy_params.repeat, copy_params.length * sizeof(T), (copy_params.src_ld - copy_params.length) * sizeof(T), (copy_params.dst_ld - copy_params.length) / ELE_NUM_PER_UNIT, 0); - smem_shm_copy_gm2ub(dst, remote_buff, data_copy_params_gm2ub); + shmemi_copy_gm2ub(dst, remote_buff, data_copy_params_gm2ub); } /** @@ -622,7 +622,7 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(__gm__ T *dst, __ubuf__ T *src, uint32_t __gm__ T *remote_ptr = reinterpret_cast<__gm__ T *>(ptr); - smem_shm_copy_ub2gm(remote_ptr, src, elem_size * sizeof(T)); + shmemi_copy_ub2gm(remote_ptr, src, elem_size * sizeof(T)); } /** @@ -650,7 +650,7 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(AscendC::GlobalTensor dst, AscendC::L AscendC::GlobalTensor remote_buff; remote_buff.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(ptr)); - smem_shm_copy_ub2gm(remote_buff, src, elem_size * sizeof(T)); + shmemi_copy_ub2gm(remote_buff, src, elem_size * sizeof(T)); } /** @@ -688,7 +688,7 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(__gm__ T *dst, __ubuf__ T *src, const no AscendC::DataCopyExtParams data_copy_params_ub2gm(copy_params.repeat, copy_params.length * sizeof(T), (copy_params.src_ld - copy_params.length) / ELE_NUM_PER_UNIT, (copy_params.dst_ld - copy_params.length) * sizeof(T), 0); - smem_shm_copy_ub2gm(dst_tensor, ub_tensor, data_copy_params_ub2gm); + shmemi_copy_ub2gm(dst_tensor, ub_tensor, data_copy_params_ub2gm); } /** @@ -722,7 +722,7 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(AscendC::GlobalTensor dst, AscendC::L AscendC::DataCopyExtParams data_copy_params_ub2gm(copy_params.repeat, copy_params.length * sizeof(T), (copy_params.src_ld - copy_params.length) / ELE_NUM_PER_UNIT, (copy_params.dst_ld - copy_params.length) * sizeof(T), 0); - smem_shm_copy_ub2gm(remote_buff, src, data_copy_params_ub2gm); + shmemi_copy_ub2gm(remote_buff, src, data_copy_params_ub2gm); } #endif \ No newline at end of file diff --git a/include/device/low_level/shmemx_device_low_level_rma.h b/include/device/low_level/shmemx_device_low_level_rma.h index b160e0a6..4f2eab7f 100644 --- a/include/device/low_level/shmemx_device_low_level_rma.h +++ b/include/device/low_level/shmemx_device_low_level_rma.h @@ -40,20 +40,20 @@ SHMEM_DEVICE void shmemx_mte_get_mem_nbi_low_level(__gm__ int8_t* dst, __gm__ in uint64_t repeat_elem = block_size; uint64_t loop_times = remain > 0 ? repeat_times + 1 : repeat_times; for (uint64_t i = 0; i < repeat_times; i++) { - smem_shm_copy_gm2ub(buf, remote_ptr + i * repeat_elem, block_size, enable_L2); + shmemi_copy_gm2ub(buf, remote_ptr + i * repeat_elem, block_size, enable_L2); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(dst + i * repeat_elem, buf, block_size, enable_L2); + shmemi_copy_ub2gm(dst + i * repeat_elem, buf, block_size, enable_L2); if (i != loop_times - 1) { // Last PIPE Sync Should be done outside AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); } } if (remain > 0) { - smem_shm_copy_gm2ub(buf, remote_ptr + repeat_times * repeat_elem, remain, enable_L2); + shmemi_copy_gm2ub(buf, remote_ptr + repeat_times * repeat_elem, remain, enable_L2); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(dst + repeat_times * repeat_elem, buf, remain, enable_L2); + shmemi_copy_ub2gm(dst + repeat_times * repeat_elem, buf, remain, enable_L2); } } @@ -83,20 +83,20 @@ SHMEM_DEVICE void shmemx_mte_put_mem_nbi_low_level(__gm__ int8_t* dst, __gm__ in uint64_t repeat_elem = block_size; uint64_t loop_times = remain > 0 ? repeat_times + 1 : repeat_times; for (uint64_t i = 0; i < repeat_times; i++) { - smem_shm_copy_gm2ub(buf, src + i * repeat_elem, block_size, enable_L2); + shmemi_copy_gm2ub(buf, src + i * repeat_elem, block_size, enable_L2); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(remote_ptr + i * repeat_elem, buf, block_size, enable_L2); + shmemi_copy_ub2gm(remote_ptr + i * repeat_elem, buf, block_size, enable_L2); if (i != loop_times - 1) { // Last PIPE Sync Should be done outside AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); } } if (remain > 0) { - smem_shm_copy_gm2ub(buf, src + repeat_times * repeat_elem, remain, enable_L2); + shmemi_copy_gm2ub(buf, src + repeat_times * repeat_elem, remain, enable_L2); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(remote_ptr + repeat_times * repeat_elem, buf, remain, enable_L2); + shmemi_copy_ub2gm(remote_ptr + repeat_times * repeat_elem, buf, remain, enable_L2); } } diff --git a/include/internal/device/shmemi_base_copy_api.h b/include/internal/device/shmemi_base_copy_api.h index 18f9fa41..5303a64c 100644 --- a/include/internal/device/shmemi_base_copy_api.h +++ b/include/internal/device/shmemi_base_copy_api.h @@ -11,12 +11,11 @@ #define __SHMEMI_BASE_COPY_H__ #include "kernel_operator.h" - -#define SMEM_SHM_INLINE_AICORE __attribute__((always_inline)) inline __aicore__ +#include "host_device/shmem_types.h" template -SMEM_SHM_INLINE_AICORE void smem_shm_copy_ub2gm(__gm__ T* dstGva, __ubuf__ T* srcUb, - uint32_t size, bool enableL2 = true) +SHMEM_DEVICE void shmemi_copy_ub2gm(__gm__ T* dstGva, __ubuf__ T* srcUb, + uint32_t size, bool toL2Cache = true) { ASCENDC_ASSERT((dstGva != nullptr), "input gva is null"); @@ -26,14 +25,14 @@ SMEM_SHM_INLINE_AICORE void smem_shm_copy_ub2gm(__gm__ T* dstGva, __ubuf__ T* sr ubTensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); ubTensor.address_.bufferAddr = reinterpret_cast(srcUb); gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(dstGva)); - if (!enableL2) { + if (!toL2Cache) { gmTensor.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); } AscendC::DataCopyPad(gmTensor, ubTensor, dataCopyParams); } template -SMEM_SHM_INLINE_AICORE void smem_shm_copy_ub2gm(const AscendC::GlobalTensor &dstGva, +SHMEM_DEVICE void shmemi_copy_ub2gm(const AscendC::GlobalTensor &dstGva, const AscendC::LocalTensor &srcUb, uint32_t size) { AscendC::DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); @@ -41,7 +40,7 @@ SMEM_SHM_INLINE_AICORE void smem_shm_copy_ub2gm(const AscendC::GlobalTensor & } template -SMEM_SHM_INLINE_AICORE void smem_shm_copy_ub2gm(__gm__ T* dstGva, __ubuf__ T* srcUb, +SHMEM_DEVICE void shmemi_copy_ub2gm(__gm__ T* dstGva, __ubuf__ T* srcUb, AscendC::DataCopyExtParams ©Params) { ASCENDC_ASSERT((dstGva != nullptr), "input gva is null"); @@ -55,15 +54,15 @@ SMEM_SHM_INLINE_AICORE void smem_shm_copy_ub2gm(__gm__ T* dstGva, __ubuf__ T* sr } template -SMEM_SHM_INLINE_AICORE void smem_shm_copy_ub2gm(const AscendC::GlobalTensor &dstGva, +SHMEM_DEVICE void shmemi_copy_ub2gm(const AscendC::GlobalTensor &dstGva, const AscendC::LocalTensor &srcUb, AscendC::DataCopyExtParams ©Params) { AscendC::DataCopyPad(dstGva, srcUb, copyParams); } template -SMEM_SHM_INLINE_AICORE void smem_shm_copy_gm2ub(__ubuf__ T* dstUb, __gm__ T* srcGva, - uint32_t size, bool enableL2 = true) +SHMEM_DEVICE void shmemi_copy_gm2ub(__ubuf__ T* dstUb, __gm__ T* srcGva, + uint32_t size, bool toL2Cache = true) { ASCENDC_ASSERT((srcGva != nullptr), "input gva is null"); AscendC::LocalTensor ubTensor; @@ -72,7 +71,7 @@ SMEM_SHM_INLINE_AICORE void smem_shm_copy_gm2ub(__ubuf__ T* dstUb, __gm__ T* src ubTensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); ubTensor.address_.bufferAddr = reinterpret_cast(dstUb); gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(srcGva)); - if (!enableL2) { + if (!toL2Cache) { gmTensor.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); } AscendC::DataCopyPadExtParams padParams; @@ -80,7 +79,7 @@ SMEM_SHM_INLINE_AICORE void smem_shm_copy_gm2ub(__ubuf__ T* dstUb, __gm__ T* src } template -SMEM_SHM_INLINE_AICORE void smem_shm_copy_gm2ub(const AscendC::LocalTensor &dstUb, +SHMEM_DEVICE void shmemi_copy_gm2ub(const AscendC::LocalTensor &dstUb, const AscendC::GlobalTensor &srcGva, uint32_t size) { AscendC::DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); @@ -89,7 +88,7 @@ SMEM_SHM_INLINE_AICORE void smem_shm_copy_gm2ub(const AscendC::LocalTensor &d } template -SMEM_SHM_INLINE_AICORE void smem_shm_copy_gm2ub(__ubuf__ T* dstUb, __gm__ T* srcGva, +SHMEM_DEVICE void shmemi_copy_gm2ub(__ubuf__ T* dstUb, __gm__ T* srcGva, AscendC::DataCopyExtParams ©Params) { ASCENDC_ASSERT((srcGva != nullptr), "input gva is null"); @@ -103,7 +102,7 @@ SMEM_SHM_INLINE_AICORE void smem_shm_copy_gm2ub(__ubuf__ T* dstUb, __gm__ T* src } template -SMEM_SHM_INLINE_AICORE void smem_shm_copy_gm2ub(const AscendC::LocalTensor &dstUb, +SHMEM_DEVICE void shmemi_copy_gm2ub(const AscendC::LocalTensor &dstUb, const AscendC::GlobalTensor &srcGva, AscendC::DataCopyExtParams ©Params) { AscendC::DataCopyPadExtParams padParams; diff --git a/include/internal/device/shmemi_device_common.h b/include/internal/device/shmemi_device_common.h index fa7de721..b16fb28d 100644 --- a/include/internal/device/shmemi_device_common.h +++ b/include/internal/device/shmemi_device_common.h @@ -12,6 +12,7 @@ #include "shmemi_device_arch.h" #include "shmemi_device_def.h" +#include "shmemi_base_copy_api.h" constexpr int ub_limit = 192 * 1024; @@ -22,7 +23,6 @@ SHMEM_DEVICE __gm__ shmemi_device_host_state_t *shmemi_get_state() { return reinterpret_cast<__gm__ shmemi_device_host_state_t *>(smem_shm_get_extra_context_addr()); } #else -#include "shmemi_base_copy_api.h" // rdma constexpr uint64_t SMEM_SHM_DEVICE_PRE_META_SIZE = 128UL; // 128B -- Gitee From de36ccffbc23c1c85487f1c484bdd17eef3466b0 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Thu, 9 Oct 2025 10:40:59 +0800 Subject: [PATCH 38/74] Move shmemi_mm_heap into shmem_mm --- src/host/mem/shmem_mm.cpp | 278 +++++++++++++++++++++++++++-- src/host/mem/shmemi_heap.h | 3 - src/host/mem/shmemi_mm.h | 44 +++++ src/host/mem/shmemi_mm_heap.cpp | 298 -------------------------------- src/host/mem/shmemi_mm_heap.h | 59 ------- 5 files changed, 305 insertions(+), 377 deletions(-) delete mode 100644 src/host/mem/shmemi_mm_heap.cpp delete mode 100644 src/host/mem/shmemi_mm_heap.h diff --git a/src/host/mem/shmem_mm.cpp b/src/host/mem/shmem_mm.cpp index e522af16..507afed0 100644 --- a/src/host/mem/shmem_mm.cpp +++ b/src/host/mem/shmem_mm.cpp @@ -10,16 +10,260 @@ #include #include "acl/acl.h" #include "shmemi_host_common.h" -#include "shmemi_mm_heap.h" + +bool range_size_first_comparator::operator()(const memory_range &mr1, const memory_range &mr2) const noexcept +{ + if (mr1.size != mr2.size) { + return mr1.size < mr2.size; + } + + return mr1.offset < mr2.offset; +} + +memory_heap::memory_heap(void *base, uint64_t size) noexcept : base_{reinterpret_cast(base)}, size_{size} +{ + pthread_spin_init(&spinlock_, 0); + address_idle_tree_[0] = size; + size_idle_tree_.insert({0, size}); +} + +memory_heap::~memory_heap() noexcept +{ + pthread_spin_destroy(&spinlock_); +} + +void *memory_heap::allocate(uint64_t size) noexcept +{ + if (size == 0 || size > g_state.heap_size) { + SHM_LOG_ERROR("cannot allocate with size " << size); + return nullptr; + } + + auto aligned_size = allocated_size_align_up(size); + memory_range anchor{0, aligned_size}; + + pthread_spin_lock(&spinlock_); + auto size_pos = size_idle_tree_.lower_bound(anchor); + if (size_pos == size_idle_tree_.end()) { + pthread_spin_unlock(&spinlock_); + SHM_LOG_ERROR("cannot allocate with size: " << size); + return nullptr; + } + + auto target_offset = size_pos->offset; + auto target_size = size_pos->size; + auto addr_pos = address_idle_tree_.find(target_offset); + if (addr_pos == address_idle_tree_.end()) { + pthread_spin_unlock(&spinlock_); + SHM_LOG_ERROR("offset(" << target_offset << ") size(" << target_size << ") in size tree, not in address tree."); + return nullptr; + } + + size_idle_tree_.erase(size_pos); + address_idle_tree_.erase(addr_pos); + address_used_tree_.emplace(target_offset, aligned_size); + if (target_size > aligned_size) { + memory_range left{target_offset + aligned_size, target_size - aligned_size}; + address_idle_tree_.emplace(left.offset, left.size); + size_idle_tree_.emplace(left); + } + pthread_spin_unlock(&spinlock_); + + return base_ + target_offset; +} + +void *memory_heap::aligned_allocate(uint64_t alignment, uint64_t size) noexcept +{ + if (size == 0 || alignment == 0 || size > g_state.heap_size) { + SHM_LOG_ERROR("invalid input, align=" << alignment << ", size=" << size); + return nullptr; + } + + if ((alignment & (alignment - 1UL)) != 0) { + SHM_LOG_ERROR("alignment should be power of 2, but real " << alignment); + return nullptr; + } + + uint64_t head_skip = 0; + auto aligned_size = allocated_size_align_up(size); + memory_range anchor{0, aligned_size}; + + pthread_spin_lock(&spinlock_); + auto size_pos = size_idle_tree_.lower_bound(anchor); + while (size_pos != size_idle_tree_.end() && !alignment_matches(*size_pos, alignment, aligned_size, head_skip)) { + ++size_pos; + } + + if (size_pos == size_idle_tree_.end()) { + pthread_spin_unlock(&spinlock_); + SHM_LOG_ERROR("cannot allocate with size: " << size << ", alignment: " << alignment); + return nullptr; + } + + auto target_offset = size_pos->offset; + auto target_size = size_pos->size; + memory_range result_range{size_pos->offset + head_skip, aligned_size}; + size_idle_tree_.erase(size_pos); + + if (head_skip > 0) { + size_idle_tree_.emplace(memory_range{target_offset, head_skip}); + address_idle_tree_.emplace(target_offset, head_skip); + } + + if (head_skip + aligned_size < target_size) { + memory_range leftMR{target_offset + head_skip + aligned_size, target_size - head_skip - aligned_size}; + size_idle_tree_.emplace(leftMR); + address_idle_tree_.emplace(leftMR.offset, leftMR.size); + } + + address_used_tree_.emplace(result_range.offset, result_range.size); + pthread_spin_unlock(&spinlock_); + + return base_ + result_range.offset; +} + +int32_t memory_heap::release(void *address) noexcept +{ + auto u8a = reinterpret_cast(address); + if (u8a < base_ || u8a >= base_ + size_) { + SHM_LOG_ERROR("release invalid address " << address); + return -1; + } + + auto offset = u8a - base_; + pthread_spin_lock(&spinlock_); + auto pos = address_used_tree_.find(offset); + if (pos == address_used_tree_.end()) { + pthread_spin_unlock(&spinlock_); + SHM_LOG_ERROR("release address " << address << " not allocated."); + return -1; + } + + auto size = pos->second; + uint64_t final_offset = offset; + uint64_t final_size = size; + address_used_tree_.erase(pos); + + auto prev_addr_pos = address_idle_tree_.lower_bound(offset); + if (prev_addr_pos != address_idle_tree_.begin()) { + --prev_addr_pos; + if (prev_addr_pos != address_idle_tree_.end() && prev_addr_pos->first + prev_addr_pos->second == offset) { + // 合并前一个range + final_offset = prev_addr_pos->first; + final_size += prev_addr_pos->second; + address_idle_tree_.erase(prev_addr_pos); + size_idle_tree_.erase(memory_range{prev_addr_pos->first, prev_addr_pos->second}); + } + } + + auto next_addr_pos = address_idle_tree_.find(offset + size); + if (next_addr_pos != address_idle_tree_.end()) { // 合并后一个range + final_size += next_addr_pos->second; + address_idle_tree_.erase(next_addr_pos); + size_idle_tree_.erase(memory_range{next_addr_pos->first, next_addr_pos->second}); + } + address_idle_tree_.emplace(final_offset, final_size); + size_idle_tree_.emplace(memory_range{final_offset, final_size}); + pthread_spin_unlock(&spinlock_); + + return 0; +} + +bool memory_heap::allocated_size(void *address, uint64_t &size) const noexcept +{ + auto u8a = reinterpret_cast(address); + if (u8a < base_ || u8a >= base_ + size_) { + SHM_LOG_ERROR("release invalid address " << address); + return false; + } + + auto offset = u8a - base_; + bool exist = false; + pthread_spin_lock(&spinlock_); + auto pos = address_used_tree_.find(offset); + if (pos != address_used_tree_.end()) { + exist = true; + size = pos->second; + } + pthread_spin_unlock(&spinlock_); + + return exist; +} + +uint64_t memory_heap::allocated_size_align_up(uint64_t input_size) noexcept +{ + constexpr uint64_t align_size = 16UL; + constexpr uint64_t align_size_mask = ~(align_size - 1UL); + return (input_size + align_size - 1UL) & align_size_mask; +} + +bool memory_heap::alignment_matches(const memory_range &mr, uint64_t alignment, uint64_t size, + uint64_t &head_skip) noexcept +{ + if (mr.size < size) { + return false; + } + + if ((mr.offset & (alignment - 1UL)) == 0UL) { + head_skip = 0; + return true; + } + + auto aligned_offset = ((mr.offset + alignment - 1UL) & (~(alignment - 1UL))); + head_skip = aligned_offset - mr.offset; + return mr.size >= size + head_skip; +} + +void memory_heap::reduce_size_in_lock(const std::map::iterator &pos, uint64_t new_size) noexcept +{ + auto offset = pos->first; + auto old_size = pos->second; + pos->second = new_size; + auto next_addr_pos = address_idle_tree_.find(offset + old_size); + if (next_addr_pos == address_idle_tree_.end()) { + address_idle_tree_.emplace(offset + new_size, old_size - new_size); + size_idle_tree_.emplace(memory_range{offset + new_size, old_size - new_size}); + } else { + auto next_size_pos = size_idle_tree_.find(memory_range{next_addr_pos->first, next_addr_pos->second}); + size_idle_tree_.erase(next_size_pos); + next_addr_pos->second += (old_size - new_size); + size_idle_tree_.emplace(memory_range{next_addr_pos->first, next_addr_pos->second}); + } +} + +bool memory_heap::expend_size_in_lock(const std::map::iterator &pos, uint64_t new_size) noexcept +{ + auto offset = pos->first; + auto old_size = pos->second; + auto delta = new_size - old_size; + + auto next_addr_pos = address_idle_tree_.find(offset + old_size); + if (next_addr_pos == address_idle_tree_.end() || next_addr_pos->second < delta) { + return false; + } + + pos->second = new_size; + auto next_size_pos = size_idle_tree_.find(memory_range{next_addr_pos->first, next_addr_pos->second}); + if (next_addr_pos->second == delta) { + size_idle_tree_.erase(next_size_pos); + address_idle_tree_.erase(next_addr_pos); + } else { + size_idle_tree_.erase(next_size_pos); + next_addr_pos->second -= delta; + size_idle_tree_.emplace(memory_range{next_addr_pos->first, next_addr_pos->second}); + } + + return true; +} namespace { -std::shared_ptr shm_memory_heap; +std::shared_ptr shmemi_memory_manager; } int32_t memory_manager_initialize(void *base, uint64_t size) { - shm_memory_heap = std::make_shared(base, size); - if (shm_memory_heap == nullptr) { + shmemi_memory_manager = std::make_shared(base, size); + if (shmemi_memory_manager == nullptr) { return SHMEM_INNER_ERROR; } return SHMEM_SUCCESS; @@ -27,23 +271,23 @@ int32_t memory_manager_initialize(void *base, uint64_t size) void memory_manager_destroy() { - shm_memory_heap.reset(); + shmemi_memory_manager.reset(); } void *shmem_malloc(size_t size) { - if (shm_memory_heap == nullptr) { + if (shmemi_memory_manager == nullptr) { SHM_LOG_ERROR("Memory Heap Not Initialized."); return nullptr; } - void *ptr = shm_memory_heap->allocate(size); + void *ptr = shmemi_memory_manager->allocate(size); SHM_LOG_DEBUG("shmem_malloc(" << size << ")"); auto ret = shmemi_control_barrier_all(); if (ret != 0) { SHM_LOG_ERROR("malloc mem barrier failed, ret: " << ret); if (ptr != nullptr) { - shm_memory_heap->release(ptr); + shmemi_memory_manager->release(ptr); ptr = nullptr; } } @@ -52,19 +296,19 @@ void *shmem_malloc(size_t size) void *shmem_calloc(size_t nmemb, size_t size) { - if (shm_memory_heap == nullptr) { + if (shmemi_memory_manager == nullptr) { SHM_LOG_ERROR("Memory Heap Not Initialized."); return nullptr; } SHM_ASSERT_MULTIPLY_OVERFLOW(nmemb, size, g_state.heap_size, nullptr); auto total_size = nmemb * size; - auto ptr = shm_memory_heap->allocate(total_size); + auto ptr = shmemi_memory_manager->allocate(total_size); if (ptr != nullptr) { auto ret = aclrtMemset(ptr, size, 0, size); if (ret != 0) { SHM_LOG_ERROR("shmem_calloc(" << nmemb << ", " << size << ") memset failed: " << ret); - shm_memory_heap->release(ptr); + shmemi_memory_manager->release(ptr); ptr = nullptr; } } @@ -73,7 +317,7 @@ void *shmem_calloc(size_t nmemb, size_t size) if (ret != 0) { SHM_LOG_ERROR("calloc mem barrier failed, ret: " << ret); if (ptr != nullptr) { - shm_memory_heap->release(ptr); + shmemi_memory_manager->release(ptr); ptr = nullptr; } } @@ -84,17 +328,17 @@ void *shmem_calloc(size_t nmemb, size_t size) void *shmem_align(size_t alignment, size_t size) { - if (shm_memory_heap == nullptr) { + if (shmemi_memory_manager == nullptr) { SHM_LOG_ERROR("Memory Heap Not Initialized."); return nullptr; } - auto ptr = shm_memory_heap->aligned_allocate(alignment, size); + auto ptr = shmemi_memory_manager->aligned_allocate(alignment, size); auto ret = shmemi_control_barrier_all(); if (ret != 0) { SHM_LOG_ERROR("shmem_align barrier failed, ret: " << ret); if (ptr != nullptr) { - shm_memory_heap->release(ptr); + shmemi_memory_manager->release(ptr); ptr = nullptr; } } @@ -104,7 +348,7 @@ void *shmem_align(size_t alignment, size_t size) void shmem_free(void *ptr) { - if (shm_memory_heap == nullptr) { + if (shmemi_memory_manager == nullptr) { SHM_LOG_ERROR("Memory Heap Not Initialized."); return; } @@ -112,7 +356,7 @@ void shmem_free(void *ptr) return; } - auto ret = shm_memory_heap->release(ptr); + auto ret = shmemi_memory_manager->release(ptr); if (ret != 0) { SHM_LOG_ERROR("release failed: " << ret); } diff --git a/src/host/mem/shmemi_heap.h b/src/host/mem/shmemi_heap.h index 0248af5a..e5ebe537 100644 --- a/src/host/mem/shmemi_heap.h +++ b/src/host/mem/shmemi_heap.h @@ -22,9 +22,6 @@ public: int setup_heap(int *transport_map); // export && import p2p memories && aclrtMapMem int remove_heap(); // aclrtUnmapMem - int *heap_alloc(); // ptr pretend alloc - int *heap_free(); // ptr pretend free - void *get_heap_base(); // return heap_base_ void *get_peer_heap_base_p2p(int pe_id); // peer_heap_base_p2p_ diff --git a/src/host/mem/shmemi_mm.h b/src/host/mem/shmemi_mm.h index 3a9316a4..3d18c471 100644 --- a/src/host/mem/shmemi_mm.h +++ b/src/host/mem/shmemi_mm.h @@ -10,8 +10,52 @@ #ifndef SHMEMI_MM_H #define SHMEMI_MM_H +#include +#include +#include +#include + #include "host/shmem_host_def.h" +struct memory_range { + const uint64_t offset; + const uint64_t size; + + memory_range(uint64_t o, uint64_t s) noexcept : offset{o}, size{s} + {} +}; + +struct range_size_first_comparator { + bool operator()(const memory_range &mr1, const memory_range &mr2) const noexcept; +}; + +class memory_heap { +public: + memory_heap(void *base, uint64_t size) noexcept; + ~memory_heap() noexcept; + +public: + void *allocate(uint64_t size) noexcept; + void *aligned_allocate(uint64_t alignment, uint64_t size) noexcept; + int32_t release(void *address) noexcept; + bool allocated_size(void *address, uint64_t &size) const noexcept; + +private: + static uint64_t allocated_size_align_up(uint64_t input_size) noexcept; + static bool alignment_matches(const memory_range &mr, uint64_t alignment, uint64_t size, + uint64_t &head_skip) noexcept; + void reduce_size_in_lock(const std::map::iterator &pos, uint64_t new_size) noexcept; + bool expend_size_in_lock(const std::map::iterator &pos, uint64_t new_size) noexcept; + +private: + uint8_t *const base_; + const uint64_t size_; + mutable pthread_spinlock_t spinlock_{}; + std::map address_idle_tree_; + std::map address_used_tree_; + std::set size_idle_tree_; +}; + int32_t memory_manager_initialize(void *base, uint64_t size); void memory_manager_destroy(); diff --git a/src/host/mem/shmemi_mm_heap.cpp b/src/host/mem/shmemi_mm_heap.cpp deleted file mode 100644 index 02284684..00000000 --- a/src/host/mem/shmemi_mm_heap.cpp +++ /dev/null @@ -1,298 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ -#include "shmemi_host_common.h" -#include "shmemi_mm_heap.h" - -bool range_size_first_comparator::operator()(const memory_range &mr1, const memory_range &mr2) const noexcept -{ - if (mr1.size != mr2.size) { - return mr1.size < mr2.size; - } - - return mr1.offset < mr2.offset; -} - -memory_heap::memory_heap(void *base, uint64_t size) noexcept : base_{reinterpret_cast(base)}, size_{size} -{ - pthread_spin_init(&spinlock_, 0); - address_idle_tree_[0] = size; - size_idle_tree_.insert({0, size}); -} - -memory_heap::~memory_heap() noexcept -{ - pthread_spin_destroy(&spinlock_); -} - -void *memory_heap::allocate(uint64_t size) noexcept -{ - if (size == 0 || size > g_state.heap_size) { - SHM_LOG_ERROR("cannot allocate with size " << size); - return nullptr; - } - - auto aligned_size = allocated_size_align_up(size); - memory_range anchor{0, aligned_size}; - - pthread_spin_lock(&spinlock_); - auto size_pos = size_idle_tree_.lower_bound(anchor); - if (size_pos == size_idle_tree_.end()) { - pthread_spin_unlock(&spinlock_); - SHM_LOG_ERROR("cannot allocate with size: " << size); - return nullptr; - } - - auto target_offset = size_pos->offset; - auto target_size = size_pos->size; - auto addr_pos = address_idle_tree_.find(target_offset); - if (addr_pos == address_idle_tree_.end()) { - pthread_spin_unlock(&spinlock_); - SHM_LOG_ERROR("offset(" << target_offset << ") size(" << target_size << ") in size tree, not in address tree."); - return nullptr; - } - - size_idle_tree_.erase(size_pos); - address_idle_tree_.erase(addr_pos); - address_used_tree_.emplace(target_offset, aligned_size); - if (target_size > aligned_size) { - memory_range left{target_offset + aligned_size, target_size - aligned_size}; - address_idle_tree_.emplace(left.offset, left.size); - size_idle_tree_.emplace(left); - } - pthread_spin_unlock(&spinlock_); - - return base_ + target_offset; -} - -void *memory_heap::aligned_allocate(uint64_t alignment, uint64_t size) noexcept -{ - if (size == 0 || alignment == 0 || size > g_state.heap_size) { - SHM_LOG_ERROR("invalid input, align=" << alignment << ", size=" << size); - return nullptr; - } - - if ((alignment & (alignment - 1UL)) != 0) { - SHM_LOG_ERROR("alignment should be power of 2, but real " << alignment); - return nullptr; - } - - uint64_t head_skip = 0; - auto aligned_size = allocated_size_align_up(size); - memory_range anchor{0, aligned_size}; - - pthread_spin_lock(&spinlock_); - auto size_pos = size_idle_tree_.lower_bound(anchor); - while (size_pos != size_idle_tree_.end() && !alignment_matches(*size_pos, alignment, aligned_size, head_skip)) { - ++size_pos; - } - - if (size_pos == size_idle_tree_.end()) { - pthread_spin_unlock(&spinlock_); - SHM_LOG_ERROR("cannot allocate with size: " << size << ", alignment: " << alignment); - return nullptr; - } - - auto target_offset = size_pos->offset; - auto target_size = size_pos->size; - memory_range result_range{size_pos->offset + head_skip, aligned_size}; - size_idle_tree_.erase(size_pos); - - if (head_skip > 0) { - size_idle_tree_.emplace(memory_range{target_offset, head_skip}); - address_idle_tree_.emplace(target_offset, head_skip); - } - - if (head_skip + aligned_size < target_size) { - memory_range leftMR{target_offset + head_skip + aligned_size, target_size - head_skip - aligned_size}; - size_idle_tree_.emplace(leftMR); - address_idle_tree_.emplace(leftMR.offset, leftMR.size); - } - - address_used_tree_.emplace(result_range.offset, result_range.size); - pthread_spin_unlock(&spinlock_); - - return base_ + result_range.offset; -} - -bool memory_heap::change_size(void *address, uint64_t size) noexcept -{ - auto u8a = reinterpret_cast(address); - if (u8a < base_ || u8a >= base_ + size_) { - SHM_LOG_ERROR("release invalid address " << address); - return false; - } - - if (size == 0) { - release(address); - return true; - } - - auto offset = u8a - base_; - pthread_spin_lock(&spinlock_); - auto pos = address_used_tree_.find(offset); - if (pos == address_used_tree_.end()) { - pthread_spin_unlock(&spinlock_); - SHM_LOG_ERROR("change size for address " << address << " not allocated."); - return false; - } - - // size不变 - if (pos->second == size) { - pthread_spin_unlock(&spinlock_); - return true; - } - - // 缩小size - if (pos->second > size) { - reduce_size_in_lock(pos, size); - pthread_spin_unlock(&spinlock_); - return true; - } - - // 扩大size - auto success = expend_size_in_lock(pos, size); - pthread_spin_unlock(&spinlock_); - - return success; -} - -int32_t memory_heap::release(void *address) noexcept -{ - auto u8a = reinterpret_cast(address); - if (u8a < base_ || u8a >= base_ + size_) { - SHM_LOG_ERROR("release invalid address " << address); - return -1; - } - - auto offset = u8a - base_; - pthread_spin_lock(&spinlock_); - auto pos = address_used_tree_.find(offset); - if (pos == address_used_tree_.end()) { - pthread_spin_unlock(&spinlock_); - SHM_LOG_ERROR("release address " << address << " not allocated."); - return -1; - } - - auto size = pos->second; - uint64_t final_offset = offset; - uint64_t final_size = size; - address_used_tree_.erase(pos); - - auto prev_addr_pos = address_idle_tree_.lower_bound(offset); - if (prev_addr_pos != address_idle_tree_.begin()) { - --prev_addr_pos; - if (prev_addr_pos != address_idle_tree_.end() && prev_addr_pos->first + prev_addr_pos->second == offset) { - // 合并前一个range - final_offset = prev_addr_pos->first; - final_size += prev_addr_pos->second; - address_idle_tree_.erase(prev_addr_pos); - size_idle_tree_.erase(memory_range{prev_addr_pos->first, prev_addr_pos->second}); - } - } - - auto next_addr_pos = address_idle_tree_.find(offset + size); - if (next_addr_pos != address_idle_tree_.end()) { // 合并后一个range - final_size += next_addr_pos->second; - address_idle_tree_.erase(next_addr_pos); - size_idle_tree_.erase(memory_range{next_addr_pos->first, next_addr_pos->second}); - } - address_idle_tree_.emplace(final_offset, final_size); - size_idle_tree_.emplace(memory_range{final_offset, final_size}); - pthread_spin_unlock(&spinlock_); - - return 0; -} - -bool memory_heap::allocated_size(void *address, uint64_t &size) const noexcept -{ - auto u8a = reinterpret_cast(address); - if (u8a < base_ || u8a >= base_ + size_) { - SHM_LOG_ERROR("release invalid address " << address); - return false; - } - - auto offset = u8a - base_; - bool exist = false; - pthread_spin_lock(&spinlock_); - auto pos = address_used_tree_.find(offset); - if (pos != address_used_tree_.end()) { - exist = true; - size = pos->second; - } - pthread_spin_unlock(&spinlock_); - - return exist; -} - -uint64_t memory_heap::allocated_size_align_up(uint64_t input_size) noexcept -{ - constexpr uint64_t align_size = 16UL; - constexpr uint64_t align_size_mask = ~(align_size - 1UL); - return (input_size + align_size - 1UL) & align_size_mask; -} - -bool memory_heap::alignment_matches(const memory_range &mr, uint64_t alignment, uint64_t size, - uint64_t &head_skip) noexcept -{ - if (mr.size < size) { - return false; - } - - if ((mr.offset & (alignment - 1UL)) == 0UL) { - head_skip = 0; - return true; - } - - auto aligned_offset = ((mr.offset + alignment - 1UL) & (~(alignment - 1UL))); - head_skip = aligned_offset - mr.offset; - return mr.size >= size + head_skip; -} - -void memory_heap::reduce_size_in_lock(const std::map::iterator &pos, uint64_t new_size) noexcept -{ - auto offset = pos->first; - auto old_size = pos->second; - pos->second = new_size; - auto next_addr_pos = address_idle_tree_.find(offset + old_size); - if (next_addr_pos == address_idle_tree_.end()) { - address_idle_tree_.emplace(offset + new_size, old_size - new_size); - size_idle_tree_.emplace(memory_range{offset + new_size, old_size - new_size}); - } else { - auto next_size_pos = size_idle_tree_.find(memory_range{next_addr_pos->first, next_addr_pos->second}); - size_idle_tree_.erase(next_size_pos); - next_addr_pos->second += (old_size - new_size); - size_idle_tree_.emplace(memory_range{next_addr_pos->first, next_addr_pos->second}); - } -} - -bool memory_heap::expend_size_in_lock(const std::map::iterator &pos, uint64_t new_size) noexcept -{ - auto offset = pos->first; - auto old_size = pos->second; - auto delta = new_size - old_size; - - auto next_addr_pos = address_idle_tree_.find(offset + old_size); - if (next_addr_pos == address_idle_tree_.end() || next_addr_pos->second < delta) { - return false; - } - - pos->second = new_size; - auto next_size_pos = size_idle_tree_.find(memory_range{next_addr_pos->first, next_addr_pos->second}); - if (next_addr_pos->second == delta) { - size_idle_tree_.erase(next_size_pos); - address_idle_tree_.erase(next_addr_pos); - } else { - size_idle_tree_.erase(next_size_pos); - next_addr_pos->second -= delta; - size_idle_tree_.emplace(memory_range{next_addr_pos->first, next_addr_pos->second}); - } - - return true; -} \ No newline at end of file diff --git a/src/host/mem/shmemi_mm_heap.h b/src/host/mem/shmemi_mm_heap.h deleted file mode 100644 index 7107547b..00000000 --- a/src/host/mem/shmemi_mm_heap.h +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ -#ifndef SHMEMI_MM_HEAP_H -#define SHMEMI_MM_HEAP_H - -#include -#include -#include -#include - -struct memory_range { - const uint64_t offset; - const uint64_t size; - - memory_range(uint64_t o, uint64_t s) noexcept : offset{o}, size{s} - { - } -}; - -struct range_size_first_comparator { - bool operator()(const memory_range &mr1, const memory_range &mr2) const noexcept; -}; - -class memory_heap { -public: - memory_heap(void *base, uint64_t size) noexcept; - ~memory_heap() noexcept; - -public: - void *allocate(uint64_t size) noexcept; - void *aligned_allocate(uint64_t alignment, uint64_t size) noexcept; - bool change_size(void *address, uint64_t size) noexcept; - int32_t release(void *address) noexcept; - bool allocated_size(void *address, uint64_t &size) const noexcept; - -private: - static uint64_t allocated_size_align_up(uint64_t input_size) noexcept; - static bool alignment_matches(const memory_range &mr, uint64_t alignment, uint64_t size, - uint64_t &head_skip) noexcept; - void reduce_size_in_lock(const std::map::iterator &pos, uint64_t new_size) noexcept; - bool expend_size_in_lock(const std::map::iterator &pos, uint64_t new_size) noexcept; - -private: - uint8_t *const base_; - const uint64_t size_; - mutable pthread_spinlock_t spinlock_{}; - std::map address_idle_tree_; - std::map address_used_tree_; - std::set size_idle_tree_; -}; - -#endif // SHMEMI_MM_HEAP_H -- Gitee From beea18e5a98d4a615d56976b471c54fc713630e7 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Fri, 17 Oct 2025 14:45:33 +0800 Subject: [PATCH 39/74] Add rt Interfaces Check && Add p2p connectivity check --- include/internal/host_device/shmemi_types.h | 2 +- src/host/common/shmemi_logger.h | 9 +++ .../default/shmemi_init_default.cpp | 6 +- src/host/mem/shmemi_global_state.cpp | 25 +++--- src/host/mem/shmemi_global_state.h | 2 +- src/host/mem/shmemi_heap.cpp | 77 ++++++++++--------- src/host/transport/shmemi_transport.cpp | 14 +--- src/modules/transport/shmemi_mte.cpp | 3 +- 8 files changed, 73 insertions(+), 65 deletions(-) diff --git a/include/internal/host_device/shmemi_types.h b/include/internal/host_device/shmemi_types.h index 7d86360d..a82ac8be 100644 --- a/include/internal/host_device/shmemi_types.h +++ b/include/internal/host_device/shmemi_types.h @@ -49,7 +49,7 @@ extern "C" { // global_state constexpr uint64_t DEVMM_SVM_MEM_START = 0x100000000000ULL; constexpr uint64_t SVM_END_ADDR = 0x100000000000ULL + 0x80000000000ULL - (1UL << 30UL); // svm end -constexpr uint64_t GLOBAL_STATE_SIZE = 512UL * 1024UL * 1024UL; // global_state fixed length +constexpr uint64_t GLOBAL_STATE_SIZE = 4UL * 1024UL * 1024UL; // global_state fixed length // synchronization typedef int32_t shmemi_sync_bit[SHMEMI_SYNCBIT_SIZE / sizeof(int32_t)]; diff --git a/src/host/common/shmemi_logger.h b/src/host/common/shmemi_logger.h index e2c61b90..a5bbf3fc 100644 --- a/src/host/common/shmemi_logger.h +++ b/src/host/common/shmemi_logger.h @@ -181,4 +181,13 @@ private: } \ } while (0) +#define SHMEM_CHECK(x) \ + do { \ + int32_t check_ret = x; \ + if (check_ret != 0) { \ + SHM_LOG_ERROR(" return shmem error: " << check_ret); \ + return ; \ + } \ + } while (0) + #endif // SHMEM_SHM_OUT_LOGGER_H diff --git a/src/host/init/init_backends/default/shmemi_init_default.cpp b/src/host/init/init_backends/default/shmemi_init_default.cpp index 6888aff7..1a4ba83b 100644 --- a/src/host/init/init_backends/default/shmemi_init_default.cpp +++ b/src/host/init/init_backends/default/shmemi_init_default.cpp @@ -1,4 +1,5 @@ #include "shmemi_init_default.h" +#include "common/shmemi_logger.h" shmemi_init_default::shmemi_init_default(shmem_init_attr_t *attr) { @@ -30,9 +31,8 @@ int shmemi_init_default::finalize_device_state() int shmemi_init_default::update_device_state(void* host_ptr, size_t size) { - int32_t status = SHMEM_SUCCESS; - status = aclrtMemcpy(global_state_d->get_ptr(), size, host_ptr, size, ACL_MEMCPY_HOST_TO_DEVICE); - return status; + SHMEM_CHECK_RET(aclrtMemcpy(global_state_d->get_ptr(), size, host_ptr, size, ACL_MEMCPY_HOST_TO_DEVICE)); + return SHMEM_SUCCESS; } int shmemi_init_default::reserve_heap(shmemi_device_host_state_t &g_state) diff --git a/src/host/mem/shmemi_global_state.cpp b/src/host/mem/shmemi_global_state.cpp index c091f38b..7828f5f4 100644 --- a/src/host/mem/shmemi_global_state.cpp +++ b/src/host/mem/shmemi_global_state.cpp @@ -1,6 +1,9 @@ #include #include #include "shmemi_global_state.h" +#include "host/shmem_host_def.h" +#include "common/shmemi_host_types.h" +#include "common/shmemi_logger.h" #define HAL_LOAD_SYM(TARGET_FUNC, FILE_HANDLE, SYMBOL_NAME) \ dlerror(); \ @@ -52,13 +55,9 @@ int32_t load_hal_library() global_state_reigister::global_state_reigister(int device_id): device_id_{device_id} { - int32_t status = load_hal_library(); - if (status != 0) { - std::cout << "load_hal_library failed " << std::endl; - return; - } - - halMemAddressReserveFunc(&device_ptr, GLOBAL_STATE_SIZE, 0, (void *)(SVM_END_ADDR - GLOBAL_STATE_SIZE), 1); + SHMEM_CHECK(load_hal_library()); + + SHMEM_CHECK(halMemAddressReserveFunc(&device_ptr_, GLOBAL_STATE_SIZE, 0, (void *)(SVM_END_ADDR - GLOBAL_STATE_SIZE), 1)); drv_mem_prop memprop; memprop.side = 1; @@ -68,18 +67,18 @@ global_state_reigister::global_state_reigister(int device_id): device_id_{device memprop.mem_type = 0; memprop.reserve = 0; - halMemCreateFunc(&alloc_handle, GLOBAL_STATE_SIZE, &memprop, 0); + SHMEM_CHECK(halMemCreateFunc(&alloc_handle, GLOBAL_STATE_SIZE, &memprop, 0)); - halMemMapFunc(device_ptr, GLOBAL_STATE_SIZE, 0, alloc_handle, 0); + SHMEM_CHECK(halMemMapFunc(device_ptr_, GLOBAL_STATE_SIZE, 0, alloc_handle, 0)); } global_state_reigister::~global_state_reigister() { - halMemUnmapFunc(device_ptr); + SHMEM_CHECK(halMemUnmapFunc(device_ptr_)); - halMemReleaseFunc(alloc_handle); + SHMEM_CHECK(halMemReleaseFunc(alloc_handle)); - halMemAddressFreeFunc(device_ptr); + SHMEM_CHECK(halMemAddressFreeFunc(device_ptr_)); if (hal_handle != nullptr) dlclose(hal_handle); @@ -87,5 +86,5 @@ global_state_reigister::~global_state_reigister() void *global_state_reigister::get_ptr() { - return device_ptr; + return device_ptr_; } diff --git a/src/host/mem/shmemi_global_state.h b/src/host/mem/shmemi_global_state.h index 2119fbf9..646cc03f 100644 --- a/src/host/mem/shmemi_global_state.h +++ b/src/host/mem/shmemi_global_state.h @@ -29,7 +29,7 @@ public: void *get_ptr(); private: - void *device_ptr = nullptr; + void *device_ptr_ = nullptr; drv_mem_handle_t *alloc_handle; diff --git a/src/host/mem/shmemi_heap.cpp b/src/host/mem/shmemi_heap.cpp index 7aa9eb10..721dbc87 100644 --- a/src/host/mem/shmemi_heap.cpp +++ b/src/host/mem/shmemi_heap.cpp @@ -1,4 +1,7 @@ #include "shmemi_heap.h" +#include "host/shmem_host_def.h" +#include "common/shmemi_host_types.h" +#include "common/shmemi_logger.h" shmem_symmetric_heap::shmem_symmetric_heap(int pe_id, int pe_size): mype(pe_id), npes(pe_size) { @@ -16,7 +19,6 @@ shmem_symmetric_heap::shmem_symmetric_heap(int pe_id, int pe_size): mype(pe_id), int shmem_symmetric_heap::reserve_heap(size_t size) { - int status = 0; peer_heap_base_p2p_ = (void **)std::calloc(npes, sizeof(void *)); // reserve virtual ptrs @@ -25,37 +27,32 @@ int shmem_symmetric_heap::reserve_heap(size_t size) } // reserve local heap_base_ - aclrtReserveMemAddress(&(peer_heap_base_p2p_[mype]), size, 0, nullptr, 1); + SHMEM_CHECK_RET(aclrtReserveMemAddress(&(peer_heap_base_p2p_[mype]), size, 0, nullptr, 1)); heap_base_ = peer_heap_base_p2p_[mype]; // alloc local physical memory - aclrtMallocPhysical(&local_handle, size, &memprop, 0); + SHMEM_CHECK_RET(aclrtMallocPhysical(&local_handle, size, &memprop, 0)); alloc_size = size; - return status; + return SHMEM_SUCCESS; } int shmem_symmetric_heap::export_memory() { - int status = 0; // Get share_handle - status = aclrtMemExportToShareableHandle(local_handle, memprop.handleType, 0, &share_handle); - return status; + SHMEM_CHECK_RET(aclrtMemExportToShareableHandle(local_handle, memprop.handleType, 0, &share_handle)); + return SHMEM_SUCCESS; } int shmem_symmetric_heap::export_pid() { - int status = 0; - // Get local pid - status = aclrtDeviceGetBareTgid(&my_pid); - return status; + SHMEM_CHECK_RET(aclrtDeviceGetBareTgid(&my_pid)); + return SHMEM_SUCCESS; } int shmem_symmetric_heap::import_pid() { - int status = 0; - // Get all pids g_boot_handle.allgather(&my_pid, pid_list.data(), 1 * sizeof(int), &g_boot_handle); @@ -65,32 +62,37 @@ int shmem_symmetric_heap::import_pid() if (i == mype) { continue; } + // Check if p2p connected + if (peer_heap_base_p2p_[i] == NULL) { + continue; + } share_pid.push_back(pid_list[i]); } - status = aclrtMemSetPidToShareableHandle(share_handle, share_pid.data(), npes - 1); - return status; + SHMEM_CHECK_RET(aclrtMemSetPidToShareableHandle(share_handle, share_pid.data(), npes - 1)); + return SHMEM_SUCCESS; } int shmem_symmetric_heap::import_memory() { - int status = 0; g_boot_handle.allgather(&share_handle, share_handle_list.data(), 1 * sizeof(uint64_t), &g_boot_handle); for (int i = 0; i < npes; i++) { if (i == mype) { physical_handle_list[i] = local_handle; continue; } - status = aclrtMemImportFromShareableHandle(share_handle_list[i], mype, &(physical_handle_list[i])); + // Check if p2p connected + if (peer_heap_base_p2p_[i] == NULL) { + continue; + } + SHMEM_CHECK_RET(aclrtMemImportFromShareableHandle(share_handle_list[i], mype, &(physical_handle_list[i]))); } - return status; + return SHMEM_SUCCESS; } int shmem_symmetric_heap::setup_heap(int *transport_map) { - int status = 0; - // MTE p2p_heap_base_ reserve int local_offset = mype * npes; for (int i = 0; i < npes; i++) { @@ -98,41 +100,44 @@ int shmem_symmetric_heap::setup_heap(int *transport_map) continue; if (transport_map[local_offset + i] == 1) { - aclrtReserveMemAddress(&(peer_heap_base_p2p_[i]), alloc_size, 0, nullptr, 1); + SHMEM_CHECK_RET(aclrtReserveMemAddress(&(peer_heap_base_p2p_[i]), alloc_size, 0, nullptr, 1)); } } - status = export_memory(); - status = export_pid(); - status = import_pid(); - status = import_memory(); + SHMEM_CHECK_RET(export_memory()); + SHMEM_CHECK_RET(export_pid()); + SHMEM_CHECK_RET(import_pid()); + SHMEM_CHECK_RET(import_memory()); // Shareable Handle Map for (int i = 0; i < npes; i++) { - status = aclrtMapMem(peer_heap_base_p2p_[i], alloc_size, 0, physical_handle_list[i], 0); + // Check if p2p connected + if (peer_heap_base_p2p_[i] != NULL) { + SHMEM_CHECK_RET(aclrtMapMem(peer_heap_base_p2p_[i], alloc_size, 0, physical_handle_list[i], 0)); + } } - return status; + return SHMEM_SUCCESS; } int shmem_symmetric_heap::remove_heap() { - int status = 0; for (int i = 0; i < npes; i++) { - status = aclrtUnmapMem(peer_heap_base_p2p_[i]); + if (peer_heap_base_p2p_[i] != NULL) { + SHMEM_CHECK_RET(aclrtUnmapMem(peer_heap_base_p2p_[i])); + } } - - return status; + return SHMEM_SUCCESS; } int shmem_symmetric_heap::unreserve_heap() { - int status = 0; for (int i = 0; i < npes; i++) { - status = aclrtReleaseMemAddress(peer_heap_base_p2p_[i]); + if (peer_heap_base_p2p_[i] != NULL) { + SHMEM_CHECK_RET(aclrtReleaseMemAddress(peer_heap_base_p2p_[i])); + } } - - status = aclrtFreePhysical(local_handle); - return status; + SHMEM_CHECK_RET(aclrtFreePhysical(local_handle)); + return SHMEM_SUCCESS; } void *shmem_symmetric_heap::get_heap_base() diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp index b5908ae9..4cee15f1 100644 --- a/src/host/transport/shmemi_transport.cpp +++ b/src/host/transport/shmemi_transport.cpp @@ -3,8 +3,6 @@ #include "transport/shmemi_transport.h" -// extern shmemi_device_host_state_t g_state; - #define TRANSPORT_MODULE_MTE "shmem_transport_mte.so" static void *mte_plugin_hdl = nullptr; @@ -42,8 +40,6 @@ void shmemi_transport_unload() int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state) { - int status = 0; - uint32_t num_choosen_transport = 0; mte_plugin_name = TRANSPORT_MODULE_MTE; @@ -63,14 +59,12 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state) { host_hash_list = (uint64_t *)calloc(g_state.npes, sizeof(uint64_t)); g_boot_handle.allgather(&g_state.host_hash, host_hash_list, 1 * sizeof(uint64_t), &g_boot_handle); - status = shmemi_mte_init(host_hash_list, g_state.mype, g_state.npes); + SHMEM_CHECK_RET(shmemi_mte_init(host_hash_list, g_state.mype, g_state.npes)); - return status; + return SHMEM_SUCCESS; } int32_t shmemi_build_transport_map(int *transport_map, shmemi_device_host_state_t &g_state) { - int status = SHMEM_SUCCESS; - int *local_map = NULL; local_map = (int *)calloc(g_state.npes, sizeof(int)); @@ -83,7 +77,7 @@ int32_t shmemi_build_transport_map(int *transport_map, shmemi_device_host_state_ for (int j = 0; j < num_choosen_transport; j++) { int reach = 0; // Judge mte peer access - status = shmemi_mte_can_access_peer(&reach, i); + SHMEM_CHECK_RET(shmemi_mte_can_access_peer(&reach, i)); if (reach) { int m = 1 << j; @@ -95,7 +89,7 @@ int32_t shmemi_build_transport_map(int *transport_map, shmemi_device_host_state_ g_boot_handle.allgather(local_map, transport_map, g_state.npes * sizeof(int), &g_boot_handle); if (local_map) free(local_map); - return status; + return SHMEM_SUCCESS; } int32_t shmemi_transport_setup_connections(int *transport_map, shmemi_device_host_state_t &g_state) { diff --git a/src/modules/transport/shmemi_mte.cpp b/src/modules/transport/shmemi_mte.cpp index 97257bfe..945be534 100644 --- a/src/modules/transport/shmemi_mte.cpp +++ b/src/modules/transport/shmemi_mte.cpp @@ -14,6 +14,7 @@ #include #include "host/shmem_host_def.h" +#include "common/shmemi_logger.h" #include "internal/host_device/shmemi_types.h" #include "transport/shmemi_transport.h" @@ -49,7 +50,7 @@ int shmemi_mte_connect_peers(int *selected_dev_ids, int num_selected_devs) { // EnablePeerAccess for (int i = 0; i < num_selected_devs; i++) { - aclrtDeviceEnablePeerAccess(selected_dev_ids[i], 0); + SHMEM_CHECK_RET(aclrtDeviceEnablePeerAccess(selected_dev_ids[i], 0)); } return 0; -- Gitee From f7393c280395382af6cf58ca4a4d2b90818ad790 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Mon, 20 Oct 2025 10:58:56 +0800 Subject: [PATCH 40/74] Support ASCEND_RT_VISIBLE_DEVICES && Improve Interfaces Check --- .../default/shmemi_init_default.cpp | 19 ++--- src/host/mem/shmemi_global_state.cpp | 69 +++++++++++++++---- src/host/mem/shmemi_global_state.h | 4 ++ 3 files changed, 69 insertions(+), 23 deletions(-) diff --git a/src/host/init/init_backends/default/shmemi_init_default.cpp b/src/host/init/init_backends/default/shmemi_init_default.cpp index 1a4ba83b..17eed1a2 100644 --- a/src/host/init/init_backends/default/shmemi_init_default.cpp +++ b/src/host/init/init_backends/default/shmemi_init_default.cpp @@ -20,6 +20,9 @@ shmemi_init_default::~shmemi_init_default() int shmemi_init_default::init_device_state() { global_state_d = new global_state_reigister(mype); + if (global_state_d->get_init_status() != 0) { + SHM_LOG_ERROR("global_state reigister error"); + } return SHMEM_SUCCESS; } @@ -47,7 +50,7 @@ int shmemi_init_default::reserve_heap(shmemi_device_host_state_t &g_state) int shmemi_init_default::setup_heap(shmemi_device_host_state_t &g_state) { - heap_obj->setup_heap(transport_map); + SHMEM_CHECK_RET(heap_obj->setup_heap(transport_map)); for (int32_t i = 0; i < g_state.npes; i++) { g_state.p2p_heap_base[i] = heap_obj->get_peer_heap_base_p2p(i); @@ -59,28 +62,26 @@ int shmemi_init_default::setup_heap(shmemi_device_host_state_t &g_state) int shmemi_init_default::remove_heap() { - heap_obj->remove_heap(); - + SHMEM_CHECK_RET(heap_obj->remove_heap()); return SHMEM_SUCCESS; } int shmemi_init_default::release_heap() { - heap_obj->unreserve_heap(); - + SHMEM_CHECK_RET(heap_obj->unreserve_heap()); return SHMEM_SUCCESS; } int shmemi_init_default::transport_init(shmemi_device_host_state_t &g_state) { - shmemi_transport_init(g_state); // mte init && rdma init - shmemi_build_transport_map(transport_map, g_state); // returns transport_map - shmemi_transport_setup_connections(transport_map, g_state); // connect_endpoints by transpost_map + SHMEM_CHECK_RET(shmemi_transport_init(g_state)); // mte init && rdma init + SHMEM_CHECK_RET(shmemi_build_transport_map(transport_map, g_state)); // returns transport_map + SHMEM_CHECK_RET(shmemi_transport_setup_connections(transport_map, g_state)); // connect_endpoints by transpost_map return SHMEM_SUCCESS; } int shmemi_init_default::transport_finalize() { - shmemi_transport_finalize(); + SHMEM_CHECK_RET(shmemi_transport_finalize()); return SHMEM_SUCCESS; } \ No newline at end of file diff --git a/src/host/mem/shmemi_global_state.cpp b/src/host/mem/shmemi_global_state.cpp index 7828f5f4..2a60e298 100644 --- a/src/host/mem/shmemi_global_state.cpp +++ b/src/host/mem/shmemi_global_state.cpp @@ -5,7 +5,7 @@ #include "common/shmemi_host_types.h" #include "common/shmemi_logger.h" -#define HAL_LOAD_SYM(TARGET_FUNC, FILE_HANDLE, SYMBOL_NAME) \ +#define LOAD_SYM(TARGET_FUNC, FILE_HANDLE, SYMBOL_NAME) \ dlerror(); \ *((void **)&TARGET_FUNC) = dlsym(FILE_HANDLE, SYMBOL_NAME); \ error = dlerror(); \ @@ -14,6 +14,12 @@ dlclose(hal_handle); \ } +std::mutex g_mutex; + +bool g_hal_loaded = false; +static void *hal_handle; +const char *g_hal_lib_name = "libascend_hal.so"; + int (*halMemAddressReserveFunc)(void **ptr, size_t size, size_t alignment, void *addr, uint64_t flag); int (*halMemAddressFreeFunc)(void *ptr); int (*halMemCreateFunc)(drv_mem_handle_t **handle, size_t size, const struct drv_mem_prop *prop, uint64_t flag); @@ -21,16 +27,17 @@ int (*halMemReleaseFunc)(drv_mem_handle_t *handle); int (*halMemMapFunc)(void *ptr, size_t size, size_t offset, drv_mem_handle_t *handle, uint64_t flag); int (*halMemUnmapFunc)(void *ptr); -std::mutex g_mutex; -bool g_loaded = false; -static void *hal_handle; -const char *g_hal_lib_name = "libascend_hal.so"; +bool g_rt_loaded = false; +static void *rt_handle; +const char *g_rt_lib_name = "libascendcl.so"; + +int (*rtGetLogicDevIdByUserDevIdFunc)(const int32_t, int32_t *const); int32_t load_hal_library() { char *error; std::lock_guard guard(g_mutex); - if (g_loaded) { + if (g_hal_loaded) { return 0; } @@ -42,26 +49,52 @@ int32_t load_hal_library() return 1; } - HAL_LOAD_SYM(halMemAddressReserveFunc, hal_handle, "halMemAddressReserve"); - HAL_LOAD_SYM(halMemAddressFreeFunc, hal_handle, "halMemAddressFree"); - HAL_LOAD_SYM(halMemCreateFunc, hal_handle, "halMemCreate"); - HAL_LOAD_SYM(halMemReleaseFunc, hal_handle, "halMemRelease"); - HAL_LOAD_SYM(halMemMapFunc, hal_handle, "halMemMap"); - HAL_LOAD_SYM(halMemUnmapFunc, hal_handle, "halMemUnmap"); + LOAD_SYM(halMemAddressReserveFunc, hal_handle, "halMemAddressReserve"); + LOAD_SYM(halMemAddressFreeFunc, hal_handle, "halMemAddressFree"); + LOAD_SYM(halMemCreateFunc, hal_handle, "halMemCreate"); + LOAD_SYM(halMemReleaseFunc, hal_handle, "halMemRelease"); + LOAD_SYM(halMemMapFunc, hal_handle, "halMemMap"); + LOAD_SYM(halMemUnmapFunc, hal_handle, "halMemUnmap"); + + g_hal_loaded = true; + return 0; +} + +int32_t load_rt_library() +{ + char *error; + std::lock_guard guard(g_mutex); + if (g_rt_loaded) { + return 0; + } + + dlerror(); + + rt_handle = dlopen(g_rt_lib_name, RTLD_NOW); + if (!rt_handle) { + fprintf(stderr, "dlopen failed: %s\n", dlerror()); + return 1; + } + + LOAD_SYM(rtGetLogicDevIdByUserDevIdFunc, rt_handle, "rtGetLogicDevIdByUserDevId"); - g_loaded = true; + g_rt_loaded = true; return 0; } global_state_reigister::global_state_reigister(int device_id): device_id_{device_id} { SHMEM_CHECK(load_hal_library()); + SHMEM_CHECK(load_rt_library()); SHMEM_CHECK(halMemAddressReserveFunc(&device_ptr_, GLOBAL_STATE_SIZE, 0, (void *)(SVM_END_ADDR - GLOBAL_STATE_SIZE), 1)); + int32_t logicDeviceId = -1; + SHMEM_CHECK(rtGetLogicDevIdByUserDevIdFunc(device_id_, &logicDeviceId)); + drv_mem_prop memprop; memprop.side = 1; - memprop.devid = device_id_; + memprop.devid = logicDeviceId; memprop.module_id = 0; memprop.pg_type = 0; memprop.mem_type = 0; @@ -70,6 +103,9 @@ global_state_reigister::global_state_reigister(int device_id): device_id_{device SHMEM_CHECK(halMemCreateFunc(&alloc_handle, GLOBAL_STATE_SIZE, &memprop, 0)); SHMEM_CHECK(halMemMapFunc(device_ptr_, GLOBAL_STATE_SIZE, 0, alloc_handle, 0)); + + // init success + init_status_ = 0; } global_state_reigister::~global_state_reigister() @@ -88,3 +124,8 @@ void *global_state_reigister::get_ptr() { return device_ptr_; } + +int global_state_reigister::get_init_status() +{ + return init_status_; +} diff --git a/src/host/mem/shmemi_global_state.h b/src/host/mem/shmemi_global_state.h index 646cc03f..9c4d517c 100644 --- a/src/host/mem/shmemi_global_state.h +++ b/src/host/mem/shmemi_global_state.h @@ -28,12 +28,16 @@ public: ~global_state_reigister(); void *get_ptr(); + int get_init_status(); private: void *device_ptr_ = nullptr; drv_mem_handle_t *alloc_handle; int device_id_; + + // 1 means no-init + int init_status_ = 1; }; -- Gitee From 7650653a754340e4fc38e85380bc74da3b73d660 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Thu, 23 Oct 2025 21:09:31 +0800 Subject: [PATCH 41/74] Fix MTE Transport Init Flow --- src/host/common/shmemi_host_types.h | 15 ++-- .../default/shmemi_init_default.cpp | 10 ++- .../default/shmemi_init_default.h | 3 - src/host/mem/shmemi_heap.cpp | 4 +- src/host/mem/shmemi_heap.h | 2 +- src/host/transport/shmemi_transport.cpp | 72 +++++++++++-------- src/host/transport/shmemi_transport.h | 6 +- src/modules/transport/shmemi_mte.cpp | 35 ++++----- 8 files changed, 78 insertions(+), 69 deletions(-) diff --git a/src/host/common/shmemi_host_types.h b/src/host/common/shmemi_host_types.h index fa784722..78a4b96b 100644 --- a/src/host/common/shmemi_host_types.h +++ b/src/host/common/shmemi_host_types.h @@ -47,14 +47,12 @@ typedef struct shmemi_bootstrap_uid_options { } shmemi_bootstrap_uid_options_t; typedef struct shmemi_transport_pe_info { - int32_t mype; - uint32_t host_id; - uint32_t dev_id; + int32_t pe; + int32_t dev_id; + uint64_t host_hash; } shmemi_transport_pe_info_t; typedef struct shmemi_transport { - shmemi_bootstrap_handle_t *boot_handle; - // control plane int (*can_access_peer)(int *access, shmemi_transport_pe_info_t *peer, struct shmemi_transport *t); int (*connect_peers)(struct shmemi_transport *t, int *selected_dev_ids, int num_selected_devs); @@ -85,12 +83,17 @@ typedef struct { int8_t default_event_id; uint32_t default_block_num; + // topo + int32_t *transport_map; /* npes * npes, 2D-Array, point-to-connectivity. */ + shmemi_transport_pe_info *pe_info; /* All pe's host info, need to build transports. */ + shmemi_options_t options; shmemi_bootstrap_handle_t *boot_handle; shmemi_transport_t choosen_transports[SHMEM_MAX_TRANSPORT_NUM]; int32_t num_choosen_transport; } shmemi_host_state_t; + extern shmemi_bootstrap_handle_t g_boot_handle; -// extern shmemi_host_state_t g_host_state; +extern shmemi_host_state_t g_host_state; #endif // SHMEMI_HOST_TYPES_H \ No newline at end of file diff --git a/src/host/init/init_backends/default/shmemi_init_default.cpp b/src/host/init/init_backends/default/shmemi_init_default.cpp index 17eed1a2..d119846f 100644 --- a/src/host/init/init_backends/default/shmemi_init_default.cpp +++ b/src/host/init/init_backends/default/shmemi_init_default.cpp @@ -5,8 +5,6 @@ shmemi_init_default::shmemi_init_default(shmem_init_attr_t *attr) { mype = attr->my_rank; npes = attr->n_ranks; - - transport_map = (int *)calloc(npes * npes, sizeof(int)); } shmemi_init_default::~shmemi_init_default() @@ -50,7 +48,7 @@ int shmemi_init_default::reserve_heap(shmemi_device_host_state_t &g_state) int shmemi_init_default::setup_heap(shmemi_device_host_state_t &g_state) { - SHMEM_CHECK_RET(heap_obj->setup_heap(transport_map)); + SHMEM_CHECK_RET(heap_obj->setup_heap()); for (int32_t i = 0; i < g_state.npes; i++) { g_state.p2p_heap_base[i] = heap_obj->get_peer_heap_base_p2p(i); @@ -74,9 +72,9 @@ int shmemi_init_default::release_heap() int shmemi_init_default::transport_init(shmemi_device_host_state_t &g_state) { - SHMEM_CHECK_RET(shmemi_transport_init(g_state)); // mte init && rdma init - SHMEM_CHECK_RET(shmemi_build_transport_map(transport_map, g_state)); // returns transport_map - SHMEM_CHECK_RET(shmemi_transport_setup_connections(transport_map, g_state)); // connect_endpoints by transpost_map + SHMEM_CHECK_RET(shmemi_transport_init(g_state)); // mte init && rdma init + SHMEM_CHECK_RET(shmemi_build_transport_map(g_state)); // build transport_map + SHMEM_CHECK_RET(shmemi_transport_setup_connections(g_state)); // connect_endpoints by transpost_map return SHMEM_SUCCESS; } diff --git a/src/host/init/init_backends/default/shmemi_init_default.h b/src/host/init/init_backends/default/shmemi_init_default.h index 826e15a1..f9f67c81 100644 --- a/src/host/init/init_backends/default/shmemi_init_default.h +++ b/src/host/init/init_backends/default/shmemi_init_default.h @@ -41,9 +41,6 @@ private: // heap_obj shmem_symmetric_heap *heap_obj = nullptr; - - // transport_map - int *transport_map = NULL; }; #endif // SHMEMI_INIT_NORMAL_H \ No newline at end of file diff --git a/src/host/mem/shmemi_heap.cpp b/src/host/mem/shmemi_heap.cpp index 721dbc87..7b81ced1 100644 --- a/src/host/mem/shmemi_heap.cpp +++ b/src/host/mem/shmemi_heap.cpp @@ -91,7 +91,7 @@ int shmem_symmetric_heap::import_memory() return SHMEM_SUCCESS; } -int shmem_symmetric_heap::setup_heap(int *transport_map) +int shmem_symmetric_heap::setup_heap() { // MTE p2p_heap_base_ reserve int local_offset = mype * npes; @@ -99,7 +99,7 @@ int shmem_symmetric_heap::setup_heap(int *transport_map) if (i == mype) continue; - if (transport_map[local_offset + i] == 1) { + if (g_host_state.transport_map[local_offset + i] == 1) { SHMEM_CHECK_RET(aclrtReserveMemAddress(&(peer_heap_base_p2p_[i]), alloc_size, 0, nullptr, 1)); } } diff --git a/src/host/mem/shmemi_heap.h b/src/host/mem/shmemi_heap.h index e5ebe537..40d953c8 100644 --- a/src/host/mem/shmemi_heap.h +++ b/src/host/mem/shmemi_heap.h @@ -19,7 +19,7 @@ public: int reserve_heap(size_t size); // aclrtReserveMemAddress && aclrtMallocPhysical int unreserve_heap(); // halMemAddressFree && aclrtFreePhysical - int setup_heap(int *transport_map); // export && import p2p memories && aclrtMapMem + int setup_heap(); // export && import p2p memories && aclrtMapMem int remove_heap(); // aclrtUnmapMem void *get_heap_base(); // return heap_base_ diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp index 4cee15f1..9e807bf9 100644 --- a/src/host/transport/shmemi_transport.cpp +++ b/src/host/transport/shmemi_transport.cpp @@ -8,10 +8,7 @@ static void *mte_plugin_hdl = nullptr; static char *mte_plugin_name = nullptr; -int (*shmemi_mte_init)(uint64_t *hash_list, int pe_id, int pe_size); -int (*shmemi_mte_can_access_peer)(int *access, int pe_id); -int (*shmemi_mte_connect_peers)(int *selected_dev_ids, int num_selected_devs); -int (*shmemi_mte_finalize)(); +shmemi_host_state_t g_host_state; uint64_t *host_hash_list; @@ -38,9 +35,10 @@ void shmemi_transport_unload() } } - int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state) { - uint32_t num_choosen_transport = 0; + g_host_state.num_choosen_transport = 1; // now only support mte; + g_host_state.transport_map = (int *)calloc(g_state.npes * g_state.npes, sizeof(int)); + g_host_state.pe_info = (shmemi_transport_pe_info *)calloc(g_state.npes, sizeof(shmemi_transport_pe_info)); mte_plugin_name = TRANSPORT_MODULE_MTE; shmemi_transport_load(); @@ -51,33 +49,40 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state) { return SHMEM_INVALID_VALUE; } - *((void **)&shmemi_mte_init) = dlsym(mte_plugin_hdl, "shmemi_mte_init"); - *((void **)&shmemi_mte_can_access_peer) = dlsym(mte_plugin_hdl, "shmemi_mte_can_access_peer"); - *((void **)&shmemi_mte_connect_peers) = dlsym(mte_plugin_hdl, "shmemi_mte_connect_peers"); - *((void **)&shmemi_mte_finalize) = dlsym(mte_plugin_hdl, "shmemi_mte_finalize"); + transport_init_func init_fn; + init_fn = (transport_init_func)dlsym(mte_plugin_hdl, "shmemi_mte_init"); + + // Package my_info + shmemi_transport_pe_info_t my_info; + my_info.pe = g_state.mype; + my_info.host_hash = g_state.host_hash; + + int32_t device_id; + SHMEM_CHECK_RET(aclrtGetDevice(&device_id)); + my_info.dev_id = device_id; - host_hash_list = (uint64_t *)calloc(g_state.npes, sizeof(uint64_t)); - g_boot_handle.allgather(&g_state.host_hash, host_hash_list, 1 * sizeof(uint64_t), &g_boot_handle); + // AllGather All pe's host info + g_boot_handle.allgather((void *)&my_info, g_host_state.pe_info, sizeof(shmemi_transport_pe_info_t), &g_boot_handle); - SHMEM_CHECK_RET(shmemi_mte_init(host_hash_list, g_state.mype, g_state.npes)); + SHMEM_CHECK_RET(init_fn(&g_host_state.choosen_transports[0], my_info)); return SHMEM_SUCCESS; } -int32_t shmemi_build_transport_map(int *transport_map, shmemi_device_host_state_t &g_state) { +int32_t shmemi_build_transport_map(shmemi_device_host_state_t &g_state) { int *local_map = NULL; local_map = (int *)calloc(g_state.npes, sizeof(int)); - // Every selected transport must be access to all pe. - // If any can_reach_peer returns false, build_map should return failed. - for (int i = 0; i < g_state.npes; i++) { - int num_choosen_transport = 1; // now only mte. - - // Loop can_access_peer, j = 0 means MTE, j = 1 means RDMA ... - for (int j = 0; j < num_choosen_transport; j++) { + shmemi_transport_t t; + + // Loop can_access_peer, j = 0 means MTE, j = 1 means RDMA ... + for (int j = 0; j < g_host_state.num_choosen_transport; j++) { + t = g_host_state.choosen_transports[j]; + + for (int i = 0; i < g_state.npes; i++) { int reach = 0; - // Judge mte peer access - SHMEM_CHECK_RET(shmemi_mte_can_access_peer(&reach, i)); + + SHMEM_CHECK_RET(t.can_access_peer(&reach, g_host_state.pe_info + i, &t)); if (reach) { int m = 1 << j; @@ -86,14 +91,18 @@ int32_t shmemi_build_transport_map(int *transport_map, shmemi_device_host_state_ } } - g_boot_handle.allgather(local_map, transport_map, g_state.npes * sizeof(int), &g_boot_handle); + g_boot_handle.allgather(local_map, g_host_state.transport_map, g_state.npes * sizeof(int), &g_boot_handle); if (local_map) free(local_map); return SHMEM_SUCCESS; } -int32_t shmemi_transport_setup_connections(int *transport_map, shmemi_device_host_state_t &g_state) { - +int32_t shmemi_transport_setup_connections(shmemi_device_host_state_t &g_state) { + shmemi_transport_t t; + + // MTE + t = g_host_state.choosen_transports[0]; + int *mte_peer_list; int mte_peer_num = 0; mte_peer_list = (int *)calloc(g_state.npes, sizeof(int)); @@ -102,18 +111,23 @@ int32_t shmemi_transport_setup_connections(int *transport_map, shmemi_device_hos for (int i = 0; i < g_state.npes; i++) { if (i == g_state.mype) continue; - if (transport_map[i] == 1) { + if (g_host_state.transport_map[local_offset + i] == 1) { mte_peer_list[mte_peer_num] = i; ++mte_peer_num; } } - - shmemi_mte_connect_peers(mte_peer_list, mte_peer_num); + + t.connect_peers(&t, mte_peer_list, mte_peer_num); return 0; } int32_t shmemi_transport_finalize() { + shmemi_transport_t t; + // MTE + t = g_host_state.choosen_transports[0]; + t.finalize(&t); + dlclose(mte_plugin_hdl); return 0; } diff --git a/src/host/transport/shmemi_transport.h b/src/host/transport/shmemi_transport.h index 0fb09e70..40978094 100644 --- a/src/host/transport/shmemi_transport.h +++ b/src/host/transport/shmemi_transport.h @@ -1,11 +1,13 @@ #ifndef SHMEMI_TRANSPORT_H #define SHMEMI_TRANSPORT_H +typedef int(*transport_init_func)(shmemi_transport_t *transport, shmemi_transport_pe_info_t my_info); + int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state); -int32_t shmemi_build_transport_map(int *transport_map, shmemi_device_host_state_t &g_state); +int32_t shmemi_build_transport_map(shmemi_device_host_state_t &g_state); -int32_t shmemi_transport_setup_connections(int *transport_map, shmemi_device_host_state_t &g_state); +int32_t shmemi_transport_setup_connections(shmemi_device_host_state_t &g_state); int32_t shmemi_transport_finalize(); diff --git a/src/modules/transport/shmemi_mte.cpp b/src/modules/transport/shmemi_mte.cpp index 945be534..4ff13808 100644 --- a/src/modules/transport/shmemi_mte.cpp +++ b/src/modules/transport/shmemi_mte.cpp @@ -13,8 +13,7 @@ #include #include -#include "host/shmem_host_def.h" -#include "common/shmemi_logger.h" +#include "shmemi_host_common.h" #include "internal/host_device/shmemi_types.h" #include "transport/shmemi_transport.h" @@ -22,23 +21,11 @@ extern "C" { #endif -static uint64_t *host_hash_list; -static int mype; -static int npes; +static uint64_t my_host_hash; -// control plane -int shmemi_mte_init(uint64_t *hash_list, int pe_id, int pe_size) { - - host_hash_list = hash_list; - mype = pe_id; - npes = pe_size; - - return 0; -} - -int shmemi_mte_can_access_peer(int *access, int pe_id) { +int shmemi_mte_can_access_peer(int *access, shmemi_transport_pe_info_t *peer_info, shmemi_transport *t) { // host_id same return 1, otherwise 0 - if (host_hash_list[mype] == host_hash_list[pe_id]) { + if (my_host_hash == peer_info->host_hash) { *access = 1; } else { *access = 0; @@ -46,8 +33,7 @@ int shmemi_mte_can_access_peer(int *access, int pe_id) { return 0; } -int shmemi_mte_connect_peers(int *selected_dev_ids, int num_selected_devs) { - +int shmemi_mte_connect_peers(shmemi_transport *t, int *selected_dev_ids, int num_selected_devs) { // EnablePeerAccess for (int i = 0; i < num_selected_devs; i++) { SHMEM_CHECK_RET(aclrtDeviceEnablePeerAccess(selected_dev_ids[i], 0)); @@ -56,8 +42,17 @@ int shmemi_mte_connect_peers(int *selected_dev_ids, int num_selected_devs) { return 0; } -int shmemi_mte_finalize() { +int shmemi_mte_finalize(shmemi_transport *t) { + return 0; +} + +// control plane +int shmemi_mte_init(shmemi_transport_t *t, shmemi_transport_pe_info_t my_info) { + t->can_access_peer = shmemi_mte_can_access_peer; + t->connect_peers = shmemi_mte_connect_peers; + t->finalize = shmemi_mte_finalize; + my_host_hash = my_info.host_hash; return 0; } -- Gitee From bb101ceff1e3a019224b8a5c9c299aef13246382 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Fri, 24 Oct 2025 14:55:22 +0800 Subject: [PATCH 42/74] Transport Refactor 2.0, Add g_state into all transport funcs --- src/host/common/shmemi_host_types.h | 11 +++- .../default/shmemi_init_default.cpp | 3 +- src/host/mem/shmemi_global_state.cpp | 3 + src/host/transport/shmemi_transport.cpp | 59 ++++++------------- src/host/transport/shmemi_transport.h | 2 +- src/modules/transport/shmemi_mte.cpp | 14 ++--- 6 files changed, 38 insertions(+), 54 deletions(-) diff --git a/src/host/common/shmemi_host_types.h b/src/host/common/shmemi_host_types.h index 78a4b96b..5039a326 100644 --- a/src/host/common/shmemi_host_types.h +++ b/src/host/common/shmemi_host_types.h @@ -12,6 +12,8 @@ #define SHMEM_MAX_TRANSPORT_NUM 16 +#include "internal/host_device/shmemi_types.h" + typedef struct shmemi_bootstrap_attr { shmemi_bootstrap_attr() : initialize_mf(0), mpi_comm(NULL), uid_args(NULL) {} @@ -54,9 +56,12 @@ typedef struct shmemi_transport_pe_info { typedef struct shmemi_transport { // control plane - int (*can_access_peer)(int *access, shmemi_transport_pe_info_t *peer, struct shmemi_transport *t); - int (*connect_peers)(struct shmemi_transport *t, int *selected_dev_ids, int num_selected_devs); - int (*finalize)(struct shmemi_transport *t); + int (*can_access_peer)(int *access, shmemi_transport_pe_info_t *peer, + struct shmemi_transport *t, shmemi_device_host_state_t *g_state); + int (*connect_peers)(struct shmemi_transport *t, int *selected_dev_ids, + int num_selected_devs, shmemi_device_host_state_t *g_state); + int (*finalize)(struct shmemi_transport *t, + shmemi_device_host_state_t *g_state); // data plane, TBD void (*rma)(struct shmemi_transport *t, int32_t type, void *dst, void *src, size_t size, int32_t pe); diff --git a/src/host/init/init_backends/default/shmemi_init_default.cpp b/src/host/init/init_backends/default/shmemi_init_default.cpp index d119846f..adcee7cb 100644 --- a/src/host/init/init_backends/default/shmemi_init_default.cpp +++ b/src/host/init/init_backends/default/shmemi_init_default.cpp @@ -40,7 +40,8 @@ int shmemi_init_default::reserve_heap(shmemi_device_host_state_t &g_state) { heap_obj = new shmem_symmetric_heap(mype, npes); - heap_obj->reserve_heap(g_state.heap_size); + SHMEM_CHECK_RET(heap_obj->reserve_heap(g_state.heap_size)); + g_state.heap_base = heap_obj->get_heap_base(); return SHMEM_SUCCESS; diff --git a/src/host/mem/shmemi_global_state.cpp b/src/host/mem/shmemi_global_state.cpp index 2a60e298..6d9486de 100644 --- a/src/host/mem/shmemi_global_state.cpp +++ b/src/host/mem/shmemi_global_state.cpp @@ -118,6 +118,9 @@ global_state_reigister::~global_state_reigister() if (hal_handle != nullptr) dlclose(hal_handle); + + if (rt_handle != nullptr) + dlclose(rt_handle); } void *global_state_reigister::get_ptr() diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp index 9e807bf9..897d6441 100644 --- a/src/host/transport/shmemi_transport.cpp +++ b/src/host/transport/shmemi_transport.cpp @@ -5,52 +5,29 @@ #define TRANSPORT_MODULE_MTE "shmem_transport_mte.so" -static void *mte_plugin_hdl = nullptr; -static char *mte_plugin_name = nullptr; +static void *transport_mte_lib = NULL; shmemi_host_state_t g_host_state; -uint64_t *host_hash_list; - -void shmemi_transport_load() -{ - dlerror(); - if (mte_plugin_hdl == nullptr) { - - mte_plugin_hdl = dlopen(mte_plugin_name, RTLD_NOW); - } - dlerror(); -} - -void shmemi_transport_unload() -{ - if (mte_plugin_hdl != nullptr) { - dlclose(mte_plugin_hdl); - mte_plugin_hdl = nullptr; - } - - if (mte_plugin_name != nullptr) { - free(mte_plugin_name); - mte_plugin_name = nullptr; - } -} - int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state) { g_host_state.num_choosen_transport = 1; // now only support mte; g_host_state.transport_map = (int *)calloc(g_state.npes * g_state.npes, sizeof(int)); g_host_state.pe_info = (shmemi_transport_pe_info *)calloc(g_state.npes, sizeof(shmemi_transport_pe_info)); - mte_plugin_name = TRANSPORT_MODULE_MTE; - shmemi_transport_load(); - - if (!mte_plugin_hdl) { - SHM_LOG_ERROR("Transport unable to load " << mte_plugin_name << ", err is: " << stderr); - shmemi_transport_unload(); + transport_mte_lib = dlopen("shmem_transport_mte.so", RTLD_NOW); + if (!transport_mte_lib) { + SHM_LOG_ERROR("Transport unable to load " << "shmem_transport_mte.so" << ", err is: " << stderr); return SHMEM_INVALID_VALUE; } transport_init_func init_fn; - init_fn = (transport_init_func)dlsym(mte_plugin_hdl, "shmemi_mte_init"); + init_fn = (transport_init_func)dlsym(transport_mte_lib, "shmemi_mte_init"); + if (!init_fn) { + dlclose(transport_mte_lib); + transport_mte_lib = NULL; + SHM_LOG_ERROR("Unable to get info from " << "shmem_transport_mte.so" << "."); + return SHMEM_INVALID_VALUE; + } // Package my_info shmemi_transport_pe_info_t my_info; @@ -64,7 +41,7 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state) { // AllGather All pe's host info g_boot_handle.allgather((void *)&my_info, g_host_state.pe_info, sizeof(shmemi_transport_pe_info_t), &g_boot_handle); - SHMEM_CHECK_RET(init_fn(&g_host_state.choosen_transports[0], my_info)); + SHMEM_CHECK_RET(init_fn(&g_host_state.choosen_transports[0], &g_state)); return SHMEM_SUCCESS; } @@ -82,7 +59,7 @@ int32_t shmemi_build_transport_map(shmemi_device_host_state_t &g_state) { for (int i = 0; i < g_state.npes; i++) { int reach = 0; - SHMEM_CHECK_RET(t.can_access_peer(&reach, g_host_state.pe_info + i, &t)); + SHMEM_CHECK_RET(t.can_access_peer(&reach, g_host_state.pe_info + i, &t, &g_state)); if (reach) { int m = 1 << j; @@ -99,7 +76,6 @@ int32_t shmemi_build_transport_map(shmemi_device_host_state_t &g_state) { int32_t shmemi_transport_setup_connections(shmemi_device_host_state_t &g_state) { shmemi_transport_t t; - // MTE t = g_host_state.choosen_transports[0]; @@ -117,7 +93,7 @@ int32_t shmemi_transport_setup_connections(shmemi_device_host_state_t &g_state) } } - t.connect_peers(&t, mte_peer_list, mte_peer_num); + t.connect_peers(&t, mte_peer_list, mte_peer_num, &g_state); return 0; } @@ -126,8 +102,11 @@ int32_t shmemi_transport_finalize() { shmemi_transport_t t; // MTE t = g_host_state.choosen_transports[0]; - t.finalize(&t); + t.finalize(&t, &g_state); - dlclose(mte_plugin_hdl); + if (transport_mte_lib != NULL) { + dlclose(transport_mte_lib); + transport_mte_lib = NULL; + } return 0; } diff --git a/src/host/transport/shmemi_transport.h b/src/host/transport/shmemi_transport.h index 40978094..e0b331d9 100644 --- a/src/host/transport/shmemi_transport.h +++ b/src/host/transport/shmemi_transport.h @@ -1,7 +1,7 @@ #ifndef SHMEMI_TRANSPORT_H #define SHMEMI_TRANSPORT_H -typedef int(*transport_init_func)(shmemi_transport_t *transport, shmemi_transport_pe_info_t my_info); +typedef int(*transport_init_func)(shmemi_transport_t *transport, shmemi_device_host_state_t *g_state); int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state); diff --git a/src/modules/transport/shmemi_mte.cpp b/src/modules/transport/shmemi_mte.cpp index 4ff13808..539f7250 100644 --- a/src/modules/transport/shmemi_mte.cpp +++ b/src/modules/transport/shmemi_mte.cpp @@ -21,11 +21,9 @@ extern "C" { #endif -static uint64_t my_host_hash; - -int shmemi_mte_can_access_peer(int *access, shmemi_transport_pe_info_t *peer_info, shmemi_transport *t) { +int shmemi_mte_can_access_peer(int *access, shmemi_transport_pe_info_t *peer_info, shmemi_transport *t, shmemi_device_host_state_t *g_state) { // host_id same return 1, otherwise 0 - if (my_host_hash == peer_info->host_hash) { + if (g_state->host_hash == peer_info->host_hash) { *access = 1; } else { *access = 0; @@ -33,26 +31,24 @@ int shmemi_mte_can_access_peer(int *access, shmemi_transport_pe_info_t *peer_inf return 0; } -int shmemi_mte_connect_peers(shmemi_transport *t, int *selected_dev_ids, int num_selected_devs) { +int shmemi_mte_connect_peers(shmemi_transport *t, int *selected_dev_ids, int num_selected_devs, shmemi_device_host_state_t *g_state) { // EnablePeerAccess for (int i = 0; i < num_selected_devs; i++) { SHMEM_CHECK_RET(aclrtDeviceEnablePeerAccess(selected_dev_ids[i], 0)); } - return 0; } -int shmemi_mte_finalize(shmemi_transport *t) { +int shmemi_mte_finalize(shmemi_transport *t, shmemi_device_host_state_t *g_state) { return 0; } // control plane -int shmemi_mte_init(shmemi_transport_t *t, shmemi_transport_pe_info_t my_info) { +int shmemi_mte_init(shmemi_transport_t *t, shmemi_device_host_state_t *g_state) { t->can_access_peer = shmemi_mte_can_access_peer; t->connect_peers = shmemi_mte_connect_peers; t->finalize = shmemi_mte_finalize; - my_host_hash = my_info.host_hash; return 0; } -- Gitee From c7df4873302d1a8dd4ec77cf6f1c7b2df35406b9 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Fri, 24 Oct 2025 15:25:47 +0800 Subject: [PATCH 43/74] Move shmemi_base_copy_api.h into shmem_device_low_level_rma.h --- .../low_level/shmem_device_low_level_rma.h | 155 ++++++++++++++++++ .../internal/device/shmemi_base_copy_api.h | 112 ------------- .../internal/device/shmemi_device_common.h | 1 - 3 files changed, 155 insertions(+), 113 deletions(-) delete mode 100644 include/internal/device/shmemi_base_copy_api.h diff --git a/include/device/low_level/shmem_device_low_level_rma.h b/include/device/low_level/shmem_device_low_level_rma.h index a3f37671..319ad424 100644 --- a/include/device/low_level/shmem_device_low_level_rma.h +++ b/include/device/low_level/shmem_device_low_level_rma.h @@ -39,6 +39,161 @@ SHMEM_DEVICE __gm__ void *shmem_ptr(__gm__ void *ptr, int pe) return reinterpret_cast<__gm__ void *>(remote_ptr); } +/** + * @brief Simple Copy interface. Copy contiguous data on local UB memory to symmetric address on the specified PE. + * + * @param dstGva [in] Pointer on Symmetric memory of the destination data. + * @param srcUb [in] Pointer on local UB of the source data. + * @param elem_size [in] Byte Size of data in the destination and source arrays. + * @param toL2Cache [in] Enable L2Cache or not. False means disable L2Cache. + */ +template +SHMEM_DEVICE void shmemi_copy_ub2gm(__gm__ T* dstGva, __ubuf__ T* srcUb, + uint32_t size, bool toL2Cache = true) +{ + ASCENDC_ASSERT((dstGva != nullptr), "input gva is null"); + + AscendC::LocalTensor ubTensor; + AscendC::GlobalTensor gmTensor; + AscendC::DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); + ubTensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); + ubTensor.address_.bufferAddr = reinterpret_cast(srcUb); + gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(dstGva)); + if (!toL2Cache) { + gmTensor.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); + } + AscendC::DataCopyPad(gmTensor, ubTensor, dataCopyParams); +} + +/** + * @brief Simple Copy interface. Copy contiguous data on local UB memory to symmetric address on the specified PE. + * + * @param dstGva [in] GlobalTensor on Symmetric memory of the destination data. + * @param srcUb [in] LocalTensor on local UB of the source data. + * @param elem_size [in] Byte Size of data in the destination and source arrays. + * @param toL2Cache [in] Enable L2Cache or not. False means disable L2Cache. + */ +template +SHMEM_DEVICE void shmemi_copy_ub2gm(const AscendC::GlobalTensor &dstGva, + const AscendC::LocalTensor &srcUb, uint32_t size) +{ + AscendC::DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); + AscendC::DataCopyPad(dstGva, srcUb, dataCopyParams); +} + +/** + * @brief Simple Copy interface. Copy contiguous data on local UB memory to symmetric address on the specified PE. + * + * @param dstGva [in] Pointer on Symmetric memory of the destination data. + * @param srcUb [in] Pointer on local UB of the source data. + * @param copyParams [in] Describe non-contiguous data copy. + */ +template +SHMEM_DEVICE void shmemi_copy_ub2gm(__gm__ T* dstGva, __ubuf__ T* srcUb, + AscendC::DataCopyExtParams ©Params) +{ + ASCENDC_ASSERT((dstGva != nullptr), "input gva is null"); + + AscendC::LocalTensor ubTensor; + AscendC::GlobalTensor gmTensor; + ubTensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); + ubTensor.address_.bufferAddr = reinterpret_cast(srcUb); + gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(dstGva)); + AscendC::DataCopyPad(gmTensor, ubTensor, copyParams); +} + +/** + * @brief Simple Copy interface. Copy contiguous data on local UB memory to symmetric address on the specified PE. + * + * @param dstGva [in] GlobalTensor on Symmetric memory of the destination data. + * @param srcUb [in] LocalTensor on local UB of the source data. + * @param copyParams [in] Describe non-contiguous data copy. + */ +template +SHMEM_DEVICE void shmemi_copy_ub2gm(const AscendC::GlobalTensor &dstGva, + const AscendC::LocalTensor &srcUb, AscendC::DataCopyExtParams ©Params) +{ + AscendC::DataCopyPad(dstGva, srcUb, copyParams); +} + +/** + * @brief Simple Copy interface. Copy contiguous data on symmetric memory from the specified PE to local UB. + * + * @param dstUb [in] Pointer on local UB of the destination data. + * @param srcGva [in] Pointer on Symmetric memory of the source data. + * @param size [in] Byte Size of data in the destination and source arrays. + * @param toL2Cache [in] Enable L2Cache or not. False means disable L2Cache. + */ +template +SHMEM_DEVICE void shmemi_copy_gm2ub(__ubuf__ T* dstUb, __gm__ T* srcGva, + uint32_t size, bool toL2Cache = true) +{ + ASCENDC_ASSERT((srcGva != nullptr), "input gva is null"); + AscendC::LocalTensor ubTensor; + AscendC::GlobalTensor gmTensor; + AscendC::DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); + ubTensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); + ubTensor.address_.bufferAddr = reinterpret_cast(dstUb); + gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(srcGva)); + if (!toL2Cache) { + gmTensor.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); + } + AscendC::DataCopyPadExtParams padParams; + AscendC::DataCopyPad(ubTensor, gmTensor, dataCopyParams, padParams); +} + +/** + * @brief Simple Copy interface. Copy contiguous data on symmetric memory from the specified PE to local UB. + * + * @param dstUb [in] LocalTensor on local UB of the destination data. + * @param srcGva [in] GlobalTensor on Symmetric memory of the source data. + * @param size [in] Byte Size of data in the destination and source arrays. + */ +template +SHMEM_DEVICE void shmemi_copy_gm2ub(const AscendC::LocalTensor &dstUb, + const AscendC::GlobalTensor &srcGva, uint32_t size) +{ + AscendC::DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); + AscendC::DataCopyPadExtParams padParams; + AscendC::DataCopyPad(dstUb, srcGva, dataCopyParams, padParams); +} + +/** + * @brief Simple Copy interface. Copy contiguous data on local UB memory to symmetric address on the specified PE. + * + * @param dstUb [in] Pointer on local UB of the destination data. + * @param srcGva [in] Pointer on Symmetric memory of the source data. + * @param copyParams [in] Describe non-contiguous data copy. + */ +template +SHMEM_DEVICE void shmemi_copy_gm2ub(__ubuf__ T* dstUb, __gm__ T* srcGva, + AscendC::DataCopyExtParams ©Params) +{ + ASCENDC_ASSERT((srcGva != nullptr), "input gva is null"); + AscendC::LocalTensor ubTensor; + AscendC::GlobalTensor gmTensor; + ubTensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); + ubTensor.address_.bufferAddr = reinterpret_cast(dstUb); + gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(srcGva)); + AscendC::DataCopyPadExtParams padParams; + AscendC::DataCopyPad(ubTensor, gmTensor, copyParams, padParams); +} + +/** + * @brief Simple Copy interface. Copy contiguous data on local UB memory to symmetric address on the specified PE. + * + * @param dstUb [in] LocalTensor on local UB of the destination data. + * @param srcGva [in] GlobalTensor on Symmetric memory of the source data. + * @param copyParams [in] Describe non-contiguous data copy. + */ +template +SHMEM_DEVICE void shmemi_copy_gm2ub(const AscendC::LocalTensor &dstUb, + const AscendC::GlobalTensor &srcGva, AscendC::DataCopyExtParams ©Params) +{ + AscendC::DataCopyPadExtParams padParams; + AscendC::DataCopyPad(dstUb, srcGva, copyParams, padParams); +} + /** * @brief Asynchronous interface. Copy contiguous data on symmetric memory from the specified PE to address on the local device. * diff --git a/include/internal/device/shmemi_base_copy_api.h b/include/internal/device/shmemi_base_copy_api.h deleted file mode 100644 index 5303a64c..00000000 --- a/include/internal/device/shmemi_base_copy_api.h +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ -#ifndef __SHMEMI_BASE_COPY_H__ -#define __SHMEMI_BASE_COPY_H__ - -#include "kernel_operator.h" -#include "host_device/shmem_types.h" - -template -SHMEM_DEVICE void shmemi_copy_ub2gm(__gm__ T* dstGva, __ubuf__ T* srcUb, - uint32_t size, bool toL2Cache = true) -{ - ASCENDC_ASSERT((dstGva != nullptr), "input gva is null"); - - AscendC::LocalTensor ubTensor; - AscendC::GlobalTensor gmTensor; - AscendC::DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); - ubTensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); - ubTensor.address_.bufferAddr = reinterpret_cast(srcUb); - gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(dstGva)); - if (!toL2Cache) { - gmTensor.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); - } - AscendC::DataCopyPad(gmTensor, ubTensor, dataCopyParams); -} - -template -SHMEM_DEVICE void shmemi_copy_ub2gm(const AscendC::GlobalTensor &dstGva, - const AscendC::LocalTensor &srcUb, uint32_t size) -{ - AscendC::DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); - AscendC::DataCopyPad(dstGva, srcUb, dataCopyParams); -} - -template -SHMEM_DEVICE void shmemi_copy_ub2gm(__gm__ T* dstGva, __ubuf__ T* srcUb, - AscendC::DataCopyExtParams ©Params) -{ - ASCENDC_ASSERT((dstGva != nullptr), "input gva is null"); - - AscendC::LocalTensor ubTensor; - AscendC::GlobalTensor gmTensor; - ubTensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); - ubTensor.address_.bufferAddr = reinterpret_cast(srcUb); - gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(dstGva)); - AscendC::DataCopyPad(gmTensor, ubTensor, copyParams); -} - -template -SHMEM_DEVICE void shmemi_copy_ub2gm(const AscendC::GlobalTensor &dstGva, - const AscendC::LocalTensor &srcUb, AscendC::DataCopyExtParams ©Params) -{ - AscendC::DataCopyPad(dstGva, srcUb, copyParams); -} - -template -SHMEM_DEVICE void shmemi_copy_gm2ub(__ubuf__ T* dstUb, __gm__ T* srcGva, - uint32_t size, bool toL2Cache = true) -{ - ASCENDC_ASSERT((srcGva != nullptr), "input gva is null"); - AscendC::LocalTensor ubTensor; - AscendC::GlobalTensor gmTensor; - AscendC::DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); - ubTensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); - ubTensor.address_.bufferAddr = reinterpret_cast(dstUb); - gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(srcGva)); - if (!toL2Cache) { - gmTensor.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); - } - AscendC::DataCopyPadExtParams padParams; - AscendC::DataCopyPad(ubTensor, gmTensor, dataCopyParams, padParams); -} - -template -SHMEM_DEVICE void shmemi_copy_gm2ub(const AscendC::LocalTensor &dstUb, - const AscendC::GlobalTensor &srcGva, uint32_t size) -{ - AscendC::DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); - AscendC::DataCopyPadExtParams padParams; - AscendC::DataCopyPad(dstUb, srcGva, dataCopyParams, padParams); -} - -template -SHMEM_DEVICE void shmemi_copy_gm2ub(__ubuf__ T* dstUb, __gm__ T* srcGva, - AscendC::DataCopyExtParams ©Params) -{ - ASCENDC_ASSERT((srcGva != nullptr), "input gva is null"); - AscendC::LocalTensor ubTensor; - AscendC::GlobalTensor gmTensor; - ubTensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); - ubTensor.address_.bufferAddr = reinterpret_cast(dstUb); - gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(srcGva)); - AscendC::DataCopyPadExtParams padParams; - AscendC::DataCopyPad(ubTensor, gmTensor, copyParams, padParams); -} - -template -SHMEM_DEVICE void shmemi_copy_gm2ub(const AscendC::LocalTensor &dstUb, - const AscendC::GlobalTensor &srcGva, AscendC::DataCopyExtParams ©Params) -{ - AscendC::DataCopyPadExtParams padParams; - AscendC::DataCopyPad(dstUb, srcGva, copyParams, padParams); -} - -#endif // __SHMEMI_BASE_COPY_H__ \ No newline at end of file diff --git a/include/internal/device/shmemi_device_common.h b/include/internal/device/shmemi_device_common.h index b16fb28d..ff64d491 100644 --- a/include/internal/device/shmemi_device_common.h +++ b/include/internal/device/shmemi_device_common.h @@ -12,7 +12,6 @@ #include "shmemi_device_arch.h" #include "shmemi_device_def.h" -#include "shmemi_base_copy_api.h" constexpr int ub_limit = 192 * 1024; -- Gitee From fcfa95d59853a8e39e68eeeb0c0bd0361217c317 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Fri, 24 Oct 2025 15:56:17 +0800 Subject: [PATCH 44/74] Fix some mistake device_id to mype cases --- .../init/init_backends/default/shmemi_init_default.cpp | 8 ++++++-- src/host/init/init_backends/default/shmemi_init_default.h | 1 + src/host/mem/shmemi_heap.cpp | 6 +++--- src/host/mem/shmemi_heap.h | 3 ++- src/host/transport/shmemi_transport.cpp | 3 ++- 5 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/host/init/init_backends/default/shmemi_init_default.cpp b/src/host/init/init_backends/default/shmemi_init_default.cpp index adcee7cb..420f053d 100644 --- a/src/host/init/init_backends/default/shmemi_init_default.cpp +++ b/src/host/init/init_backends/default/shmemi_init_default.cpp @@ -5,6 +5,10 @@ shmemi_init_default::shmemi_init_default(shmem_init_attr_t *attr) { mype = attr->my_rank; npes = attr->n_ranks; + auto status = aclrtGetDevice(&device_id); + if (status != 0) { + SHM_LOG_ERROR("Get Device_id error"); + } } shmemi_init_default::~shmemi_init_default() @@ -17,7 +21,7 @@ shmemi_init_default::~shmemi_init_default() int shmemi_init_default::init_device_state() { - global_state_d = new global_state_reigister(mype); + global_state_d = new global_state_reigister(device_id); if (global_state_d->get_init_status() != 0) { SHM_LOG_ERROR("global_state reigister error"); } @@ -38,7 +42,7 @@ int shmemi_init_default::update_device_state(void* host_ptr, size_t size) int shmemi_init_default::reserve_heap(shmemi_device_host_state_t &g_state) { - heap_obj = new shmem_symmetric_heap(mype, npes); + heap_obj = new shmem_symmetric_heap(mype, npes, device_id); SHMEM_CHECK_RET(heap_obj->reserve_heap(g_state.heap_size)); diff --git a/src/host/init/init_backends/default/shmemi_init_default.h b/src/host/init/init_backends/default/shmemi_init_default.h index f9f67c81..ced2d86c 100644 --- a/src/host/init/init_backends/default/shmemi_init_default.h +++ b/src/host/init/init_backends/default/shmemi_init_default.h @@ -35,6 +35,7 @@ public: private: int mype; int npes; + int device_id; // global_state global_state_reigister *global_state_d = nullptr; diff --git a/src/host/mem/shmemi_heap.cpp b/src/host/mem/shmemi_heap.cpp index 7b81ced1..985d9b1f 100644 --- a/src/host/mem/shmemi_heap.cpp +++ b/src/host/mem/shmemi_heap.cpp @@ -3,7 +3,7 @@ #include "common/shmemi_host_types.h" #include "common/shmemi_logger.h" -shmem_symmetric_heap::shmem_symmetric_heap(int pe_id, int pe_size): mype(pe_id), npes(pe_size) +shmem_symmetric_heap::shmem_symmetric_heap(int pe_id, int pe_size, int dev_id): mype(pe_id), npes(pe_size), device_id(dev_id) { physical_handle_list.resize(pe_size); share_handle_list.resize(pe_size); @@ -13,7 +13,7 @@ shmem_symmetric_heap::shmem_symmetric_heap(int pe_id, int pe_size): mype(pe_id), memprop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED; memprop.memAttr = ACL_HBM_MEM_HUGE; memprop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE; - memprop.location.id = pe_id; + memprop.location.id = dev_id; memprop.reserve = 0; } @@ -85,7 +85,7 @@ int shmem_symmetric_heap::import_memory() if (peer_heap_base_p2p_[i] == NULL) { continue; } - SHMEM_CHECK_RET(aclrtMemImportFromShareableHandle(share_handle_list[i], mype, &(physical_handle_list[i]))); + SHMEM_CHECK_RET(aclrtMemImportFromShareableHandle(share_handle_list[i], device_id, &(physical_handle_list[i]))); } return SHMEM_SUCCESS; diff --git a/src/host/mem/shmemi_heap.h b/src/host/mem/shmemi_heap.h index 40d953c8..eabd297f 100644 --- a/src/host/mem/shmemi_heap.h +++ b/src/host/mem/shmemi_heap.h @@ -13,7 +13,7 @@ class shmem_symmetric_heap { public: shmem_symmetric_heap() {} - shmem_symmetric_heap(int pe_id, int pe_size); + shmem_symmetric_heap(int pe_id, int pe_size, int dev_id); ~shmem_symmetric_heap() {}; int reserve_heap(size_t size); // aclrtReserveMemAddress && aclrtMallocPhysical @@ -34,6 +34,7 @@ private: int32_t mype; int32_t npes; + int32_t device_id; uint64_t alloc_size; diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp index 897d6441..e96b89ed 100644 --- a/src/host/transport/shmemi_transport.cpp +++ b/src/host/transport/shmemi_transport.cpp @@ -88,7 +88,8 @@ int32_t shmemi_transport_setup_connections(shmemi_device_host_state_t &g_state) if (i == g_state.mype) continue; if (g_host_state.transport_map[local_offset + i] == 1) { - mte_peer_list[mte_peer_num] = i; + shmemi_transport_pe_info_t *peer_info = (g_host_state.pe_info + i); + mte_peer_list[mte_peer_num] = peer_info->dev_id; ++mte_peer_num; } } -- Gitee From c14eea0fd30ce1b078128822e0051d048d21d7d6 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Sat, 25 Oct 2025 15:39:07 +0800 Subject: [PATCH 45/74] Add Copyright --- include/device/low_level/shmem_device_low_level_rma.h | 3 +-- .../init_backends/default/shmemi_init_default.cpp | 9 +++++++++ .../init/init_backends/default/shmemi_init_default.h | 9 +++++++++ src/host/init/init_backends/mf/shmemi_init_mf.cpp | 9 +++++++++ src/host/init/init_backends/mf/shmemi_init_mf.h | 9 +++++++++ src/host/init/init_backends/shmemi_init_base.h | 9 +++++++++ src/host/mem/shmemi_global_state.cpp | 9 +++++++++ src/host/mem/shmemi_global_state.h | 9 +++++++++ src/host/mem/shmemi_heap.cpp | 9 +++++++++ src/host/mem/shmemi_heap.h | 9 +++++++++ src/host/transport/shmemi_transport.cpp | 11 +++++++++-- src/host/transport/shmemi_transport.h | 9 +++++++++ 12 files changed, 100 insertions(+), 4 deletions(-) diff --git a/include/device/low_level/shmem_device_low_level_rma.h b/include/device/low_level/shmem_device_low_level_rma.h index 319ad424..879b8a8c 100644 --- a/include/device/low_level/shmem_device_low_level_rma.h +++ b/include/device/low_level/shmem_device_low_level_rma.h @@ -70,8 +70,7 @@ SHMEM_DEVICE void shmemi_copy_ub2gm(__gm__ T* dstGva, __ubuf__ T* srcUb, * * @param dstGva [in] GlobalTensor on Symmetric memory of the destination data. * @param srcUb [in] LocalTensor on local UB of the source data. - * @param elem_size [in] Byte Size of data in the destination and source arrays. - * @param toL2Cache [in] Enable L2Cache or not. False means disable L2Cache. + * @param size [in] Byte Size of data in the destination and source arrays. */ template SHMEM_DEVICE void shmemi_copy_ub2gm(const AscendC::GlobalTensor &dstGva, diff --git a/src/host/init/init_backends/default/shmemi_init_default.cpp b/src/host/init/init_backends/default/shmemi_init_default.cpp index 420f053d..908f0713 100644 --- a/src/host/init/init_backends/default/shmemi_init_default.cpp +++ b/src/host/init/init_backends/default/shmemi_init_default.cpp @@ -1,3 +1,12 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #include "shmemi_init_default.h" #include "common/shmemi_logger.h" diff --git a/src/host/init/init_backends/default/shmemi_init_default.h b/src/host/init/init_backends/default/shmemi_init_default.h index ced2d86c..ed6295b5 100644 --- a/src/host/init/init_backends/default/shmemi_init_default.h +++ b/src/host/init/init_backends/default/shmemi_init_default.h @@ -1,4 +1,13 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #ifndef SHMEMI_INIT_NORMAL_H #define SHMEMI_INIT_NORMAL_H diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.cpp b/src/host/init/init_backends/mf/shmemi_init_mf.cpp index 00909dbe..fd5a7749 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.cpp +++ b/src/host/init/init_backends/mf/shmemi_init_mf.cpp @@ -1,3 +1,12 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #include "shmemi_init_mf.h" #ifdef BACKEND_MF diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.h b/src/host/init/init_backends/mf/shmemi_init_mf.h index bd4e1acb..b2c97dec 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.h +++ b/src/host/init/init_backends/mf/shmemi_init_mf.h @@ -1,4 +1,13 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #ifndef SHMEMI_INIT_MF_H #define SHMEMI_INIT_MF_H diff --git a/src/host/init/init_backends/shmemi_init_base.h b/src/host/init/init_backends/shmemi_init_base.h index 3f1aa394..4b8cdbb7 100644 --- a/src/host/init/init_backends/shmemi_init_base.h +++ b/src/host/init/init_backends/shmemi_init_base.h @@ -1,3 +1,12 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #ifndef SHMEMI_INIT_BASE_H #define SHMEMI_INIT_BASE_H diff --git a/src/host/mem/shmemi_global_state.cpp b/src/host/mem/shmemi_global_state.cpp index 6d9486de..cc350e75 100644 --- a/src/host/mem/shmemi_global_state.cpp +++ b/src/host/mem/shmemi_global_state.cpp @@ -1,3 +1,12 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #include #include #include "shmemi_global_state.h" diff --git a/src/host/mem/shmemi_global_state.h b/src/host/mem/shmemi_global_state.h index 9c4d517c..1fb56cfc 100644 --- a/src/host/mem/shmemi_global_state.h +++ b/src/host/mem/shmemi_global_state.h @@ -1,3 +1,12 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #ifndef SHMEMI_GLOBAL_STATE_H #define SHMEMI_GLOBAL_STATE_H diff --git a/src/host/mem/shmemi_heap.cpp b/src/host/mem/shmemi_heap.cpp index 985d9b1f..de8abf87 100644 --- a/src/host/mem/shmemi_heap.cpp +++ b/src/host/mem/shmemi_heap.cpp @@ -1,3 +1,12 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #include "shmemi_heap.h" #include "host/shmem_host_def.h" #include "common/shmemi_host_types.h" diff --git a/src/host/mem/shmemi_heap.h b/src/host/mem/shmemi_heap.h index eabd297f..68385d0a 100644 --- a/src/host/mem/shmemi_heap.h +++ b/src/host/mem/shmemi_heap.h @@ -1,3 +1,12 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #ifndef SHMEMI_HEAP_H #define SHMEMI_HEAP_H diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp index e96b89ed..05401036 100644 --- a/src/host/transport/shmemi_transport.cpp +++ b/src/host/transport/shmemi_transport.cpp @@ -1,10 +1,17 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #include "shmemi_host_common.h" #include "dlfcn.h" #include "transport/shmemi_transport.h" -#define TRANSPORT_MODULE_MTE "shmem_transport_mte.so" - static void *transport_mte_lib = NULL; shmemi_host_state_t g_host_state; diff --git a/src/host/transport/shmemi_transport.h b/src/host/transport/shmemi_transport.h index e0b331d9..7ddb4da5 100644 --- a/src/host/transport/shmemi_transport.h +++ b/src/host/transport/shmemi_transport.h @@ -1,3 +1,12 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #ifndef SHMEMI_TRANSPORT_H #define SHMEMI_TRANSPORT_H -- Gitee From 632615e1811f225d7cbd8d51f69d0ead276900a4 Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Thu, 23 Oct 2025 16:42:35 +0800 Subject: [PATCH 46/74] Support RDMA with shared library. --- examples/CMakeLists.txt | 1 + examples/rdma_test/CMakeLists.txt | 9 + examples/rdma_test/main.cpp | 212 ++++++ examples/rdma_test/rdma_test_kernel.cpp | 69 ++ .../low_level/shmem_device_low_level_roce.h | 25 +- include/internal/host_device/shmemi_types.h | 1 + src/CMakeLists.txt | 23 + src/host/init/shmem_init.cpp | 1 + src/host/mem/shmemi_heap.cpp | 4 +- src/host/transport/shmemi_transport.cpp | 40 ++ .../transport/rdma/device_qp_manager.cpp | 680 ++++++++++++++++++ .../transport/rdma/device_qp_manager.h | 93 +++ src/modules/transport/rdma/dl_hccp_api.cpp | 167 +++++ src/modules/transport/rdma/dl_hccp_api.h | 248 +++++++ src/modules/transport/rdma/dl_hccp_def.h | 647 +++++++++++++++++ src/modules/transport/rdma/rdma_manager.h | 419 +++++++++++ src/modules/transport/shmemi_rdma.cpp | 75 +- 17 files changed, 2674 insertions(+), 40 deletions(-) create mode 100644 examples/rdma_test/CMakeLists.txt create mode 100644 examples/rdma_test/main.cpp create mode 100644 examples/rdma_test/rdma_test_kernel.cpp create mode 100644 src/modules/transport/rdma/device_qp_manager.cpp create mode 100644 src/modules/transport/rdma/device_qp_manager.h create mode 100644 src/modules/transport/rdma/dl_hccp_api.cpp create mode 100644 src/modules/transport/rdma/dl_hccp_api.h create mode 100644 src/modules/transport/rdma/dl_hccp_def.h create mode 100644 src/modules/transport/rdma/rdma_manager.h diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index f09b1614..83389522 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -54,6 +54,7 @@ endfunction() foreach(EXAMPLE allgather + rdma_test # matmul_allreduce # rdma_perftest # rdma_demo diff --git a/examples/rdma_test/CMakeLists.txt b/examples/rdma_test/CMakeLists.txt new file mode 100644 index 00000000..5593c77a --- /dev/null +++ b/examples/rdma_test/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +shmem_add_collective_example(rdma_test main.cpp) \ No newline at end of file diff --git a/examples/rdma_test/main.cpp b/examples/rdma_test/main.cpp new file mode 100644 index 00000000..d67b8ba0 --- /dev/null +++ b/examples/rdma_test/main.cpp @@ -0,0 +1,212 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include +#include +#include +#include +#include +#include "acl/acl.h" +#include "shmem_api.h" +#include "shmemi_host_common.h" + +int g_npus = 8; +const char *ipport; +int f_rank = 0; +int f_npu = 0; + +extern void qpinfo_demo(uint32_t block_dim, void* stream, uint8_t* gva, uint32_t destRankId, uint32_t qpIdx); +extern void shm_rdma_write_test_do(void* stream, uint8_t* gva, uint64_t heap_size, uint8_t* remote_gva); +extern void shm_rdma_write_test_poll_cq_do(void* stream, uint8_t* gva, uint64_t heap_size, uint8_t* remote_gva); + +int test_shmem_rdma(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + // 初始化ACL和SHMEM + int32_t device_id = rank_id % g_npus + f_npu; + int status = 0; + aclrtStream stream = nullptr; + + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); + status = aclrtCreateStream(&stream); + + shmem_init_attr_t *attributes; + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); + + void* dev_ptr; + aclrtMalloc(&dev_ptr, 120 * 8, ACL_MEM_MALLOC_HUGE_FIRST); + uint64_t *xHost; + size_t totalSize = 120; + size_t elementCount = totalSize / sizeof(uint64_t); + aclrtMallocHost((void **)(&xHost), totalSize); + std::fill(xHost, xHost + elementCount, 0); + + for (uint32_t curRank = 0; curRank < n_ranks; curRank++) { + if (curRank == rank_id) { + continue; + } + qpinfo_demo(1, stream, (uint8_t*)dev_ptr + rank_id * totalSize, curRank, 0); + aclrtSynchronizeStream(stream); + sleep(1); + + aclrtMemcpy(xHost, totalSize, (uint8_t*)dev_ptr + rank_id * totalSize, totalSize, ACL_MEMCPY_DEVICE_TO_HOST); + for (uint32_t i = 0; i < totalSize / sizeof(uint64_t); i++) { + printf("GetQPInfo srcRank = %d, destRank = %d, index = %d, value = %lu\n", rank_id, curRank, i, xHost[i]); + } + } + + aclrtFreeHost(xHost); + aclrtFree(dev_ptr); + + status = shmem_finalize(); + status = aclrtDestroyStream(stream); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + return 0; +} + +int test_shmem_rdma_put(int rank_id, int n_ranks, uint64_t local_mem_size, uint64_t remote_gva) +{ + // 初始化ACL和SHMEM + int32_t device_id = rank_id % g_npus + f_npu; + int status = 0; + aclrtStream stream = nullptr; + + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); + status = aclrtCreateStream(&stream); + + shmem_init_attr_t *attributes; + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); + + void* dev_ptr; + aclrtMalloc(&dev_ptr, 64 * 16, ACL_MEM_MALLOC_HUGE_FIRST); + uint32_t *xHost; + size_t messageSize = 64; + size_t totalSize = messageSize * n_ranks; + aclrtMallocHost((void **)(&xHost), totalSize); + for (uint32_t i = 0; i < messageSize / sizeof(uint32_t); i++) { + xHost[i] = rank_id; + } + + if (rank_id == 0) { + aclrtMemcpy((uint8_t*)dev_ptr + rank_id * messageSize + rank_id * messageSize, + messageSize, xHost, messageSize, ACL_MEMCPY_HOST_TO_DEVICE); + shm_rdma_write_test_do(stream, (uint8_t*)dev_ptr, messageSize, (uint8_t*)remote_gva); + if (aclrtSynchronizeStream(stream) != 0) { + std::cout << "[ERROR] aclrtSynchronizeStream failed." << std::endl; + } + } + sleep(1); + + aclrtMemcpy(xHost, totalSize, (uint8_t*)dev_ptr + rank_id * messageSize, totalSize, ACL_MEMCPY_DEVICE_TO_HOST); + if (rank_id == 1) { + for (uint32_t i = 0; i < n_ranks; i++) { + if (xHost[i * messageSize / sizeof(uint32_t)] != i) { + std::cout << "[ERROR] Put result check error at " << i << std::endl; + } + } + } + + aclrtFreeHost(xHost); + aclrtFree(dev_ptr); + + status = shmem_finalize(); + status = aclrtDestroyStream(stream); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + return 0; +} + +int test_shmem_rdma_put_poll_cq(int rank_id, int n_ranks, uint64_t local_mem_size, uint64_t remote_gva) +{ + // 初始化ACL和SHMEM + int32_t device_id = rank_id % g_npus + f_npu; + int status = 0; + aclrtStream stream = nullptr; + + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); + status = aclrtCreateStream(&stream); + + shmem_init_attr_t *attributes; + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); + + void* dev_ptr = shmem_malloc(1024); + std::cout << "gva address = " << dev_ptr << std::endl; + uint32_t *xHost; + size_t messageSize = 64; + size_t totalSize = messageSize * n_ranks; + aclrtMallocHost((void **)(&xHost), totalSize); + for (uint32_t i = 0; i < messageSize / sizeof(uint32_t); i++) { + xHost[i] = rank_id + 10; + } + + if (rank_id == 0) { + aclrtMemcpy((uint8_t*)dev_ptr + 128, + messageSize, xHost, messageSize, ACL_MEMCPY_HOST_TO_DEVICE); + shm_rdma_write_test_poll_cq_do(stream, (uint8_t*)dev_ptr, messageSize, (uint8_t*)dev_ptr); + if (aclrtSynchronizeStream(stream) != 0) { + std::cout << "[ERROR] aclrtSynchronizeStream failed." << std::endl; + } + } + sleep(1); + + if (rank_id == 1) { + aclrtMemcpy(xHost, totalSize, (uint8_t*)dev_ptr, totalSize, ACL_MEMCPY_DEVICE_TO_HOST); + for (uint32_t i = 0; i < n_ranks; i++) { + if (xHost[i * messageSize / sizeof(uint32_t)] != i + 10) { + std::cout << "[ERROR] Put result check error at " << i << std::endl; + } + } + } else { + aclrtMemcpy(xHost, totalSize, (uint8_t*)dev_ptr, totalSize, ACL_MEMCPY_DEVICE_TO_HOST); + for (uint32_t i = 0; i < totalSize / sizeof(uint64_t); i++) { + printf("GetQPInfo srcRank = %d, index = %d, value = %lu\n", rank_id, i, ((uint64_t*)xHost)[i]); + } + } + + aclrtFreeHost(xHost); + aclrtFree(dev_ptr); + + status = shmem_finalize(); + status = aclrtDestroyStream(stream); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + return 0; +} + +int main(int argc, char *argv[]) +{ + int status = 0; + // 初始化MPI环境 + MPI_Init(&argc, &argv); + + // 获取当前进程的编号(rank) + int n_ranks; + MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + int rank_id; + MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); + ipport = argv[1]; + g_npus = atoi(argv[2]); + f_rank = atoi(argv[3]); + f_npu = atoi(argv[4]); + uint64_t remote_gva = atol(argv[5]); + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + // status = test_shmem_rdma_put(rank_id, n_ranks, local_mem_size, remote_gva); + status = test_shmem_rdma_put_poll_cq(rank_id, n_ranks, local_mem_size, remote_gva); + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; + MPI_Finalize(); + + return 0; +} \ No newline at end of file diff --git a/examples/rdma_test/rdma_test_kernel.cpp b/examples/rdma_test/rdma_test_kernel.cpp new file mode 100644 index 00000000..9247d05b --- /dev/null +++ b/examples/rdma_test/rdma_test_kernel.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef _RDMA_TEST_KERNEL_ +#define _RDMA_TEST_KERNEL_ + +#include "kernel_operator.h" +#include "shmem_api.h" + +constexpr uint32_t MESSAGE_SIZE = 64; + +extern "C" __global__ __aicore__ void shm_rdma_write_qpinfo_test(GM_ADDR gva, uint32_t destRankId, uint32_t qpIdx) +{ + shmemi_roce_qpinfo_test(gva, destRankId, qpIdx); +} + +void qpinfo_demo(uint32_t block_dim, void* stream, uint8_t* gva, uint32_t destRankId, uint32_t qpIdx) +{ + shm_rdma_write_qpinfo_test<<>>(gva, destRankId, qpIdx); +} + +extern "C" __global__ __aicore__ void shm_rdma_write_test(GM_ADDR gva, uint64_t heap_size, GM_ADDR remote_gva) +{ + AscendC::TPipe pipe; + AscendC::TBuf buf; + pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); + AscendC::LocalTensor ubLocal32 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); + AscendC::LocalTensor ubLocal64 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); + auto myRank = 0; + auto totalRank = 2; + for (int i = 0; i < totalRank; i++) { + if (i == myRank) { + continue; + } + shmemi_roce_write(gva + myRank * heap_size + myRank * MESSAGE_SIZE, + remote_gva + i * heap_size + myRank * MESSAGE_SIZE, i, 0, MESSAGE_SIZE, ubLocal64, ubLocal32); + } +} + +void shm_rdma_write_test_do(void* stream, uint8_t* gva, uint64_t heap_size, uint8_t* remote_gva) +{ + shm_rdma_write_test<<<1, nullptr, stream>>>(gva, heap_size, remote_gva); +} + +extern "C" __global__ __aicore__ void shm_rdma_write_test_poll_cq(GM_ADDR gva, uint64_t heap_size, GM_ADDR remote_gva) +{ + AscendC::TPipe pipe; + AscendC::TBuf buf; + pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); + AscendC::LocalTensor ubLocal32 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); + AscendC::LocalTensor ubLocal64 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); + auto myRank = 0; + auto totalRank = 2; + shmemi_roce_pollcq_test(gva + 128, remote_gva, 1, 0, MESSAGE_SIZE, ubLocal64, ubLocal32, gva); +} + +void shm_rdma_write_test_poll_cq_do(void* stream, uint8_t* gva, uint64_t heap_size, uint8_t* remote_gva) +{ + shm_rdma_write_test_poll_cq<<<1, nullptr, stream>>>(gva, heap_size, remote_gva); +} + +#endif // _RDMA_DEMO_KERNEL_ \ No newline at end of file diff --git a/include/device/low_level/shmem_device_low_level_roce.h b/include/device/low_level/shmem_device_low_level_roce.h index d8997dff..60c6fab4 100644 --- a/include/device/low_level/shmem_device_low_level_roce.h +++ b/include/device/low_level/shmem_device_low_level_roce.h @@ -116,9 +116,8 @@ SHMEM_DEVICE uint32_t shmemi_roce_poll_cq(uint32_t remoteRankId, uint32_t qpIdx, AscendC::LocalTensor ubLocal64, AscendC::LocalTensor ubLocal32) { - __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + - SMEM_SHM_DEVICE_GLOBAL_META_SIZE); - __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); + __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(device_state->qp_info); uint32_t qpNum = RDMAInfo->qpNum; __gm__ SHMEMCQCtx* cqCtxEntry = (__gm__ SHMEMCQCtx*)(RDMAInfo->scqPtr + (remoteRankId * qpNum + qpIdx) * sizeof(SHMEMCQCtx)); auto cqBaseAddr = cqCtxEntry->bufAddr; @@ -212,9 +211,8 @@ SHMEM_DEVICE void shmemi_rdma_post_send(__gm__ uint8_t* remoteAddr, __gm__ uint8 AscendC::LocalTensor ubLocal64, AscendC::LocalTensor ubLocal32) { - __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + - SMEM_SHM_DEVICE_GLOBAL_META_SIZE); - __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); + __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(device_state->qp_info); uint32_t qpNum = RDMAInfo->qpNum; __gm__ SHMEMWQCtx* qpCtxEntry = (__gm__ SHMEMWQCtx*)(RDMAInfo->sqPtr + (destRankId * qpNum + qpIdx) * sizeof(SHMEMWQCtx)); auto SHMEMmemInfoTable = RDMAInfo->memPtr; @@ -345,9 +343,8 @@ SHMEM_DEVICE void shmemi_roce_quiet(uint32_t remoteRankId, uint32_t qpIdx, AscendC::LocalTensor ubLocal64, AscendC::LocalTensor ubLocal32) { - __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + - SMEM_SHM_DEVICE_GLOBAL_META_SIZE); - __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); + __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(device_state->qp_info); uint32_t qpNum = RDMAInfo->qpNum; __gm__ SHMEMWQCtx* qpCtxEntry = (__gm__ SHMEMWQCtx*)(RDMAInfo->sqPtr + (remoteRankId * qpNum + qpIdx) * sizeof(SHMEMWQCtx)); auto curHardwareHeadAddr = qpCtxEntry->headAddr; @@ -358,9 +355,8 @@ SHMEM_DEVICE void shmemi_roce_quiet(uint32_t remoteRankId, uint32_t qpIdx, SHMEM_DEVICE void shmemi_roce_qpinfo_test(__gm__ uint8_t* gva, uint32_t destRankId, uint32_t qpIdx) { - __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + - SMEM_SHM_DEVICE_GLOBAL_META_SIZE); - __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); + __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(device_state->qp_info); *(__gm__ uint64_t*)(gva) = (uint64_t)RDMAInfo; uint32_t qpNum = RDMAInfo->qpNum; *(__gm__ uint64_t*)(gva + 8) = (uint64_t)qpNum; @@ -409,9 +405,8 @@ SHMEM_DEVICE void shmemi_roce_pollcq_test(__gm__ T* srcDmaAddr, __gm__ T* destDm shmemi_rdma_post_send(destDmaAddr, srcDmaAddr, destRankId, qpIdx, SHMEMAIVOPCODE::OP_RDMA_WRITE, messageLen, ubLocal64, ubLocal32); uint32_t idx = 1; - __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + - SMEM_SHM_DEVICE_GLOBAL_META_SIZE); - __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); + __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(device_state->qp_info); uint32_t qpNum = RDMAInfo->qpNum; __gm__ SHMEMCQCtx* cqCtxEntry = (__gm__ SHMEMCQCtx*)(RDMAInfo->scqPtr + (destRankId * qpNum + qpIdx) * sizeof(SHMEMCQCtx)); *(__gm__ uint64_t*)(gva) = (uint64_t)cqCtxEntry; diff --git a/include/internal/host_device/shmemi_types.h b/include/internal/host_device/shmemi_types.h index a82ac8be..cc60482b 100644 --- a/include/internal/host_device/shmemi_types.h +++ b/include/internal/host_device/shmemi_types.h @@ -96,6 +96,7 @@ typedef struct { bool is_shmem_created; shmemi_mte_config_t mte_config; + uint64_t qp_info; } shmemi_device_host_state_t; #ifdef __cplusplus diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3b992d65..136c47d3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -86,6 +86,29 @@ if(SHMEM_MPI_SUPPORT) endif() +set(SHMEM_RDMA_SUPPORT ON) +if(SHMEM_RDMA_SUPPORT) + add_library( + shmem_transport_rdma SHARED + ) + target_sources(shmem_transport_rdma PRIVATE + modules/transport/shmemi_rdma.cpp + modules/transport/rdma/device_qp_manager.cpp + modules/transport/rdma/dl_hccp_api.cpp + ) + target_link_libraries(shmem_transport_rdma PRIVATE MPI::MPI_CXX) + target_include_directories(shmem_transport_rdma + PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src/host + ${PROJECT_SOURCE_DIR}/src/modules + ) + set_target_properties(shmem_transport_rdma PROPERTIES PREFIX "") + install(TARGETS shmem_transport_rdma + LIBRARY DESTINATION lib + ) +endif() + # 安装配置 install(TARGETS shmem LIBRARY DESTINATION lib diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index bd0f0805..d2686fee 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -51,6 +51,7 @@ constexpr int DEFAULT_BLOCK_NUM = 1; false, /* shmem_is_shmem_initialized */ \ false, /* shmem_is_shmem_created */ \ {0, 16 * 1024, 0}, /* shmem_mte_config */ \ + 0, /* qp_info */ \ } shmemi_device_host_state_t g_state = SHMEM_DEVICE_HOST_STATE_INITIALIZER; diff --git a/src/host/mem/shmemi_heap.cpp b/src/host/mem/shmemi_heap.cpp index de8abf87..d8f955b4 100644 --- a/src/host/mem/shmemi_heap.cpp +++ b/src/host/mem/shmemi_heap.cpp @@ -43,6 +43,8 @@ int shmem_symmetric_heap::reserve_heap(size_t size) SHMEM_CHECK_RET(aclrtMallocPhysical(&local_handle, size, &memprop, 0)); alloc_size = size; + SHMEM_CHECK_RET(aclrtMapMem(peer_heap_base_p2p_[mype], alloc_size, 0, local_handle, 0)); + return SHMEM_SUCCESS; } @@ -121,7 +123,7 @@ int shmem_symmetric_heap::setup_heap() // Shareable Handle Map for (int i = 0; i < npes; i++) { // Check if p2p connected - if (peer_heap_base_p2p_[i] != NULL) { + if (i != mype && peer_heap_base_p2p_[i] != NULL) { SHMEM_CHECK_RET(aclrtMapMem(peer_heap_base_p2p_[i], alloc_size, 0, physical_handle_list[i], 0)); } } diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp index 05401036..ef3af22a 100644 --- a/src/host/transport/shmemi_transport.cpp +++ b/src/host/transport/shmemi_transport.cpp @@ -13,6 +13,30 @@ #include "transport/shmemi_transport.h" static void *transport_mte_lib = NULL; +static void *rdma_plugin_hdl = nullptr; +static char *rdma_plugin_name = nullptr; + +int (*shmemi_rdma_init)(shmemi_device_host_state_t *state, shmemi_transport_t *t); + +uint64_t *host_hash_list; + +void shmemi_transport_load() +{ + dlerror(); + if (rdma_plugin_hdl == nullptr) { + + rdma_plugin_hdl = dlopen(rdma_plugin_name, RTLD_NOW); + } + dlerror(); +} + +void shmemi_transport_unload() +{ + if (rdma_plugin_hdl != nullptr) { + dlclose(rdma_plugin_hdl); + rdma_plugin_hdl = nullptr; + } +} shmemi_host_state_t g_host_state; @@ -26,6 +50,8 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state) { SHM_LOG_ERROR("Transport unable to load " << "shmem_transport_mte.so" << ", err is: " << stderr); return SHMEM_INVALID_VALUE; } + rdma_plugin_name = TRANSPORT_MODULE_RDMA; + shmemi_transport_load(); transport_init_func init_fn; init_fn = (transport_init_func)dlsym(transport_mte_lib, "shmemi_mte_init"); @@ -50,6 +76,20 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state) { SHMEM_CHECK_RET(init_fn(&g_host_state.choosen_transports[0], &g_state)); + if (!rdma_plugin_hdl) { + SHM_LOG_ERROR("Bootstrap unable to load " << rdma_plugin_name << ", err is: " << stderr); + shmemi_transport_unload(); + return SHMEM_INVALID_VALUE; + } + + *((void **)&shmemi_rdma_init) = dlsym(rdma_plugin_hdl, "shmemi_rdma_init"); + if (!shmemi_rdma_init) { + SHM_LOG_ERROR("Bootstrap plugin init func dlsym failed"); + shmemi_transport_unload(); + return SHMEM_INNER_ERROR; + } + SHMEM_CHECK_RET(shmemi_rdma_init(&g_state, nullptr)); + return SHMEM_SUCCESS; } diff --git a/src/modules/transport/rdma/device_qp_manager.cpp b/src/modules/transport/rdma/device_qp_manager.cpp new file mode 100644 index 00000000..d080c714 --- /dev/null +++ b/src/modules/transport/rdma/device_qp_manager.cpp @@ -0,0 +1,680 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "dl_hccp_api.h" +#include "device_qp_manager.h" +using namespace shm; + +DeviceQpManager::DeviceQpManager(uint32_t deviceId, uint32_t rankId, uint32_t rankCount, sockaddr_in devNet, + hybm_role_type role) noexcept + : deviceId_{deviceId}, + rankId_{rankId}, + rankCount_{rankCount}, + deviceAddress_{devNet}, + rankRole_{role} +{ +} + +void *DeviceQpManager::CreateLocalSocket() noexcept +{ + void *socketHandle = nullptr; + HccpRdev rdev; + rdev.phyId = deviceId_; + rdev.family = AF_INET; + rdev.localIp.addr = deviceAddress_.sin_addr; + auto ret = DlHccpApi::RaSocketInit(HccpNetworkMode::NETWORK_OFFLINE, rdev, socketHandle); + if (ret != 0) { + SHM_LOG_ERROR("initialize socket handle failed: " << ret); + return nullptr; + } + + return socketHandle; +} + +int DeviceQpManager::CreateServerSocket() noexcept +{ + if (serverSocketHandle_ != nullptr) { + return SHMEM_SUCCESS; + } + + auto socketHandle = CreateLocalSocket(); + if (socketHandle == nullptr) { + SHM_LOG_ERROR("create local socket handle failed."); + return SHMEM_INNER_ERROR; + } + + HccpSocketListenInfo listenInfo{}; + listenInfo.handle = socketHandle; + listenInfo.port = deviceAddress_.sin_port; + bool successListen = false; + while (listenInfo.port <= std::numeric_limits::max()) { + auto ret = DlHccpApi::RaSocketListenStart(&listenInfo, 1); + if (ret == 0) { + deviceAddress_.sin_port = listenInfo.port; + successListen = true; + break; + } + listenInfo.port++; + } + if (!successListen) { + SHM_LOG_ERROR("start to listen server socket failed."); + DlHccpApi::RaSocketDeinit(socketHandle); + return SHMEM_INNER_ERROR; + } + + SHM_LOG_INFO("start to listen on port: " << listenInfo.port << " success."); + serverSocketHandle_ = socketHandle; + return SHMEM_SUCCESS; +} + +void DeviceQpManager::DestroyServerSocket() noexcept +{ + if (serverSocketHandle_ == nullptr) { + return; + } + + HccpSocketListenInfo listenInfo{}; + listenInfo.handle = serverSocketHandle_; + listenInfo.port = deviceAddress_.sin_port; + auto ret = DlHccpApi::RaSocketListenStop(&listenInfo, 1); + if (ret != 0) { + SHM_LOG_INFO("stop to listen on port: " << listenInfo.port << " return: " << ret); + } + + ret = DlHccpApi::RaSocketDeinit(serverSocketHandle_); + if (ret != 0) { + SHM_LOG_INFO("deinit server socket return: " << ret); + } + serverSocketHandle_ = nullptr; +} + +static constexpr uint32_t SEND_CQ_DEPTH = 8192; +static constexpr uint32_t RECV_CQ_DEPTH = 128; +static constexpr uint32_t MAX_SEND_WR = 8192; +static constexpr uint32_t MAX_RECV_WR = 128; +static constexpr uint32_t QP_MODE = 2; + +DeviceQpManager::~DeviceQpManager() noexcept +{ + CloseServices(); +} + +int DeviceQpManager::SetRemoteRankInfo(const std::unordered_map &ranks) noexcept +{ + if (started_) { + SHM_LOG_ERROR("fixed ranks not support update ranks info after startup"); + return SHMEM_INNER_ERROR; + } + + currentRanksInfo_ = ranks; + return SHMEM_SUCCESS; +} + +int DeviceQpManager::SetLocalMemories(const MemoryRegionMap &mrs) noexcept +{ + if (started_) { + SHM_LOG_INFO("fixed ranks not support update register MRs after startup"); + return SHMEM_SUCCESS; + } + + currentLocalMrs_ = mrs; + return SHMEM_SUCCESS; +} + +int DeviceQpManager::Startup(void *rdma) noexcept +{ + if (rdma == nullptr) { + SHM_LOG_ERROR("input rdma is null"); + return SHMEM_INVALID_PARAM; + } + + if (started_) { + SHM_LOG_ERROR("already started."); + return SHMEM_INNER_ERROR; + } + + rdmaHandle_ = rdma; + if (!ReserveQpInfoSpace()) { + SHM_LOG_ERROR("reserve qp info space failed."); + return SHMEM_INNER_ERROR; + } + + if (currentRanksInfo_.size() != rankCount_) { + SHM_LOG_ERROR("set rank count = " << currentRanksInfo_.size() << ", but rank_size = " << rankCount_); + return SHMEM_INVALID_PARAM; + } + + for (auto it = currentRanksInfo_.begin(); it != currentRanksInfo_.end(); ++it) { + if (it->first >= rankCount_) { + SHM_LOG_ERROR("input options of nics contains rankId:" << it->first << ", rank count: " << rankCount_); + return SHMEM_INVALID_PARAM; + } + } + + auto ret = StartServerSide(); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("start server side failed: " << ret); + return ret; + } + + ret = StartClientSide(); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("start client side failed: " << ret); + return ret; + } + + started_ = true; + return SHMEM_SUCCESS; +} + +void DeviceQpManager::Shutdown() noexcept +{ + CloseServices(); +} + +int DeviceQpManager::WaitingConnectionReady() noexcept +{ + if (serverConnectResult == SHMEM_SUCCESS && clientConnectResult == SHMEM_SUCCESS) { + SHM_LOG_INFO("client & server connections ready."); + return SHMEM_SUCCESS; + } + + SHM_LOG_ERROR("background connection thread not started."); + return SHMEM_INNER_ERROR; +} + +void *DeviceQpManager::GetQpInfoAddress() const noexcept +{ + return qpInfo_; +} + +void *DeviceQpManager::GetQpHandleWithRankId(uint32_t rankId) const noexcept +{ + auto connections = rankId < rankId_ ? &clientConnections_ : &serverConnections_; + auto pos = connections->find(rankId); + if (pos == connections->end()) { + return nullptr; + } + + return pos->second.qpHandles[CONN_QP_STARS]; +} + +bool DeviceQpManager::ReserveQpInfoSpace() noexcept +{ + if (qpInfo_ != nullptr) { + return true; + } + + void *ptr = nullptr; + auto oneQpSize = 2U * (sizeof(AiQpRMAWQ) + sizeof(AiQpRMACQ)) + sizeof(RdmaMemRegionInfo); + qpInfoSize_ = sizeof(AiQpRMAQueueInfo) + oneQpSize * rankCount_; + auto ret = aclrtMalloc(&ptr, qpInfoSize_, ACL_MEM_MALLOC_HUGE_FIRST); + if (ret != 0) { + SHM_LOG_ERROR("allocate device size: " << qpInfoSize_ << ", failed: " << ret); + return false; + } + + qpInfo_ = (AiQpRMAQueueInfo *)ptr; + return true; +} + +int DeviceQpManager::StartServerSide() noexcept +{ + if (rankId_ + 1U == rankCount_) { + serverConnectResult = 0; + return SHMEM_SUCCESS; + } + + auto ret = CreateServerSocket(); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("create server socket failed: " << ret); + return ret; + } + + ret = GenerateWhiteList(); + if (ret != 0) { + SHM_LOG_ERROR("generate white list failed: " << ret); + return SHMEM_INNER_ERROR; + } + + aclrtSetDevice(deviceId_); + ret = WaitConnectionsReady(serverConnections_); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("wait connection ready failed: " << ret); + serverConnectResult = ret; + return SHMEM_INNER_ERROR; + } + ret = CreateQpWaitingReady(serverConnections_, CONN_QP_AI_CORE); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("wait connection AI qp ready failed: " << ret); + serverConnectResult = ret; + } + + ret = CreateQpWaitingReady(serverConnections_, CONN_QP_STARS); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("wait connection STARS qp ready failed: " << ret); + serverConnectResult = ret; + } + + serverConnectResult = SHMEM_SUCCESS; + + return SHMEM_SUCCESS; +} + +int DeviceQpManager::StartClientSide() noexcept +{ + if (rankId_ == 0U) { + SHM_LOG_INFO("rankId: " << rankId_ << " need not connect to others."); + clientConnectResult = SHMEM_SUCCESS; + return SHMEM_SUCCESS; + } + + std::vector connectInfos; + for (auto it = currentRanksInfo_.begin(); it != currentRanksInfo_.end(); ++it) { + if (it->first >= rankId_) { + continue; // client connect to small ranks. + } + + auto socketHandle = CreateLocalSocket(); + if (socketHandle == nullptr) { + SHM_LOG_ERROR("create local socket handle failed"); + CloseClientConnections(); + return SHMEM_INNER_ERROR; + } + + clientConnections_.emplace(it->first, ConnectionChannel{it->second.network.sin_addr, socketHandle}); + HccpSocketConnectInfo connectInfo; + connectInfo.handle = socketHandle; + connectInfo.remoteIp.addr = it->second.network.sin_addr; + connectInfo.port = it->second.network.sin_port; + bzero(connectInfo.tag, sizeof(connectInfo.tag)); + SHM_LOG_DEBUG("add connecting server " << connectInfo); + connectInfos.emplace_back(connectInfo); + } + + auto ret = DlHccpApi::RaSocketBatchConnect(connectInfos.data(), connectInfos.size()); + if (ret != 0) { + SHM_LOG_ERROR("connect to all servers failed: " << ret << ", servers count = " << connectInfos.size()); + CloseClientConnections(); + return SHMEM_INNER_ERROR; + } + + aclrtSetDevice(deviceId_); + ret = WaitConnectionsReady(clientConnections_); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("client wait connections failed: " << ret); + CloseClientConnections(); + return ret; + } + + ret = CreateQpWaitingReady(clientConnections_, CONN_QP_AI_CORE); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("client create qp for AI CORE failed: " << ret); + CloseClientConnections(); + return ret; + } + + ret = CreateQpWaitingReady(clientConnections_, CONN_QP_STARS); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("client create qp for STARS failed: " << ret); + CloseClientConnections(); + return ret; + } + clientConnectResult = SHMEM_SUCCESS; + return SHMEM_SUCCESS; +} + +int DeviceQpManager::GenerateWhiteList() noexcept +{ + std::vector whitelist; + for (auto it = currentRanksInfo_.begin(); it != currentRanksInfo_.end(); ++it) { + if (it->first <= rankId_) { + continue; // small id as server, large id as client + } + HccpSocketWhiteListInfo info{}; + info.remoteIp.addr = it->second.network.sin_addr; + info.connLimit = rankCount_; + bzero(info.tag, sizeof(info.tag)); + whitelist.emplace_back(info); + serverConnections_.emplace(it->first, ConnectionChannel{info.remoteIp.addr, serverSocketHandle_}); + } + + if (whitelist.empty()) { + return SHMEM_SUCCESS; + } + + auto ret = DlHccpApi::RaSocketWhiteListAdd(serverSocketHandle_, whitelist.data(), whitelist.size()); + if (ret != 0) { + SHM_LOG_ERROR("socket handle add white list failed: " << ret); + return SHMEM_INNER_ERROR; + } + + return SHMEM_SUCCESS; +} + +int DeviceQpManager::WaitConnectionsReady(std::unordered_map &connections) noexcept +{ + uint32_t totalSuccessCount = 0; + auto start = std::chrono::steady_clock::now(); + auto timeout = start + std::chrono::minutes(2); + while (totalSuccessCount < connections.size()) { + if (std::chrono::steady_clock::now() >= timeout) { + SHM_LOG_ERROR("waiting connection ready timeout."); + return SHMEM_INNER_ERROR; + } + + uint32_t successCount = 0; + std::vector socketInfos; + std::unordered_map addr2index; + for (auto it = connections.begin(); it != connections.end(); ++it) { + if (it->second.socketFd != nullptr) { + continue; + } + + HccpSocketInfo info{}; + info.handle = it->second.socketHandle; + info.fd = nullptr; + info.remoteIp.addr = it->second.remoteIp; + info.status = 0; + bzero(info.tag, sizeof(info.tag)); + socketInfos.push_back(info); + addr2index.emplace(it->second.remoteIp.s_addr, it->first); + } + + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + auto role = (&connections == &clientConnections_) ? 1 : 0; + auto ret = DlHccpApi::RaGetSockets(role, socketInfos.data(), socketInfos.size(), successCount); + if (ret != 0) { + SHM_LOG_ERROR("role(" << role << ") side get sockets failed: " << ret); + return SHMEM_INNER_ERROR; + } + + for (auto i = 0U; i < successCount; i++) { + auto socketInfoPos = addr2index.find(socketInfos[i].remoteIp.addr.s_addr); + if (socketInfoPos == addr2index.end()) { + SHM_LOG_ERROR("socket ip(" << inet_ntoa(socketInfos[i].remoteIp.addr) << ") should not exist."); + return SHMEM_INNER_ERROR; + } + + auto rankId = socketInfoPos->second; + auto pos = connections.find(rankId); + if (pos == connections.end()) { + SHM_LOG_ERROR("socket ip(" << inet_ntoa(socketInfos[i].remoteIp.addr) << ") should not exist."); + return SHMEM_INNER_ERROR; + } + + if (pos->second.socketFd != nullptr) { + SHM_LOG_ERROR("get socket ip(" << inet_ntoa(socketInfos[i].remoteIp.addr) << ") already get socket fd."); + return SHMEM_INNER_ERROR; + } + + if (pos->second.socketHandle != socketInfos[i].handle) { + SHM_LOG_ERROR("get socket ip(" << inet_ntoa(socketInfos[i].remoteIp.addr) + << ") socket handle not match."); + return SHMEM_INNER_ERROR; + } + + pos->second.socketFd = socketInfos[i].fd; + SHM_LOG_INFO("connect to (" << rankId << ") ready."); + } + + totalSuccessCount += successCount; + } + + return SHMEM_SUCCESS; +} + +int DeviceQpManager::CreateQpWaitingReady(std::unordered_map &connections, + ConnQpType qpType) noexcept +{ + for (auto it = connections.begin(); it != connections.end(); ++it) { + auto ret = CreateOneQp(qpType, it->second); + if (ret != 0) { + SHM_LOG_ERROR("create QP type:" << qpType << " to " << it->first << " failed: " << ret); + return SHMEM_INNER_ERROR; + } + + for (auto pos = currentLocalMrs_.begin(); pos != currentLocalMrs_.end(); ++pos) { + HccpMrInfo info{}; + info.addr = (void *)(ptrdiff_t)pos->second.address; + info.size = pos->second.size; + info.access = 7; + ret = DlHccpApi::RaMrReg(it->second.qpHandles[qpType], info); + if (ret != 0) { + SHM_LOG_ERROR("register MR failed: " << ret); + return SHMEM_INNER_ERROR; + } + } + + ret = DlHccpApi::RaQpConnectAsync(it->second.qpHandles[qpType], it->second.socketFd); + if (ret != 0) { + SHM_LOG_ERROR("connect AI QP to " << it->first << " failed: " << ret); + return SHMEM_INNER_ERROR; + } + } + + auto start = std::chrono::steady_clock::now(); + auto timeout = start + std::chrono::minutes(1); + while (std::chrono::steady_clock::now() < timeout) { + int connectingCount = 0; + for (auto it = connections.begin(); it != connections.end(); ++it) { + int status = 0; + auto ret = DlHccpApi::RaGetQpStatus(it->second.qpHandles[qpType], status); + if (ret != 0) { + SHM_LOG_ERROR("get AI QP status to " << it->first << " failed: " << ret); + return SHMEM_INNER_ERROR; + } + if (status != 1) { + connectingCount++; + } + } + if (connectingCount == 0) { + return FillQpInfo(qpType); + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + return SHMEM_INNER_ERROR; +} + +int DeviceQpManager::CreateOneQp(ConnQpType qpType, ConnectionChannel &channel) noexcept +{ + int ret; + if (qpType == CONN_QP_AI_CORE) { + HccpQpExtAttrs attr{}; + attr.qpMode = NETWORK_OFFLINE; + attr.version = 1; + attr.cqAttr.sendCqDepth = SEND_CQ_DEPTH; + attr.cqAttr.recvDqDepth = RECV_CQ_DEPTH; + attr.qp_attr.cap.max_recv_sge = 1; + attr.qp_attr.cap.max_recv_wr = MAX_RECV_WR; + attr.qp_attr.cap.max_recv_sge = 1; + attr.qp_attr.qp_type = IBV_QPT_RC; + attr.qp_attr.cap.max_send_wr = MAX_SEND_WR; + attr.data_plane_flag.bs.cq_cstm = 1; + ret = DlHccpApi::RaQpAiCreate(rdmaHandle_, attr, channel.aiQpInfo, channel.qpHandles[qpType]); + } else { + ret = DlHccpApi::RaQpCreate(rdmaHandle_, 0, QP_MODE, channel.qpHandles[qpType]); + } + return ret; +} + +int DeviceQpManager::FillQpInfo(ConnQpType qpType) noexcept +{ + if (qpType != CONN_QP_AI_CORE) { + return SHMEM_SUCCESS; + } + + const uint32_t slevel = 4; + std::vector qpInfoBuffer(qpInfoSize_); + auto copyInfo = (AiQpRMAQueueInfo *)(void *)qpInfoBuffer.data(); + copyInfo->count = 1; + copyInfo->sq = (AiQpRMAWQ *)(void *)(copyInfo + 1); + copyInfo->rq = (AiQpRMAWQ *)(void *)(copyInfo->sq + rankCount_); + copyInfo->scq = (AiQpRMACQ *)(void *)(copyInfo->rq + rankCount_); + copyInfo->rcq = (AiQpRMACQ *)(void *)(copyInfo->scq + rankCount_); + copyInfo->mr = (RdmaMemRegionInfo *)(void *)(copyInfo->rcq + rankCount_); + for (auto it = currentRanksInfo_.begin(); it != currentRanksInfo_.end(); ++it) { + copyInfo->mr[it->first].size = it->second.mr.size; + copyInfo->mr[it->first].addr = it->second.mr.address; + copyInfo->mr[it->first].lkey = it->second.mr.lkey; + copyInfo->mr[it->first].rkey = it->second.mr.rkey; + if (it->first == rankId_) { + continue; + } + + std::unordered_map *connections; + if (it->first < rankId_) { + connections = &clientConnections_; + } else { + connections = &serverConnections_; + } + + auto pos = connections->find(it->first); + if (pos == connections->end()) { + SHM_LOG_ERROR("missing for remote: " << it->first); + return SHMEM_INNER_ERROR; + } + + CopyAiWQInfo(copyInfo->sq[it->first], pos->second.aiQpInfo.data_plane_info.sq, DBMode::HW_DB, slevel); + CopyAiWQInfo(copyInfo->rq[it->first], pos->second.aiQpInfo.data_plane_info.rq, DBMode::SW_DB, slevel); + CopyAiCQInfo(copyInfo->scq[it->first], pos->second.aiQpInfo.data_plane_info.scq, DBMode::HW_DB); + CopyAiCQInfo(copyInfo->rcq[it->first], pos->second.aiQpInfo.data_plane_info.rcq, DBMode::SW_DB); + } + + auto pointer = (ptrdiff_t)(void *)(qpInfo_); + pointer += sizeof(AiQpRMAQueueInfo); + copyInfo->sq = (AiQpRMAWQ *)(void *)(pointer); + + pointer += sizeof(AiQpRMAWQ) * rankCount_; + copyInfo->rq = (AiQpRMAWQ *)(void *)(pointer); + + pointer += sizeof(AiQpRMAWQ) * rankCount_; + copyInfo->scq = (AiQpRMACQ *)(void *)(pointer); + + pointer += sizeof(AiQpRMACQ) * rankCount_; + copyInfo->rcq = (AiQpRMACQ *)(void *)(pointer); + + pointer += sizeof(AiQpRMACQ) * rankCount_; + copyInfo->mr = (RdmaMemRegionInfo *)(void *)pointer; + + auto ret = aclrtMemcpy(qpInfo_, qpInfoSize_, copyInfo, qpInfoSize_, ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != 0) { + SHM_LOG_ERROR("copy qp info to device failed: " << ret); + return SHMEM_INNER_ERROR; + } + SHM_LOG_INFO("copy qp info success"); + + return SHMEM_SUCCESS; +} + +void DeviceQpManager::CopyAiWQInfo(struct AiQpRMAWQ &dest, const struct ai_data_plane_wq &src, DBMode dbMode, + uint32_t sl) noexcept +{ + dest.wqn = src.wqn; + dest.bufAddr = src.buf_addr; + dest.wqeSize = src.wqebb_size; + dest.depth = src.depth; + dest.headAddr = src.head_addr; + dest.tailAddr = src.tail_addr; + dest.dbMode = dbMode; + if (dbMode == DBMode::SW_DB) { + dest.dbAddr = src.swdb_addr; + } else if (dbMode == DBMode::HW_DB) { + dest.dbAddr = src.db_reg; + } + dest.sl = sl; + SHM_LOG_INFO("CopyAiWQInfo: wqn = " << dest.wqn << ", bufAddr = " << dest.bufAddr << ", wqeSize = " + << dest.wqeSize << ", depth = " << dest.depth << ", headAddr = " << dest.headAddr + << ", tailAddr = " << dest.tailAddr << ", dbAddr = " << dest.dbAddr + << ", sl = " << dest.sl); +} + +void DeviceQpManager::CopyAiCQInfo(struct AiQpRMACQ &dest, const ai_data_plane_cq &source, DBMode dbMode) noexcept +{ + dest.cqn = source.cqn; + dest.bufAddr = source.buf_addr; + dest.cqeSize = source.cqe_size; + dest.depth = source.depth; + dest.headAddr = source.head_addr; + dest.tailAddr = source.tail_addr; + dest.dbMode = dbMode; + if (dbMode == DBMode::SW_DB) { + dest.dbAddr = source.swdb_addr; + } else if (dbMode == DBMode::HW_DB) { + dest.dbAddr = source.db_reg; + } + SHM_LOG_INFO("CopyAiCQInfo: cqn = " << dest.cqn << ", bufAddr = " << dest.bufAddr << ", cqeSize = " + << dest.cqeSize << ", depth = " << dest.depth << ", headAddr = " << dest.headAddr + << ", tailAddr = " << dest.tailAddr << ", dbAddr = " << dest.dbAddr); +} + +void DeviceQpManager::CloseServices() noexcept +{ + CloseServerConnections(); + CloseClientConnections(); +} + +void DeviceQpManager::CloseClientConnections() noexcept +{ + CloseConnections(clientConnections_); +} + +void DeviceQpManager::CloseServerConnections() noexcept +{ + DestroyServerSocket(); + CloseConnections(serverConnections_); +} + +void DeviceQpManager::CloseConnections(std::unordered_map &connections) noexcept +{ + std::vector socketCloseInfos; + for (auto it = connections.begin(); it != connections.end(); ++it) { + if (it->second.qpHandles[CONN_QP_AI_CORE] != nullptr) { + auto ret = DlHccpApi::RaQpDestroy(it->second.qpHandles[CONN_QP_AI_CORE]); + if (ret != 0) { + SHM_LOG_WARN("destroy AI QP to server: " << it->first << " failed: " << ret); + } + it->second.qpHandles[CONN_QP_AI_CORE] = nullptr; + } + + if (it->second.qpHandles[CONN_QP_STARS] != nullptr) { + auto ret = DlHccpApi::RaQpDestroy(it->second.qpHandles[CONN_QP_STARS]); + if (ret != 0) { + SHM_LOG_WARN("destroy stars QP to server: " << it->first << " failed: " << ret); + } + it->second.qpHandles[CONN_QP_STARS] = nullptr; + } + + if (it->second.socketFd != nullptr) { + HccpSocketCloseInfo info; + info.handle = it->second.socketHandle; + info.fd = it->second.socketFd; + info.linger = 0; + socketCloseInfos.push_back(info); + it->second.socketFd = nullptr; + } + } + + if (!socketCloseInfos.empty()) { + auto ret = DlHccpApi::RaSocketBatchClose(socketCloseInfos.data(), socketCloseInfos.size()); + if (ret != 0) { + SHM_LOG_INFO("close sockets return: " << ret); + } + } + + for (auto it = connections.begin(); it != connections.end(); ++it) { + auto ret = DlHccpApi::RaSocketDeinit(it->second.socketHandle); + if (ret != 0) { + SHM_LOG_INFO("deinit socket to server: " << it->first << " return: " << ret); + } + } + + connections.clear(); +} \ No newline at end of file diff --git a/src/modules/transport/rdma/device_qp_manager.h b/src/modules/transport/rdma/device_qp_manager.h new file mode 100644 index 00000000..1aacb566 --- /dev/null +++ b/src/modules/transport/rdma/device_qp_manager.h @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef MF_HYBRID_DEVICE_QP_MANAGER_H +#define MF_HYBRID_DEVICE_QP_MANAGER_H + +#include +#include +#include +#include "dl_hccp_api.h" +#include + +class DeviceQpManager { +public: + DeviceQpManager(uint32_t deviceId, uint32_t rankId, uint32_t rankCount, sockaddr_in devNet, + hybm_role_type role) noexcept; + ~DeviceQpManager() noexcept; + + int SetRemoteRankInfo(const std::unordered_map &ranks) noexcept; + int SetLocalMemories(const MemoryRegionMap &mrs) noexcept; + int Startup(void *rdma) noexcept; + void Shutdown() noexcept; + int WaitingConnectionReady() noexcept; + void *GetQpInfoAddress() const noexcept; + void *GetQpHandleWithRankId(uint32_t rankId) const noexcept; + +protected: + void *CreateLocalSocket() noexcept; + int CreateServerSocket() noexcept; + void DestroyServerSocket() noexcept; + +protected: + const uint32_t deviceId_; + const uint32_t rankId_; + const uint32_t rankCount_; + const hybm_role_type rankRole_; + sockaddr_in deviceAddress_; + void *serverSocketHandle_{nullptr}; + +private: + enum ConnQpType : uint32_t { + CONN_QP_AI_CORE, // AI core使用的QP + CONN_QP_STARS, // Host侧使用STARS驱动的QP + CONN_QP_COUNT + }; + + struct ConnectionChannel { + in_addr remoteIp; + void *socketHandle; + void *socketFd{nullptr}; + void *qpHandles[CONN_QP_COUNT]{}; + HccpAiQpInfo aiQpInfo{}; + int qpStatus{-1}; + + explicit ConnectionChannel(const in_addr ip) : ConnectionChannel{ip, nullptr} {} + ConnectionChannel(in_addr ip, void *sock) : remoteIp{ip}, socketHandle{sock} {} + }; + + bool ReserveQpInfoSpace() noexcept; + int StartServerSide() noexcept; + int StartClientSide() noexcept; + int GenerateWhiteList() noexcept; + int WaitConnectionsReady(std::unordered_map &connections) noexcept; + int CreateQpWaitingReady(std::unordered_map &connections, ConnQpType qpType) noexcept; + int CreateOneQp(ConnQpType qpType, ConnectionChannel &channel) noexcept; + int FillQpInfo(ConnQpType qpType) noexcept; + void CopyAiWQInfo(struct AiQpRMAWQ &dest, const struct ai_data_plane_wq &src, DBMode dbMode, uint32_t sl) noexcept; + void CopyAiCQInfo(struct AiQpRMACQ &dest, const ai_data_plane_cq &source, DBMode dbMode) noexcept; + void CloseServices() noexcept; + void CloseClientConnections() noexcept; + void CloseServerConnections() noexcept; + void CloseConnections(std::unordered_map &connections) noexcept; + + bool started_{false}; + int serverConnectResult{-1}; + int clientConnectResult{-1}; + uint32_t qpInfoSize_{0}; + void *rdmaHandle_{nullptr}; + std::unordered_map currentRanksInfo_; + MemoryRegionMap currentLocalMrs_; + AiQpRMAQueueInfo *qpInfo_{nullptr}; + std::unordered_map clientConnections_; + std::unordered_map serverConnections_; +}; + +#endif // MF_HYBRID_DEVICE_QP_MANAGER_H \ No newline at end of file diff --git a/src/modules/transport/rdma/dl_hccp_api.cpp b/src/modules/transport/rdma/dl_hccp_api.cpp new file mode 100644 index 00000000..dbcc92be --- /dev/null +++ b/src/modules/transport/rdma/dl_hccp_api.cpp @@ -0,0 +1,167 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include +#include "dl_hccp_api.h" + +bool DlHccpApi::gLoaded = false; +std::mutex DlHccpApi::gMutex; +void *DlHccpApi::raHandle; +void *DlHccpApi::tsdHandle; + +const char *DlHccpApi::gRaLibName = "libra.so"; +const char *DlHccpApi::gTsdLibName = "libtsdclient.so"; + +raRdevGetHandleFunc DlHccpApi::gRaRdevGetHandle; + +raInitFunc DlHccpApi::gRaInit; +raGetInterfaceVersionFunc DlHccpApi::gRaGetInterfaceVersion; +raSocketInitFunc DlHccpApi::gRaSocketInit; +raSocketDeinitFunc DlHccpApi::gRaSocketDeinit; +raRdevInitV2Func DlHccpApi::gRaRdevInitV2; +raSocketBatchConnectFunc DlHccpApi::gRaSocketBatchConnect; +raSocketBatchCloseFunc DlHccpApi::gRaSocketBatchClose; +raSocketBatchAbortFunc DlHccpApi::gRaSocketBatchAbort; +raSocketListenStartFunc DlHccpApi::gRaSocketListenStart; +raSocketListenStopFunc DlHccpApi::gRaSocketListenStop; +raGetSocketsFunc DlHccpApi::gRaGetSockets; +raSocketSendFunc DlHccpApi::gRaSocketSend; +raSocketRecvFunc DlHccpApi::gRaSocketRecv; +raGetIfNumFunc DlHccpApi::gRaGetIfNum; +raGetIfAddrsFunc DlHccpApi::gRaGetIfAddrs; +raSocketWhiteListAddFunc DlHccpApi::gRaSocketWhiteListAdd; +raSocketWhiteListDelFunc DlHccpApi::gRaSocketWhiteListDel; +raQpCreateFunc DlHccpApi::gRaQpCreate; +raQpAiCreateFunc DlHccpApi::gRaQpAiCreate; +raQpDestroyFunc DlHccpApi::gRaQpDestroy; +raGetQpStatusFunc DlHccpApi::gRaGetQpStatus; +raQpConnectAsyncFunc DlHccpApi::gRaQpConnectAsync; +raRegisterMrFunc DlHccpApi::gRaRegisterMR; +raDeregisterMrFunc DlHccpApi::gRaDeregisterMR; +raMrRegFunc DlHccpApi::gRaMrReg; +raMrDeregFunc DlHccpApi::gRaMrDereg; +raSendWrFunc DlHccpApi::gRaSendWr; +raPollCqFunc DlHccpApi::gRaPollCq; + +tsdOpenFunc DlHccpApi::gTsdOpen; + +Result DlHccpApi::LoadLibrary() +{ + std::lock_guard guard(gMutex); + if (gLoaded) { + return 0; + } + + raHandle = dlopen(gRaLibName, RTLD_NOW); + if (raHandle == nullptr) { + std::cout << "Failed to open library [" + << gRaLibName + << "], please source ascend-toolkit set_env.sh, or add ascend driver lib path into LD_LIBRARY_PATH," + << " error: " << dlerror() << std::endl; + return -1; + } + + tsdHandle = dlopen(gTsdLibName, RTLD_NOW); + if (tsdHandle == nullptr) { + std::cout << "Failed to open library [" + << gTsdLibName + << "], please source ascend-toolkit set_env.sh, or add ascend driver lib path into LD_LIBRARY_PATH," + << " error: " << dlerror() << std::endl; + dlclose(raHandle); + raHandle = nullptr; + return -1; + } + + /* load sym */ + DL_LOAD_SYM(gRaGetInterfaceVersion, raGetInterfaceVersionFunc, raHandle, "ra_get_interface_version"); + DL_LOAD_SYM(gRaSocketInit, raSocketInitFunc, raHandle, "ra_socket_init"); + DL_LOAD_SYM(gRaInit, raInitFunc, raHandle, "ra_init"); + DL_LOAD_SYM(gRaSocketDeinit, raSocketDeinitFunc, raHandle, "ra_socket_deinit"); + DL_LOAD_SYM(gRaRdevInitV2, raRdevInitV2Func, raHandle, "ra_rdev_init_v2"); + DL_LOAD_SYM(gRaRdevGetHandle, raRdevGetHandleFunc, raHandle, "ra_rdev_get_handle"); + DL_LOAD_SYM(gRaSocketBatchConnect, raSocketBatchConnectFunc, raHandle, "ra_socket_batch_connect"); + DL_LOAD_SYM(gRaSocketBatchClose, raSocketBatchCloseFunc, raHandle, "ra_socket_batch_close"); + DL_LOAD_SYM(gRaSocketBatchAbort, raSocketBatchAbortFunc, raHandle, "ra_socket_batch_abort"); + DL_LOAD_SYM(gRaSocketListenStart, raSocketListenStartFunc, raHandle, "ra_socket_listen_start"); + DL_LOAD_SYM(gRaSocketListenStop, raSocketListenStopFunc, raHandle, "ra_socket_listen_stop"); + DL_LOAD_SYM(gRaGetSockets, raGetSocketsFunc, raHandle, "ra_get_sockets"); + DL_LOAD_SYM(gRaSocketSend, raSocketSendFunc, raHandle, "ra_socket_send"); + DL_LOAD_SYM(gRaSocketRecv, raSocketRecvFunc, raHandle, "ra_socket_recv"); + DL_LOAD_SYM(gRaGetIfNum, raGetIfNumFunc, raHandle, "ra_get_ifnum"); + DL_LOAD_SYM(gRaGetIfAddrs, raGetIfAddrsFunc, raHandle, "ra_get_ifaddrs"); + DL_LOAD_SYM(gRaSocketWhiteListAdd, raSocketWhiteListAddFunc, raHandle, "ra_socket_white_list_add"); + DL_LOAD_SYM(gRaSocketWhiteListDel, raSocketWhiteListDelFunc, raHandle, "ra_socket_white_list_del"); + DL_LOAD_SYM(gRaQpCreate, raQpCreateFunc, raHandle, "ra_qp_create"); + DL_LOAD_SYM(gRaQpAiCreate, raQpAiCreateFunc, raHandle, "ra_ai_qp_create"); + DL_LOAD_SYM(gRaQpDestroy, raQpDestroyFunc, raHandle, "ra_qp_destroy"); + DL_LOAD_SYM(gRaGetQpStatus, raGetQpStatusFunc, raHandle, "ra_get_qp_status"); + DL_LOAD_SYM(gRaQpConnectAsync, raQpConnectAsyncFunc, raHandle, "ra_qp_connect_async"); + DL_LOAD_SYM(gRaRegisterMR, raRegisterMrFunc, raHandle, "ra_register_mr"); + DL_LOAD_SYM(gRaDeregisterMR, raDeregisterMrFunc, raHandle, "ra_deregister_mr"); + DL_LOAD_SYM(gRaMrReg, raMrRegFunc, raHandle, "ra_mr_reg"); + DL_LOAD_SYM(gRaMrDereg, raMrDeregFunc, raHandle, "ra_mr_dereg"); + DL_LOAD_SYM(gRaSendWr, raSendWrFunc, raHandle, "ra_send_wr"); + DL_LOAD_SYM(gRaPollCq, raPollCqFunc, raHandle, "ra_poll_cq"); + + DL_LOAD_SYM(gTsdOpen, tsdOpenFunc, tsdHandle, "TsdOpen"); + SHM_LOG_INFO("LoadLibrary for DlHccpApi success"); + gLoaded = true; + return 0; +} + +void DlHccpApi::CleanupLibrary() +{ + std::lock_guard guard(gMutex); + if (!gLoaded) { + return; + } + + gRaRdevGetHandle = nullptr; + gRaInit = nullptr; + gRaGetInterfaceVersion = nullptr; + gRaSocketInit = nullptr; + gRaSocketDeinit = nullptr; + gRaRdevInitV2 = nullptr; + gRaSocketBatchConnect = nullptr; + gRaSocketBatchClose = nullptr; + gRaSocketBatchAbort = nullptr; + gRaSocketListenStart = nullptr; + gRaSocketListenStop = nullptr; + gRaGetSockets = nullptr; + gRaSocketSend = nullptr; + gRaSocketRecv = nullptr; + gRaGetIfNum = nullptr; + gRaGetIfAddrs = nullptr; + gRaSocketWhiteListAdd = nullptr; + gRaSocketWhiteListDel = nullptr; + gRaQpCreate = nullptr; + gRaQpAiCreate = nullptr; + gRaQpDestroy = nullptr; + gRaGetQpStatus = nullptr; + gRaQpConnectAsync = nullptr; + gRaRegisterMR = nullptr; + gRaDeregisterMR = nullptr; + gRaMrReg = nullptr; + gRaMrDereg = nullptr; + gTsdOpen = nullptr; + gRaSendWr = nullptr; + gRaPollCq = nullptr; + + if (raHandle != nullptr) { + dlclose(raHandle); + raHandle = nullptr; + } + + if (tsdHandle != nullptr) { + dlclose(tsdHandle); + tsdHandle = nullptr; + } + gLoaded = false; +} \ No newline at end of file diff --git a/src/modules/transport/rdma/dl_hccp_api.h b/src/modules/transport/rdma/dl_hccp_api.h new file mode 100644 index 00000000..b8722bca --- /dev/null +++ b/src/modules/transport/rdma/dl_hccp_api.h @@ -0,0 +1,248 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef DL_HCCP_API_H +#define DL_HCCP_API_H + +#include +#include "dl_hccp_def.h" + +using namespace shm; + +using raRdevGetHandleFunc = int (*)(uint32_t, void **); + +using raGetInterfaceVersionFunc = int (*)(uint32_t, uint32_t, uint32_t *); +using raInitFunc = int (*)(const HccpRaInitConfig *); +using raSocketInitFunc = int (*)(HccpNetworkMode, HccpRdev, void **); +using raSocketDeinitFunc = int (*)(void *); +using raRdevInitV2Func = int (*)(HccpRdevInitInfo, HccpRdev, void **); +using raSocketBatchConnectFunc = int (*)(HccpSocketConnectInfo[], uint32_t); +using raSocketBatchCloseFunc = int (*)(HccpSocketCloseInfo[], uint32_t); +using raSocketBatchAbortFunc = int (*)(HccpSocketConnectInfo[], uint32_t); +using raSocketListenStartFunc = int (*)(HccpSocketListenInfo[], uint32_t); +using raSocketListenStopFunc = int (*)(HccpSocketListenInfo[], uint32_t); +using raGetSocketsFunc = int (*)(uint32_t, HccpSocketInfo[], uint32_t, uint32_t *); +using raSocketSendFunc = int (*)(const void *, const void *, uint64_t, uint64_t *); +using raSocketRecvFunc = int (*)(const void *, void *, uint64_t, uint64_t *); +using raGetIfNumFunc = int (*)(const HccpRaGetIfAttr *, uint32_t *); +using raGetIfAddrsFunc = int (*)(const HccpRaGetIfAttr *, HccpInterfaceInfo[], uint32_t *); +using raSocketWhiteListAddFunc = int (*)(void *, const HccpSocketWhiteListInfo[], uint32_t num); +using raSocketWhiteListDelFunc = int (*)(void *, const HccpSocketWhiteListInfo[], uint32_t num); +using raQpCreateFunc = int (*)(void *, int, int, void **); +using raQpAiCreateFunc = int (*)(void *, const HccpQpExtAttrs *, HccpAiQpInfo *, void **); +using raQpDestroyFunc = int (*)(void *); +using raGetQpStatusFunc = int (*)(void *, int *); +using raQpConnectAsyncFunc = int (*)(void *, const void *); +using raRegisterMrFunc = int (*)(const void *, HccpMrInfo *, void **); +using raDeregisterMrFunc = int (*)(const void *, void *); +using raMrRegFunc = int (*)(void *, HccpMrInfo *); +using raMrDeregFunc = int (*)(void *, HccpMrInfo *); +using raSendWrFunc = int (*)(void *, send_wr *, send_wr_rsp *); +using tsdOpenFunc = uint32_t (*)(uint32_t, uint32_t); +using raPollCqFunc = int (*)(void *, bool, uint32_t, void *); + +class DlHccpApi { +public: + static Result LoadLibrary(); + static void CleanupLibrary(); + + static inline int RaGetInterfaceVersion(uint32_t deviceId, uint32_t opcode, uint32_t &version) + { + return gRaGetInterfaceVersion(deviceId, opcode, &version); + } + + static inline int RaSocketInit(HccpNetworkMode mode, const HccpRdev &rdev, void *&socketHandle) + { + return gRaSocketInit(mode, rdev, &socketHandle); + } + + static inline int RaInit(const HccpRaInitConfig &config) + { + return gRaInit(&config); + } + + static inline int RaSocketDeinit(void *socketHandle) + { + return gRaSocketDeinit(socketHandle); + } + + static inline int RaRdevInitV2(const HccpRdevInitInfo &info, const HccpRdev &rdev, void *&rdmaHandle) + { + return gRaRdevInitV2(info, rdev, &rdmaHandle); + } + + static inline int RaRdevGetHandle(uint32_t deviceId, void *&rdmaHandle) + { + return gRaRdevGetHandle(deviceId, &rdmaHandle); + } + + static inline int RaSocketBatchConnect(HccpSocketConnectInfo conn[], uint32_t num) + { + return gRaSocketBatchConnect(conn, num); + } + + static inline int RaSocketBatchClose(HccpSocketCloseInfo conn[], uint32_t num) + { + return gRaSocketBatchClose(conn, num); + } + + static inline int RaSocketBatchAbort(HccpSocketConnectInfo conn[], uint32_t num) + { + return gRaSocketBatchAbort(conn, num); + } + + static inline int RaSocketListenStart(HccpSocketListenInfo conn[], uint32_t num) + { + return gRaSocketListenStart(conn, num); + } + + static inline int RaSocketListenStop(HccpSocketListenInfo conn[], uint32_t num) + { + return gRaSocketListenStop(conn, num); + } + + static inline int RaGetSockets(uint32_t role, HccpSocketInfo conn[], uint32_t num, uint32_t &connectedNum) + { + return gRaGetSockets(role, conn, num, &connectedNum); + } + + static inline int RaSocketSend(const void *fd, const void *data, uint64_t size, uint64_t &sent) + { + return gRaSocketSend(fd, data, size, &sent); + } + + static inline int RaSocketRecv(const void *fd, void *data, uint64_t size, uint64_t &received) + { + return gRaSocketRecv(fd, data, size, &received); + } + + static inline int RaGetIfNum(const HccpRaGetIfAttr &config, uint32_t &num) + { + return gRaGetIfNum(&config, &num); + } + + static inline int RaGetIfAddrs(const HccpRaGetIfAttr &config, HccpInterfaceInfo infos[], uint32_t &num) + { + return gRaGetIfAddrs(&config, infos, &num); + } + + static inline int RaSocketWhiteListAdd(void *socket, const HccpSocketWhiteListInfo list[], uint32_t num) + { + return gRaSocketWhiteListAdd(socket, list, num); + } + + static inline int RaSocketWhiteListDel(void *socket, const HccpSocketWhiteListInfo list[], uint32_t num) + { + return gRaSocketWhiteListAdd(socket, list, num); + } + + static inline int RaQpCreate(void *rdmaHandle, int flag, int qpMode, void *&qpHandle) + { + return gRaQpCreate(rdmaHandle, flag, qpMode, &qpHandle); + } + + static inline int RaQpAiCreate(void *rdmaHandle, const HccpQpExtAttrs &attrs, HccpAiQpInfo &info, void *&qpHandle) + { + return gRaQpAiCreate(rdmaHandle, &attrs, &info, &qpHandle); + } + + static inline int RaQpDestroy(void *qpHandle) + { + return gRaQpDestroy(qpHandle); + } + + static inline int RaGetQpStatus(void *qpHandle, int &status) + { + return gRaGetQpStatus(qpHandle, &status); + } + + static inline int RaQpConnectAsync(void *qp, const void *socketFd) + { + return gRaQpConnectAsync(qp, socketFd); + } + + static inline int RaRegisterMR(const void *rdmaHandle, HccpMrInfo *info, void *&mrHandle) + { + return gRaRegisterMR(rdmaHandle, info, &mrHandle); + } + + static inline int RaDeregisterMR(const void *rdmaHandle, void *mrHandle) + { + return gRaDeregisterMR(rdmaHandle, mrHandle); + } + + static inline int RaMrReg(void *qpHandle, HccpMrInfo &info) + { + return gRaMrReg(qpHandle, &info); + } + + static inline int RaMrDereg(void *qpHandle, HccpMrInfo &info) + { + return gRaMrDereg(qpHandle, &info); + } + + static inline int RaSendWr(void *qp_handle, struct send_wr *wr, struct send_wr_rsp *op_rsp) + { + return gRaSendWr(qp_handle, wr, op_rsp); + } + + static inline int RaPollCq(void *qp_handle, bool is_send_cq, unsigned int num_entries, void *wc) + { + return gRaPollCq(qp_handle, is_send_cq, num_entries, wc); + } + + static inline uint32_t TsdOpen(uint32_t deviceId, uint32_t rankSize) + { + return gTsdOpen(deviceId, rankSize); + } + +private: + static std::mutex gMutex; + static bool gLoaded; + static void *raHandle; + static void *tsdHandle; + static const char *gRaLibName; + static const char *gTsdLibName; + + static raRdevGetHandleFunc gRaRdevGetHandle; + + static raGetInterfaceVersionFunc gRaGetInterfaceVersion; + static raInitFunc gRaInit; + static raSocketInitFunc gRaSocketInit; + static raSocketDeinitFunc gRaSocketDeinit; + static raRdevInitV2Func gRaRdevInitV2; + static raSocketBatchConnectFunc gRaSocketBatchConnect; + static raSocketBatchCloseFunc gRaSocketBatchClose; + static raSocketBatchAbortFunc gRaSocketBatchAbort; + static raSocketListenStartFunc gRaSocketListenStart; + static raSocketListenStopFunc gRaSocketListenStop; + static raGetSocketsFunc gRaGetSockets; + static raSocketSendFunc gRaSocketSend; + static raSocketRecvFunc gRaSocketRecv; + static raGetIfNumFunc gRaGetIfNum; + static raGetIfAddrsFunc gRaGetIfAddrs; + static raSocketWhiteListAddFunc gRaSocketWhiteListAdd; + static raSocketWhiteListDelFunc gRaSocketWhiteListDel; + static raQpCreateFunc gRaQpCreate; + static raQpAiCreateFunc gRaQpAiCreate; + static raQpDestroyFunc gRaQpDestroy; + static raGetQpStatusFunc gRaGetQpStatus; + static raQpConnectAsyncFunc gRaQpConnectAsync; + static raRegisterMrFunc gRaRegisterMR; + static raDeregisterMrFunc gRaDeregisterMR; + static raMrRegFunc gRaMrReg; + static raMrDeregFunc gRaMrDereg; + static raSendWrFunc gRaSendWr; + static raPollCqFunc gRaPollCq; + + static tsdOpenFunc gTsdOpen; +}; + +#endif // DL_HCCP_API_H \ No newline at end of file diff --git a/src/modules/transport/rdma/dl_hccp_def.h b/src/modules/transport/rdma/dl_hccp_def.h new file mode 100644 index 00000000..e149a145 --- /dev/null +++ b/src/modules/transport/rdma/dl_hccp_def.h @@ -0,0 +1,647 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef DL_HCCP_DEF_H +#define DL_HCCP_DEF_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "shmem_api.h" +#include "common/shmemi_functions.h" +#include "common/shmemi_host_types.h" + +using Result = int32_t; + +constexpr uint32_t HCCL_ROOT_INFO_BYTES = 256; // 4108: root info length +constexpr uint32_t HCCP_SOCK_CONN_TAG_SIZE = 192; +constexpr uint32_t HCCP_MAX_INTERFACE_NAME_LEN = 256; + +constexpr uint64_t EXPORT_INFO_MAGIC = 0xAABB1234FFFFEEEEUL; +constexpr uint64_t EXPORT_SLICE_MAGIC = 0xAABB1234FFFFBBBBUL; +constexpr uint64_t EXPORT_INFO_VERSION = 0x1UL; + +struct HybmDeviceGlobalMeta { + uint64_t entityCount; + uint64_t reserved[15]; // total 128B, equal HYBM_DEVICE_PRE_META_SIZE +}; + +struct HybmDeviceMeta { + uint32_t entityId; + uint32_t rankId; + uint32_t rankSize; + uint32_t extraContextSize; + uint64_t symmetricSize; + uint64_t qpInfoAddress; + uint64_t reserved[12]; // total 128B, equal HYBM_DEVICE_PRE_META_SIZE +}; + + +/** + * @brief HCCL root info + */ +struct HcclRootInfo { + char internal[HCCL_ROOT_INFO_BYTES]; +}; + +struct HccpRaInitConfig { + uint32_t phyId; /**< physical device id */ + uint32_t nicPosition; /**< reference to HccpNetworkMode */ + int hdcType; /**< reference to drvHdcServiceType */ +}; + +/** + * @ingroup libinit + * ip address + */ +union HccpIpAddr { + struct in_addr addr; + struct in6_addr addr6; +}; + +struct HccpRdevInitInfo { + int mode; + uint32_t notifyType; + bool enabled910aLite; /**< true will enable 910A lite, invalid if enabled_2mb_lite is false; default is false */ + bool disabledLiteThread; /**< true will not start lite thread, flag invalid if enabled_910a/2mb_lite is false */ + bool enabled2mbLite; /**< true will enable 2MB lite(include 910A & 910B), default is false */ +}; + +/** + * @ingroup libinit + * hccp operating environment + */ +enum HccpNetworkMode { + NETWORK_PEER_ONLINE = 0, /**< Third-party online mode */ + NETWORK_OFFLINE, /**< offline mode */ + NETWORK_ONLINE, /**< online mode */ +}; + +/** + * @ingroup librdma + * Flag of mr access + */ +enum HccpMrAccessFlags { + RA_ACCESS_LOCAL_WRITE = 1, /**< mr local write access */ + RA_ACCESS_REMOTE_WRITE = (1 << 1), /**< mr remote write access */ + RA_ACCESS_REMOTE_READ = (1 << 2), /**< mr remote read access */ + RA_ACCESS_REDUCE = (1 << 8), +}; + +enum HccpNotifyType { + NO_USE = 0, + NOTIFY = 1, + EVENTID = 2, +}; + +/** + * @ingroup libsocket + * struct of the client socket + */ +struct HccpSocketConnectInfo { + void *handle; /**< socket handle */ + HccpIpAddr remoteIp; /**< IP address of remote socket, [0-7] is reserved for vnic */ + uint16_t port; /**< Socket listening port number */ + char tag[HCCP_SOCK_CONN_TAG_SIZE]; /**< tag must ended by '\0' */ +}; + +inline std::ostream &operator<<(std::ostream &output, const HccpSocketConnectInfo &info) +{ + output << "HccpSocketConnectInfo(socketHandle=" << info.handle << ", remoteIp=" << inet_ntoa(info.remoteIp.addr) + << ", port=" << info.port << ")"; + return output; +} + +/** + * @ingroup libsocket + * Details about socket after socket is linked + */ +struct HccpSocketCloseInfo { + void *handle; /**< socket handle */ + void *fd; /**< fd handle */ + int linger; /**< 0:use(default l_linger is RS_CLOSE_TIMEOUT), others:disuse */ +}; + +/** + * @ingroup libsocket + * struct of the listen info + */ +struct HccpSocketListenInfo { + void *handle; /**< socket handle */ + unsigned int port; /**< Socket listening port number */ + unsigned int phase; /**< refer to enum listen_phase */ + unsigned int err; /**< errno */ +}; + +/** + * @ingroup libsocket + * Details about socket after socket is linked + */ +struct HccpSocketInfo { + void *handle; /**< socket handle */ + void *fd; /**< fd handle */ + HccpIpAddr remoteIp; /**< IP address of remote socket */ + int status; /**< socket status:0 not connected 1:connected 2:connect timeout 3:connecting */ + char tag[HCCP_SOCK_CONN_TAG_SIZE]; /**< tag must ended by '\0' */ +}; + +/** + * @ingroup libinit + * hccp init info + */ +struct HccpRdev { + uint32_t phyId; /**< physical device id */ + int family; /**< AF_INET(ipv4) or AF_INET6(ipv6) */ + HccpIpAddr localIp; +}; + +struct HccpRaGetIfAttr { + uint32_t phyId; /**< physical device id */ + uint32_t nicPosition; /**< reference to network_mode */ + bool isAll; /**< valid when nic_position is NETWORK_OFFLINE. false: get specific rnic ip, true: get all rnic ip */ +}; + +struct HccpIfaddrInfo { + HccpIpAddr ip; /* Address of interface */ + struct in_addr mask; /* Netmask of interface */ +}; + +struct HccpInterfaceInfo { + int family; + int scopeId; + HccpIfaddrInfo ifaddr; /* Address and netmask of interface */ + char ifname[HCCP_MAX_INTERFACE_NAME_LEN]; /* Name of interface */ +}; + +struct HccpSocketWhiteListInfo { + HccpIpAddr remoteIp; /**< IP address of remote */ + uint32_t connLimit; /**< limit of whilte list */ + char tag[HCCP_SOCK_CONN_TAG_SIZE]; /**< tag used for whitelist must ended by '\0' */ +}; + +struct HccpMrInfo { + void *addr; /**< starting address of mr */ + unsigned long long size; /**< size of mr */ + int access; /**< access of mr, reference to HccpMrAccessFlags */ + unsigned int lkey; /**< local addr access key */ + unsigned int rkey; /**< remote addr access key */ +}; + +struct HccpCqExtAttr { + int sendCqDepth; + int recvDqDepth; + int sendCqCompVector; + int recvCqCompVector; +}; + +enum ibv_qp_type { + IBV_QPT_RC = 2, + IBV_QPT_UC, + IBV_QPT_UD, + IBV_QPT_RAW_PACKET = 8, + IBV_QPT_XRC_SEND = 9, + IBV_QPT_XRC_RECV, + IBV_QPT_DRIVER = 0xff, +}; + +enum ibv_wc_status { + IBV_WC_SUCCESS, + IBV_WC_LOC_LEN_ERR, + IBV_WC_LOC_QP_OP_ERR, + IBV_WC_LOC_EEC_OP_ERR, + IBV_WC_LOC_PROT_ERR, + IBV_WC_WR_FLUSH_ERR, + IBV_WC_MW_BIND_ERR, + IBV_WC_BAD_RESP_ERR, + IBV_WC_LOC_ACCESS_ERR, + IBV_WC_REM_INV_REQ_ERR, + IBV_WC_REM_ACCESS_ERR, + IBV_WC_REM_OP_ERR, + IBV_WC_RETRY_EXC_ERR, + IBV_WC_RNR_RETRY_EXC_ERR, + IBV_WC_LOC_RDD_VIOL_ERR, + IBV_WC_REM_INV_RD_REQ_ERR, + IBV_WC_REM_ABORT_ERR, + IBV_WC_INV_EECN_ERR, + IBV_WC_INV_EEC_STATE_ERR, + IBV_WC_FATAL_ERR, + IBV_WC_RESP_TIMEOUT_ERR, + IBV_WC_GENERAL_ERR +}; + +enum ibv_wc_opcode { + IBV_WC_SEND, + IBV_WC_RDMA_WRITE, + IBV_WC_RDMA_READ, + IBV_WC_COMP_SWAP, + IBV_WC_FETCH_ADD, + IBV_WC_BIND_MW, + /* + * Set value of IBV_WC_RECV so consumers can test if a completion is a + * receive by testing (opcode & IBV_WC_RECV). + */ + IBV_WC_RECV = 1 << 7, + IBV_WC_RECV_RDMA_WITH_IMM +}; + +struct ibv_wc { + uint64_t wr_id; + enum ibv_wc_status status; + enum ibv_wc_opcode opcode; + uint32_t vendor_err; + uint32_t byte_len; + uint32_t imm_data; /* in network byte order */ + uint32_t qp_num; + uint32_t src_qp; + int wc_flags; + uint16_t pkey_index; + uint16_t slid; + uint8_t sl; + uint8_t dlid_path_bits; +}; + +struct ibv_qp_cap { + uint32_t max_send_wr; + uint32_t max_recv_wr; + uint32_t max_send_sge; + uint32_t max_recv_sge; + uint32_t max_inline_data; +}; + +struct ibv_qp_init_attr { + void *qp_context; + struct ibv_cq *send_cq; + struct ibv_cq *recv_cq; + struct ibv_srq *srq; + struct ibv_qp_cap cap; + enum ibv_qp_type qp_type; + int sq_sig_all; +}; + +union ai_data_plane_cstm_flag { + struct { + uint32_t cq_cstm : 1; // 0: hccp poll cq; 1: caller poll cq + uint32_t reserved : 31; + } bs; + uint32_t value; +}; + +struct HccpQpExtAttrs { + int qpMode; + // cq attr + HccpCqExtAttr cqAttr; + // qp attr + struct ibv_qp_init_attr qp_attr; + // version control and reserved + int version; + int mem_align; // 0,1:4KB, 2:2MB + uint32_t udp_sport; + union ai_data_plane_cstm_flag data_plane_flag; // only valid in ra_ai_qp_create + uint32_t reserved[29]; +}; + +struct ai_data_plane_wq { + unsigned wqn; + unsigned long long buf_addr; + unsigned int wqebb_size; + unsigned int depth; + unsigned long long head_addr; + unsigned long long tail_addr; + unsigned long long swdb_addr; + unsigned long long db_reg; + unsigned int reserved[8U]; +}; + +struct ai_data_plane_cq { + unsigned int cqn; + unsigned long long buf_addr; + unsigned int cqe_size; + unsigned int depth; + unsigned long long head_addr; + unsigned long long tail_addr; + unsigned long long swdb_addr; + unsigned long long db_reg; + unsigned int reserved[2U]; +}; + +struct ai_data_plane_info { + struct ai_data_plane_wq sq; + struct ai_data_plane_wq rq; + struct ai_data_plane_cq scq; + struct ai_data_plane_cq rcq; + unsigned int reserved[8U]; +}; + +struct HccpAiQpInfo { + unsigned long long aiQpAddr; // refer to struct ibv_qp * + unsigned int sqIndex; // index of sq + unsigned int dbIndex; // index of db + + // below cq related info valid when data_plane_flag.bs.cq_cstm was 1 + unsigned long long ai_scq_addr; // refer to struct ibv_cq *scq + unsigned long long ai_rcq_addr; // refer to struct ibv_cq *rcq + struct ai_data_plane_info data_plane_info; +}; + +enum class DBMode : int32_t { INVALID_DB = -1, HW_DB = 0, SW_DB }; + +struct AiQpRMAWQ { + uint32_t wqn{0}; + uint64_t bufAddr{0}; + uint32_t wqeSize{0}; + uint32_t depth{0}; + uint64_t headAddr{0}; + uint64_t tailAddr{0}; + DBMode dbMode{DBMode::INVALID_DB}; // 0-hw/1-sw + uint64_t dbAddr{0}; + uint32_t sl{0}; +}; + +struct AiQpRMACQ { + uint32_t cqn{0}; + uint64_t bufAddr{0}; + uint32_t cqeSize{0}; + uint32_t depth{0}; + uint64_t headAddr{0}; + uint64_t tailAddr{0}; + DBMode dbMode{DBMode::INVALID_DB}; // 0-hw/1-sw + uint64_t dbAddr{0}; +}; + +struct RdmaMemRegionInfo { + uint64_t size{0}; // size of the memory region + uint64_t addr{0}; // start address of the memory region + uint32_t lkey{0}; + uint32_t rkey{0}; // key of the memory region +}; + +struct AiQpRMAQueueInfo { + uint32_t count; + struct AiQpRMAWQ *sq; + struct AiQpRMAWQ *rq; + struct AiQpRMACQ *scq; + struct AiQpRMACQ *rcq; + RdmaMemRegionInfo *mr; +}; + +/** + * @ingroup librdma + * Scatter and gather element + */ +struct sg_list { + uint64_t addr; /**< address of buf */ + uint32_t len; /**< len of buf */ + uint32_t lkey; /**< local addr access key */ +}; + +/** + * @ingroup librdma + * RDMA work request + */ +struct send_wr { + struct sg_list *buf_list; /**< list of sg */ + uint16_t buf_num; /**< num of buf_list */ + uint64_t dst_addr; /**< destination address */ + uint32_t rkey; /**< remote address access key */ + uint32_t op; /**< operations of RDMA supported:RDMA_WRITE:0 */ + int send_flag; /**< reference to ra_send_flags */ +}; + +/** + * @ingroup librdma + * wqe template info + */ +struct wqe_info { + unsigned int sq_index; /**< index of sq */ + unsigned int wqe_index; /**< index of wqe */ +}; + +enum ra_send_flags { + RA_SEND_FENCE = 1 << 0, /**< RDMA operation with fence */ + RA_SEND_SIGNALED = 1 << 1, /**< RDMA operation with signaled */ + RA_SEND_SOLICITED = 1 << 2, /**< RDMA operation with solicited */ + RA_SEND_INLINE = 1 << 3, /**< RDMA operation with inline */ +}; +/** + * @ingroup librdma + * doorbell info + */ +struct db_info { + unsigned int db_index; /**< index of db */ + unsigned long db_info; /**< db content */ +}; + +/** + * @ingroup librdma + * respond of sending work request + */ +struct send_wr_rsp { + union { + struct wqe_info wqe_tmp; /**< wqe template info */ + struct db_info db; /**< doorbell info */ + }; +}; +/** + * @brief handle to HCCL communicator + */ +typedef void *HcclComm; + +// macro for gcc optimization for prediction of if/else +#ifndef LIKELY +#define LIKELY(x) (__builtin_expect(!!(x), 1) != 0) +#endif + +#ifndef UNLIKELY +#define UNLIKELY(x) (__builtin_expect(!!(x), 0) != 0) +#endif + +#define HYBM_API __attribute__((visibility("default"))) + +#define DL_LOAD_SYM(TARGET_FUNC_VAR, TARGET_FUNC_TYPE, FILE_HANDLE, SYMBOL_NAME) \ + do { \ + TARGET_FUNC_VAR = (TARGET_FUNC_TYPE)dlsym(FILE_HANDLE, SYMBOL_NAME); \ + if ((TARGET_FUNC_VAR) == nullptr) { \ + std::cout << "Failed to call dlsym to load symbol" << SYMBOL_NAME << std::endl; \ + dlclose(FILE_HANDLE); \ + return -1; \ + } \ + } while (0) + + +enum HybmGvaVersion : uint32_t { + HYBM_GVA_V1 = 0, + HYBM_GVA_V2 = 1, + HYBM_GVA_V3 = 2, + HYBM_GVA_UNKNOWN +}; + +inline std::ostream &operator<<(std::ostream &output, const HccpRaInitConfig &config) +{ + output << "HccpRaInitConfig(phyId=" << config.phyId << ", nicPosition=" << config.nicPosition + << ", hdcType=" << config.hdcType << ")"; + return output; +} + +inline std::ostream &operator<<(std::ostream &output, const HccpRdevInitInfo &info) +{ + output << "HccpRdevInitInfo(mode=" << info.mode << ", notify=" << info.notifyType + << ", enabled910aLite=" << info.enabled910aLite << ", disabledLiteThread=" << info.disabledLiteThread + << ", enabled2mbLite=" << info.enabled2mbLite << ")"; + return output; +} + +inline std::ostream &operator<<(std::ostream &output, const HccpRdev &rdev) +{ + output << "HccpRdev(phyId=" << rdev.phyId << ", family=" << rdev.family + << ", rdev.ip=" << inet_ntoa(rdev.localIp.addr) << ")"; + return output; +} + +struct RegMemResult { + uint32_t reserved{0}; + uint64_t address{0}; + uint64_t size{0}; + void *mrHandle{nullptr}; + uint32_t lkey{0}; + uint32_t rkey{0}; + + RegMemResult() = default; + + RegMemResult(uint64_t addr, uint64_t sz, void *hd, uint32_t lk, uint32_t rk) + : address(addr), + size(sz), + mrHandle(hd), + lkey(lk), + rkey(rk) + { + } +}; + +inline std::ostream &operator<<(std::ostream &output, const RegMemResult &mr) +{ + output << "RegMemResult(address = " << mr.address << ", size = " << mr.size + << ", lkey = " << mr.lkey << ", rkey = " << mr.rkey << ")"; + return output; +} + +constexpr int32_t REG_MR_ACCESS_FLAG_LOCAL_WRITE = 0x1; +constexpr int32_t REG_MR_ACCESS_FLAG_REMOTE_WRITE = 0x2; +constexpr int32_t REG_MR_ACCESS_FLAG_REMOTE_READ = 0x4; +constexpr int32_t REG_MR_ACCESS_FLAG_BOTH_READ_WRITE = 0x7; + +typedef enum { + HYBM_ROLE_PEER = 0, + HYBM_ROLE_SENDER, + HYBM_ROLE_RECEIVER, + HYBM_ROLE_BUTT +} hybm_role_type; + +struct TransportOptions { + uint32_t rankId; + uint32_t rankCount; + uint32_t protocol; + hybm_role_type role; + int nic; +}; + +struct TransportMemoryRegion { + uint64_t addr = 0; /* virtual address of memory could be hbm or host dram */ + uint64_t size = 0; /* size of memory to be registered */ + int32_t access = REG_MR_ACCESS_FLAG_BOTH_READ_WRITE; /* access right by local and remote */ + uint32_t flags = 0; /* optional flags: 加一个flag标识是DRAM还是HBM */ + + friend std::ostream &operator<<(std::ostream &output, const TransportMemoryRegion &mr) + { + output << "MemoryRegion address size=" << mr.size << ", access=" << mr.access + << ", flags=" << mr.flags << ")"; + return output; + } +}; + +using MemoryRegionMap = std::map>; + +struct TransportMemoryKey { + uint32_t keys[16]; + + friend std::ostream &operator<<(std::ostream &output, const TransportMemoryKey &key) + { + output << "MemoryKey" << std::hex; + for (auto i = 0U; i < sizeof(key.keys) / sizeof(key.keys[0]); i++) { + output << "-" << key.keys[i]; + } + output << std::dec; + return output; + } +}; + +#define container_of(ptr, type, member) \ + ({ \ + const typeof(((const type *)0)->member) *__mptr = (ptr); \ + (const type *)(const void *)((const char *)__mptr - offsetof(type, member)); \ + }) + +union RegMemKeyUnion { + TransportMemoryKey commonKey; + RegMemResult deviceKey; +}; + +struct ConnectRankInfo { + hybm_role_type role; + sockaddr_in network; + RegMemResult mr; + + ConnectRankInfo(hybm_role_type r, sockaddr_in nw, RegMemResult memory_region) : role{r}, + network{std::move(nw)}, mr{memory_region} {} +}; + +struct TransportRankPrepareInfo { + std::string nic; + hybm_role_type role{HYBM_ROLE_PEER}; + RegMemResult mr; + + TransportRankPrepareInfo() {} + + TransportRankPrepareInfo(std::string n, RegMemResult k) + : nic{std::move(n)}, role{HYBM_ROLE_PEER}, mr{k} {} + + TransportRankPrepareInfo(std::string n, hybm_role_type r, RegMemResult k) + : nic{std::move(n)}, role{r}, mr{k} {} + + friend std::ostream &operator<<(std::ostream &output, const TransportRankPrepareInfo &info) + { + output << "PrepareInfo(nic=" << info.nic << ", role=" << info.role << ", mr=" << info.mr; + return output; + } +}; + +struct HybmTransPrepareOptions { + std::unordered_map options; + + friend std::ostream &operator<<(std::ostream &output, const HybmTransPrepareOptions &info) + { + output << "PrepareOptions("; + for (auto &op : info.options) { + output << op.first << " => " << op.second << ", "; + } + output << ")"; + return output; + } +}; +#endif // DL_HCCP_DEF_H \ No newline at end of file diff --git a/src/modules/transport/rdma/rdma_manager.h b/src/modules/transport/rdma/rdma_manager.h new file mode 100644 index 00000000..11d4783e --- /dev/null +++ b/src/modules/transport/rdma/rdma_manager.h @@ -0,0 +1,419 @@ +#include +#include +#include "dl_hccp_api.h" +#include "acl/acl.h" +#include "device_qp_manager.h" + +class rdma_manager { +public: + rdma_manager() {} + int OpenDevice(const TransportOptions &options) + { + int32_t deviceId = -1; + + // SHM_LOG_DEBUG("begin to open device with " << options); + auto ret = aclrtGetDevice(&deviceId); + if (ret != 0 || deviceId < 0) { + SHM_LOG_ERROR("AclrtGetDevice() return=" << ret << ", output deviceId=" << deviceId); + return -1; + } + deviceId_ = static_cast(deviceId); + rankId_ = options.rankId; + rankCount_ = options.rankCount; + role_ = options.role; + auto port = options.nic; + if (port < 0 || port > 65536) { + SHM_LOG_ERROR("Failed to parse nic info, nic = " << options.nic); + } + devicePort_ = static_cast(port); + DlHccpApi::LoadLibrary(); + + if (!PrepareOpenDevice(deviceId_, rankCount_, deviceIp_, rdmaHandle_)) { + SHM_LOG_ERROR("PrepareOpenDevice failed."); + return -1; + } + + SHM_LOG_INFO("ip = " << inet_ntoa(deviceIp_) << ", port = " << devicePort_); + + sockaddr_in deviceAddr; + deviceAddr.sin_family = AF_INET; + deviceAddr.sin_addr = deviceIp_; + deviceAddr.sin_port = devicePort_; + qpManager_ = new DeviceQpManager(deviceId_, rankId_, rankCount_, deviceAddr, HYBM_ROLE_PEER); + + return 0; + } + + void* GetQPInfoAddr() { + return qpManager_->GetQpInfoAddress(); + } + + in_addr GetDeviceIP() { + return deviceIp_; + } + + Result RegisterMemoryRegion(const TransportMemoryRegion &mr) + { + void *mrHandle = nullptr; + HccpMrInfo info{}; + info.addr = (void *)(ptrdiff_t)mr.addr; + info.size = mr.size; + info.access = mr.access; + auto ret = DlHccpApi::RaRegisterMR(rdmaHandle_, &info, mrHandle); + if (ret != 0) { + SHM_LOG_ERROR("register MR=" << mr << " failed: " << ret); + return SHMEM_INNER_ERROR; + } + + RegMemResult result{mr.addr, mr.size, mrHandle, info.lkey, info.rkey}; + localMR_ = result; + SHM_LOG_DEBUG("register MR result=" << result); + + registerMRS_.emplace(mr.addr, result); + ret = qpManager_->SetLocalMemories(registerMRS_); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("qp manager set mr failed: " << ret); + return ret; + } + return 0; + } + + Result UnregisterMemoryRegion(uint64_t addr) + { + auto pos = registerMRS_.find(addr); + if (pos == registerMRS_.end()) { + SHM_LOG_ERROR("input address not register!"); + return SHMEM_INVALID_PARAM; + } + + auto ret = DlHccpApi::RaDeregisterMR(rdmaHandle_, pos->second.mrHandle); + if (ret != 0) { + SHM_LOG_ERROR("Unregister MR addr failed: " << ret); + return SHMEM_INNER_ERROR; + } + + registerMRS_.erase(pos); + // ret = qpManager_->SetLocalMemories(registerMRS_); + // if (ret != SHMEM_SUCCESS) { + // SHM_LOG_ERROR("qp manager set mr failed: " << ret); + // return ret; + // } + return 0; + } + + RegMemResult GetLocalMR() { + return localMR_; + } + + Result Prepare(const HybmTransPrepareOptions &options) + { + SHM_LOG_DEBUG("RdmaTransportManager Prepare with : " << options); + int ret; + if ((ret = CheckPrepareOptions(options)) != 0) { + return ret; + } + + sockaddr_in deviceNetwork; + std::unordered_map rankInfo; + for (auto it = options.options.begin(); it != options.options.end(); ++it) { + ret = ipPortStringToSockaddr(it->second.nic, deviceNetwork); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("parse networks[" << it->first << "]=" << it->second.nic << " failed: " << ret); + return SHMEM_INVALID_PARAM; + } + + rankInfo.emplace(it->first, ConnectRankInfo{it->second.role, deviceNetwork, it->second.mr}); + } + + ret = qpManager_->SetRemoteRankInfo(rankInfo); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("qp manager set remote rank info failed: " << ret); + return ret; + } + + ret = qpManager_->Startup(rdmaHandle_); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("qp manager startup failed: " << ret); + return ret; + } + + return SHMEM_SUCCESS; + } + + Result Connect() + { + auto ret = AsyncConnect(); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("AsyncConnect() failed: " << ret); + return ret; + } + + ret = WaitForConnected(-1L); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("WaitForConnected(-1) failed: " << ret); + return ret; + } + + return SHMEM_SUCCESS; + } +private: + bool OpenTsd(uint32_t deviceId, uint32_t rankCount) + { + if (tsdOpened_) { + SHM_LOG_INFO("tsd already opened."); + return true; + } + + auto res = DlHccpApi::TsdOpen(deviceId, rankCount); + if (res != 0) { + SHM_LOG_ERROR("TsdOpen for (deviceId=" << deviceId << ", rankCount=" << rankCount << ") failed: " << res); + return false; + } + + SHM_LOG_DEBUG("open tsd for device id: " << deviceId << ", rank count: " << rankCount << " success."); + tsdOpened_ = true; + return true; + } + + bool RaInit(uint32_t deviceId) + { + if (raInitialized_) { + SHM_LOG_INFO("ra already initialized."); + return true; + } + + HccpRaInitConfig initConfig{}; + initConfig.phyId = deviceId; + initConfig.nicPosition = NETWORK_OFFLINE; + initConfig.hdcType = 6; // HDC_SERVICE_TYPE_RDMA = 6 + SHM_LOG_DEBUG("RaInit=" << initConfig); + auto ret = DlHccpApi::RaInit(initConfig); + if (ret != 0) { + SHM_LOG_ERROR("Hccp Init RA failed: " << ret); + return false; + } + + SHM_LOG_DEBUG("ra init for device id: " << deviceId << " success."); + raInitialized_ = true; + return true; + } + + bool RetireDeviceIp(uint32_t deviceId, in_addr &deviceIp) + { + static in_addr retiredIp{}; + + if (deviceIpRetired_) { + SHM_LOG_INFO("device ip already retired : " << inet_ntoa(retiredIp)); + deviceIp = retiredIp; + return true; + } + + uint32_t count = 0; + std::vector infos; + + HccpRaGetIfAttr config; + config.phyId = deviceId; + config.nicPosition = NETWORK_OFFLINE; + config.isAll = true; + + auto ret = DlHccpApi::RaGetIfNum(config, count); + if (ret != 0 || count == 0) { + SHM_LOG_ERROR("get interface count failed: " << ret << ", count: " << count); + return false; + } + + infos.resize(count); + ret = DlHccpApi::RaGetIfAddrs(config, infos.data(), count); + if (ret != 0) { + SHM_LOG_ERROR("get interface information failed: " << ret); + return false; + } + + for (auto &info : infos) { + if (info.family == AF_INET) { + deviceIp = retiredIp = info.ifaddr.ip.addr; + deviceIpRetired_ = true; + SHM_LOG_DEBUG("retire device ip success : " << inet_ntoa(deviceIp)); + return true; + } + } + + SHM_LOG_ERROR("not found network device of AF_INET on NPU."); + return false; + } + + bool RaRdevInit(uint32_t deviceId, in_addr deviceIp, void *&rdmaHandle) + { + if (storedRdmaHandle_ != nullptr) { + SHM_LOG_INFO("ra rdev already initialized."); + rdmaHandle = storedRdmaHandle_; + return true; + } + + HccpRdevInitInfo info{}; + HccpRdev rdev{}; + + info.mode = NETWORK_OFFLINE; + info.notifyType = NOTIFY; + info.enabled2mbLite = true; + rdev.phyId = deviceId; + rdev.family = AF_INET; + rdev.localIp.addr = deviceIp; + SHM_LOG_DEBUG("RaRdevInitV2, info=" << info << "rdev=" << rdev); + auto ret = DlHccpApi::RaRdevInitV2(info, rdev, rdmaHandle); + if (ret != 0) { + SHM_LOG_ERROR("Hccp Init RDev failed: " << ret); + return false; + } + + storedRdmaHandle_ = rdmaHandle; + SHM_LOG_INFO("initialize RDev success."); + return true; + } + + bool PrepareOpenDevice(uint32_t device, uint32_t rankCount, in_addr &deviceIp, void *&rdmaHandle) + { + // If can get rdmaHanle, maybe the device has beed opened, can try get rdmaHanle directly. + if (DlHccpApi::RaRdevGetHandle(device, rdmaHandle) == 0) { + if (rdmaHandle != nullptr) { + if (!RetireDeviceIp(device, deviceIp)) { + SHM_LOG_ERROR("RetireDeviceIp failed."); + return false; + } + SHM_LOG_DEBUG("Had prepared device and get rdmaHandle success."); + return true; + } + SHM_LOG_INFO("Had prepared device, but RdmaHadle is null, need init again."); + } + if (!OpenTsd(device, rankCount)) { + SHM_LOG_ERROR("open tsd failed."); + return false; + } + + if (!RaInit(device)) { + SHM_LOG_ERROR("RaInit failed."); + return false; + } + + if (!RetireDeviceIp(device, deviceIp)) { + SHM_LOG_ERROR("RetireDeviceIp failed."); + return false; + } + + if (!RaRdevInit(device, deviceIp, rdmaHandle)) { + SHM_LOG_ERROR("RaRdevInit failed."); + return false; + } + return true; + } + + Result AsyncConnect() + { + return SHMEM_SUCCESS; + } + + Result WaitForConnected(int64_t timeoutNs) + { + if (qpManager_ == nullptr) { + SHM_LOG_ERROR("server side not listen!"); + return SHMEM_INNER_ERROR; + } + + auto ret = qpManager_->WaitingConnectionReady(); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("wait for server side connected on device failed: " << ret); + return ret; + } + + return SHMEM_SUCCESS; + } + + int CheckPrepareOptions(const HybmTransPrepareOptions &options) + { + if (role_ != HYBM_ROLE_PEER) { + SHM_LOG_INFO("transport role: " << role_ << " check options passed."); + return SHMEM_SUCCESS; + } + + if (options.options.size() > rankCount_) { + SHM_LOG_ERROR("options size():" << options.options.size() << " larger than rank count: " << rankCount_); + return SHMEM_INVALID_PARAM; + } + + if (options.options.find(rankId_) == options.options.end()) { + SHM_LOG_ERROR("options not contains self rankId: " << rankId_); + return SHMEM_INVALID_PARAM; + } + + for (auto it = options.options.begin(); it != options.options.end(); ++it) { + if (it->first >= rankCount_) { + SHM_LOG_ERROR("input options of nics contains rankId:" << it->first << ", rank count: " << rankCount_); + return SHMEM_INVALID_PARAM; + } + } + + return SHMEM_SUCCESS; + } + + Result ipPortStringToSockaddr(const std::string& ip_port_str, sockaddr_in& addr) { + std::memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + + size_t colon_pos = ip_port_str.find(':'); + if (colon_pos == std::string::npos || + colon_pos == 0 || + colon_pos == ip_port_str.length() - 1) { + SHM_LOG_ERROR("format mismatch"); + return SHMEM_INNER_ERROR; + } + + std::string ip_str = ip_port_str.substr(0, colon_pos); + std::string port_str = ip_port_str.substr(colon_pos + 1); + + if (port_str.empty()) { + SHM_LOG_ERROR("Port not available!"); + return SHMEM_INNER_ERROR; + } + + for (char c : port_str) { + if (!std::isdigit(static_cast(c))) { + SHM_LOG_ERROR("Port contains non-digit characters!"); + return SHMEM_INNER_ERROR; + } + } + + char* endptr; + unsigned long port = std::strtoul(port_str.c_str(), &endptr, 10); + + if (endptr == port_str.c_str() || *endptr != '\0' || + port == 0 || port > 65535) { + SHM_LOG_ERROR("Port out of range!"); + return SHMEM_INNER_ERROR; + } + + addr.sin_port = htons(static_cast(port)); + + // Transform IP address + if (inet_pton(AF_INET, ip_str.c_str(), &addr.sin_addr) != 1) { + SHM_LOG_ERROR("IP address invalid!"); + return SHMEM_INNER_ERROR; + } + + return SHMEM_SUCCESS; + } + + uint32_t rankId_{0}; + uint32_t rankCount_{1}; + uint32_t deviceId_{0}; + hybm_role_type role_{HYBM_ROLE_PEER}; + in_addr deviceIp_{0}; + uint16_t devicePort_{0}; + void *rdmaHandle_{nullptr}; + void *storedRdmaHandle_{nullptr}; + bool tsdOpened_{0}; + bool raInitialized_{0}; + bool deviceIpRetired_{0}; + DeviceQpManager* qpManager_; + RegMemResult localMR_; + MemoryRegionMap registerMRS_; +}; \ No newline at end of file diff --git a/src/modules/transport/shmemi_rdma.cpp b/src/modules/transport/shmemi_rdma.cpp index 2fa76509..fed9e41e 100644 --- a/src/modules/transport/shmemi_rdma.cpp +++ b/src/modules/transport/shmemi_rdma.cpp @@ -7,34 +7,61 @@ * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. * See LICENSE in the root of the software repository for the full text of the License. */ -#include "shmemi_transport.h" +#include +#include +#include +#include +#include "host/shmem_host_def.h" +#include "rdma/rdma_manager.h" +#include "common/shmemi_host_types.h" +#include "common/shmemi_logger.h" +#include "internal/host_device/shmemi_types.h" -typedef struct { +#ifdef __cplusplus +extern "C" { +#endif +static rdma_manager* manager; -} shmemi_rdmad_transport_state_t; +int shmemi_rdma_init(shmemi_device_host_state_t *state, shmemi_transport_t *t) { + manager = new rdma_manager; + TransportOptions options; + options.rankId = state->mype; + options.rankCount = state->npes; + options.protocol = 7; + options.nic = 10002; + manager->OpenDevice(options); + auto local_device_ip = manager->GetDeviceIP(); + SHM_LOG_INFO("local ip = " << inet_ntoa(local_device_ip)); + std::vector device_ips(state->npes); + g_boot_handle.allgather(&local_device_ip, device_ips.data(), sizeof(in_addr), &g_boot_handle); + g_boot_handle.barrier(&g_boot_handle); + for (int i = 0; i < state->npes; i++) { + SHM_LOG_INFO("get rank " << i << ", device ip = " << inet_ntoa(device_ips[i])); + } -static shmemi_rdmad_transport_state_t shmemi_rdmad_transport_state; + TransportMemoryRegion mr; + mr.addr = reinterpret_cast(state->heap_base); + mr.size = reinterpret_cast(state->heap_size); + manager->RegisterMemoryRegion(mr); + auto local_mr = manager->GetLocalMR(); + SHM_LOG_INFO("local mr = " << local_mr); + std::vector mrs(state->npes); + g_boot_handle.allgather(&local_mr, mrs.data(), sizeof(RegMemResult), &g_boot_handle); + for (int i = 0; i < state->npes; i++) { + SHM_LOG_INFO("get rank " << i << ", mr info = " << mrs[i]); + } -// control plane -int shmemi_rdmad_init(shmemi_host_state_t *state, shmemi_transport_t *t) { - -} - -int shmemi_rdmad_can_access_peer(int *access, shmemi_transport_pe_info_t *peer, shmemi_transport_t *t) { - // true + HybmTransPrepareOptions TransPrepareOp; + for (int i = 0; i < state->npes; i++) { + TransPrepareOp.options[i].nic = std::string(inet_ntoa(device_ips[i])) + ":4647"; + TransPrepareOp.options[i].mr = mrs[i]; + } + manager->Prepare(TransPrepareOp); + manager->Connect(); + state->qp_info = reinterpret_cast(manager->GetQPInfoAddr()); + return 0; } -int shmemi_rdmad_connect_peers(shmemi_transport_t *t, int *selected_dev_ids, int num_selected_devs) { - // 建立QP链接 —— 获取NIC ip,check_peer_access(所有),创建sockets,创建(多)qp并连接 +#ifdef __cplusplus } - -int shmemi_rdmad_finalize(shmemi_transport_t *t) { - -} - -shmemi_transport_t shmemi_rdmad_transport_state = { - .init = shmemi_rdmad_init, - .finalize = shmemi_rdmad_finalize, - .can_access_peer = shmemi_rdmad_can_access_peer, - .connect_peers = shmemi_rdmad_connect_peers, -} \ No newline at end of file +#endif \ No newline at end of file -- Gitee From e93449d3d7ec6b0cde5b70013d22c0be66c97714 Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Thu, 23 Oct 2025 21:04:38 +0800 Subject: [PATCH 47/74] Fix rdma_demo. --- examples/CMakeLists.txt | 2 +- examples/rdma_demo/main.cpp | 15 ++++++---- examples/rdma_test/main.cpp | 4 +-- .../low_level/shmem_device_low_level_rma.h | 29 ++++++++++++++++--- include/device/shmem_device_rma.h | 8 ++--- include/internal/host_device/shmemi_types.h | 1 + src/host/init/shmem_init.cpp | 3 +- src/modules/transport/shmemi_rdma.cpp | 1 + 8 files changed, 46 insertions(+), 17 deletions(-) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 83389522..2288c277 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -57,7 +57,7 @@ foreach(EXAMPLE rdma_test # matmul_allreduce # rdma_perftest - # rdma_demo + rdma_demo ) add_subdirectory(${EXAMPLE}) endforeach() \ No newline at end of file diff --git a/examples/rdma_demo/main.cpp b/examples/rdma_demo/main.cpp index 6e0d222a..d519886b 100644 --- a/examples/rdma_demo/main.cpp +++ b/examples/rdma_demo/main.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include "acl/acl.h" #include "shmem_api.h" @@ -36,9 +37,7 @@ int test_shmem_team_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size shmem_init_attr_t *attributes; status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); - attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); uint8_t *ptr = (uint8_t*)shmem_malloc(1024); @@ -87,8 +86,13 @@ int test_shmem_team_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size int main(int argc, char *argv[]) { int status = 0; - int n_ranks = atoi(argv[1]); - int rank_id = atoi(argv[2]); + MPI_Init(&argc, &argv); + + // 获取当前进程的编号(rank) + int n_ranks; + MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + int rank_id; + MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); ipport = argv[3]; g_npus = atoi(argv[4]); f_rank = atoi(argv[5]); @@ -96,6 +100,7 @@ int main(int argc, char *argv[]) uint64_t local_mem_size = 1024UL * 1024UL * 1024; status = test_shmem_team_all_gather(rank_id, n_ranks, local_mem_size); std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; + MPI_Finalize(); return 0; } \ No newline at end of file diff --git a/examples/rdma_test/main.cpp b/examples/rdma_test/main.cpp index d67b8ba0..a176c8f3 100644 --- a/examples/rdma_test/main.cpp +++ b/examples/rdma_test/main.cpp @@ -165,7 +165,7 @@ int test_shmem_rdma_put_poll_cq(int rank_id, int n_ranks, uint64_t local_mem_siz if (rank_id == 1) { aclrtMemcpy(xHost, totalSize, (uint8_t*)dev_ptr, totalSize, ACL_MEMCPY_DEVICE_TO_HOST); for (uint32_t i = 0; i < n_ranks; i++) { - if (xHost[i * messageSize / sizeof(uint32_t)] != i + 10) { + if ((i != rank_id) && (xHost[i * messageSize / sizeof(uint32_t)] != i + 10)) { std::cout << "[ERROR] Put result check error at " << i << std::endl; } } @@ -177,7 +177,7 @@ int test_shmem_rdma_put_poll_cq(int rank_id, int n_ranks, uint64_t local_mem_siz } aclrtFreeHost(xHost); - aclrtFree(dev_ptr); + shmem_free(dev_ptr); status = shmem_finalize(); status = aclrtDestroyStream(stream); diff --git a/include/device/low_level/shmem_device_low_level_rma.h b/include/device/low_level/shmem_device_low_level_rma.h index 879b8a8c..e1db60ed 100644 --- a/include/device/low_level/shmem_device_low_level_rma.h +++ b/include/device/low_level/shmem_device_low_level_rma.h @@ -193,6 +193,27 @@ SHMEM_DEVICE void shmemi_copy_gm2ub(const AscendC::LocalTensor &dstUb, AscendC::DataCopyPad(dstUb, srcGva, copyParams, padParams); } +/** + * @brief Translate an local symmetric address to remote symmetric address on the specified PE used by RDMA. + * + * @param ptr [in] Symmetric address on local PE. + * @param pe [in] The number of the remote PE. + * @return A remote symmetric address on the specified PE that can be accessed using memory loads and stores. + */ +SHMEM_DEVICE __gm__ void *shmem_roce_ptr(__gm__ void *ptr, int pe) +{ + // Get Global State + __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); + + // Back to root address + uint64_t offset = reinterpret_cast(ptr) - reinterpret_cast(device_state->heap_base); + + // Address translate + uint64_t remote_ptr = reinterpret_cast(device_state->rdma_heap_base[pe]) + offset; + + return reinterpret_cast<__gm__ void *>(remote_ptr); +} + /** * @brief Asynchronous interface. Copy contiguous data on symmetric memory from the specified PE to address on the local device. * @@ -248,7 +269,7 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(__gm__ T *dst, __gm__ T *src, __ubuf__ T template SHMEM_DEVICE void shmem_roce_get_mem_nbi(__gm__ T* dst, __gm__ T* src, __ubuf__ T* buf, uint32_t elem_size, int pe) { - auto ptr = shmem_ptr(src, pe); + auto ptr = shmem_roce_ptr(src, pe); AscendC::LocalTensor ub_tensor_32; ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); ub_tensor_32.address_.bufferAddr = reinterpret_cast(buf); @@ -362,7 +383,7 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(AscendC::GlobalTensor dst, AscendC::G template SHMEM_DEVICE void shmem_roce_get_mem_nbi(AscendC::GlobalTensor dst, AscendC::GlobalTensor src, AscendC::LocalTensor buf, uint32_t elem_size, int pe) { - auto ptr = shmem_ptr((__gm__ void *)src.GetPhyAddr(), pe); + auto ptr = shmem_roce_ptr((__gm__ void *)src.GetPhyAddr(), pe); AscendC::LocalTensor ub_tensor_32; ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); ub_tensor_32.address_.bufferAddr = reinterpret_cast(buf.GetPhyAddr()); @@ -466,7 +487,7 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(__gm__ T *dst, __gm__ T *src, __ubuf__ T template SHMEM_DEVICE void shmem_roce_put_mem_nbi(__gm__ T* dst, __gm__ T* src, __ubuf__ T* buf, uint32_t elem_size, int pe) { - auto ptr = shmem_ptr(dst, pe); + auto ptr = shmem_roce_ptr(dst, pe); AscendC::LocalTensor ub_tensor_32; ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); ub_tensor_32.address_.bufferAddr = reinterpret_cast(buf); @@ -579,7 +600,7 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(AscendC::GlobalTensor dst, AscendC::G template SHMEM_DEVICE void shmem_roce_put_mem_nbi(AscendC::GlobalTensor dst, AscendC::GlobalTensor src, AscendC::LocalTensor buf, uint32_t elem_size, int pe, AscendC::TEventID EVENT_ID) { - auto ptr = shmem_ptr((__gm__ void *)dst.GetPhyAddr(), pe); + auto ptr = shmem_roce_ptr((__gm__ void *)dst.GetPhyAddr(), pe); AscendC::LocalTensor ub_tensor_32; ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); ub_tensor_32.address_.bufferAddr = reinterpret_cast(buf.GetPhyAddr()); diff --git a/include/device/shmem_device_rma.h b/include/device/shmem_device_rma.h index cbab2a2d..cb85f12b 100644 --- a/include/device/shmem_device_rma.h +++ b/include/device/shmem_device_rma.h @@ -255,7 +255,7 @@ SHMEM_DEVICE void shmem_getmem_nbi(__gm__ void *dst, __gm__ void *src, uint32_t copy_event_id); \ } else if (device_state->topo_list[pe] & SHMEM_TRANSPORT_ROCE) { \ /* RoCE */ \ - auto ptr = shmem_ptr(src, pe); \ + auto ptr = shmem_roce_ptr(src, pe); \ if (ptr == nullptr) return; \ /* Create LocalTensor */ \ AscendC::LocalTensor ub_tensor_32; \ @@ -330,7 +330,7 @@ SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_DETAILED_NBI); shmem_mte_get_mem_nbi(dst, src, ub_tensor, elem_size, pe, copy_event_id); \ } else if (device_state->topo_list[pe] & SHMEM_TRANSPORT_ROCE) { \ /* RoCE */ \ - auto ptr = shmem_ptr((__gm__ void *)src.GetPhyAddr(), pe); \ + auto ptr = shmem_roce_ptr((__gm__ void *)src.GetPhyAddr(), pe); \ if (ptr == nullptr) return; \ /* Create LocalTensor */ \ AscendC::LocalTensor ub_tensor_32; \ @@ -403,7 +403,7 @@ SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_TENSOR_DETAILED_NBI); copy_event_id); \ } else if (device_state->topo_list[pe] & SHMEM_TRANSPORT_ROCE) { \ /* RoCE */ \ - auto ptr = shmem_ptr(dst, pe); \ + auto ptr = shmem_roce_ptr(dst, pe); \ if (ptr == nullptr) return; \ /* Create LocalTensor */ \ AscendC::LocalTensor ub_tensor_32; \ @@ -477,7 +477,7 @@ SHMEM_TYPE_FUNC(SHMEM_PUT_TYPENAME_MEM_DETAILED_NBI); shmem_mte_put_mem_nbi(dst, src, ub_tensor, elem_size, pe, copy_event_id); \ } else if (device_state->topo_list[pe] & SHMEM_TRANSPORT_ROCE) { \ /* RoCE */ \ - auto ptr = shmem_ptr((__gm__ void *)dst.GetPhyAddr(), pe); \ + auto ptr = shmem_roce_ptr((__gm__ void *)dst.GetPhyAddr(), pe); \ if (ptr == nullptr) return; \ /* Create LocalTensor */ \ AscendC::LocalTensor ub_tensor_32; \ diff --git a/include/internal/host_device/shmemi_types.h b/include/internal/host_device/shmemi_types.h index cc60482b..cc0389b4 100644 --- a/include/internal/host_device/shmemi_types.h +++ b/include/internal/host_device/shmemi_types.h @@ -77,6 +77,7 @@ typedef struct { int npes; void *heap_base; void *p2p_heap_base[SHMEM_MAX_RANKS]; + void *rdma_heap_base[SHMEM_MAX_RANKS]; void *sdma_heap_base[SHMEM_MAX_RANKS]; uint8_t topo_list[SHMEM_MAX_RANKS]; size_t heap_size; diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index d2686fee..b4a36cf2 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -39,6 +39,7 @@ constexpr int DEFAULT_BLOCK_NUM = 1; (DEFAULT_N_PES), /* npes */ \ NULL, /* heap_base */ \ {NULL}, /* p2p_heap_base */ \ + {NULL}, /* rdma_heap_base */ \ {NULL}, /* sdma_heap_base */ \ {}, /* topo_list */ \ SIZE_MAX, /* heap_size */ \ @@ -51,7 +52,7 @@ constexpr int DEFAULT_BLOCK_NUM = 1; false, /* shmem_is_shmem_initialized */ \ false, /* shmem_is_shmem_created */ \ {0, 16 * 1024, 0}, /* shmem_mte_config */ \ - 0, /* qp_info */ \ + 0, /* qp_info */ \ } shmemi_device_host_state_t g_state = SHMEM_DEVICE_HOST_STATE_INITIALIZER; diff --git a/src/modules/transport/shmemi_rdma.cpp b/src/modules/transport/shmemi_rdma.cpp index fed9e41e..64b25a57 100644 --- a/src/modules/transport/shmemi_rdma.cpp +++ b/src/modules/transport/shmemi_rdma.cpp @@ -48,6 +48,7 @@ int shmemi_rdma_init(shmemi_device_host_state_t *state, shmemi_transport_t *t) { std::vector mrs(state->npes); g_boot_handle.allgather(&local_mr, mrs.data(), sizeof(RegMemResult), &g_boot_handle); for (int i = 0; i < state->npes; i++) { + state->rdma_heap_base[i] = reinterpret_cast(mrs[i].address); SHM_LOG_INFO("get rank " << i << ", mr info = " << mrs[i]); } -- Gitee From 21833f9427251fa6cb067ba42c5d3c7435c97e6e Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Tue, 28 Oct 2025 20:20:05 +0800 Subject: [PATCH 48/74] RDMA transport refactor. --- examples/rdma_perftest/main.cpp | 26 ++++----- examples/rdma_test/main.cpp | 1 - src/host/transport/shmemi_transport.cpp | 66 ++++++++++------------- src/modules/transport/rdma/rdma_manager.h | 28 +++++++--- src/modules/transport/shmemi_rdma.cpp | 45 +++++++++++----- 5 files changed, 94 insertions(+), 72 deletions(-) diff --git a/examples/rdma_perftest/main.cpp b/examples/rdma_perftest/main.cpp index d29537af..138df4d4 100644 --- a/examples/rdma_perftest/main.cpp +++ b/examples/rdma_perftest/main.cpp @@ -44,9 +44,7 @@ int test_shmem_rdma_highlevel_put_pingpong_latency(int rank_id, int n_ranks, uin shmem_init_attr_t *attributes; status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); - attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); uint64_t fftsConfig = shmemx_get_ffts_config(); uint8_t* gva = (uint8_t*)shmem_malloc(1024 * 1024 * 6); @@ -91,9 +89,7 @@ int test_shmem_rdma_postsend_cost(int rank_id, int n_ranks, uint64_t local_mem_s shmem_init_attr_t *attributes; status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); - attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); uint64_t fftsConfig = shmemx_get_ffts_config(); uint8_t* gva = (uint8_t*)shmem_malloc(1024 * 1024 * 6); @@ -136,9 +132,7 @@ int test_shmem_rdma_highlevel_put_bw(int rank_id, int n_ranks, uint64_t local_me shmem_init_attr_t *attributes; status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); - attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); uint64_t fftsConfig = shmemx_get_ffts_config(); uint8_t* gva = (uint8_t*)shmem_malloc(1024 * 1024 * 6); @@ -179,9 +173,7 @@ int test_shmem_rdma_mte_put_bw(int rank_id, int n_ranks, uint64_t local_mem_size shmem_init_attr_t *attributes; status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); - attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); shmem_mte_set_ub_params(0, 128 * 1024, 0); uint64_t fftsConfig = shmemx_get_ffts_config(); @@ -234,12 +226,20 @@ int main(int argc, char *argv[]) } int status = 0; int n_ranks = atoi(argv[1]); + MPI_Init(&argc, &argv); + + // 获取当前进程的编号(rank) + int n_ranks; + MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + int rank_id; + MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); if (n_ranks != 2) { std::cout << "[ERROR] Error number of ranks! Only support 2 ranks!" << std::endl; + return -1; } - int rank_id = atoi(argv[2]); if (rank_id >= 2) { std::cout << "[ERROR] Error rank ID! Only support 2 ranks!" << std::endl; + return -1; } ipport = argv[3]; g_npus = atoi(argv[4]); diff --git a/examples/rdma_test/main.cpp b/examples/rdma_test/main.cpp index a176c8f3..9467bd78 100644 --- a/examples/rdma_test/main.cpp +++ b/examples/rdma_test/main.cpp @@ -143,7 +143,6 @@ int test_shmem_rdma_put_poll_cq(int rank_id, int n_ranks, uint64_t local_mem_siz status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); void* dev_ptr = shmem_malloc(1024); - std::cout << "gva address = " << dev_ptr << std::endl; uint32_t *xHost; size_t messageSize = 64; size_t totalSize = messageSize * n_ranks; diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp index ef3af22a..d542aeac 100644 --- a/src/host/transport/shmemi_transport.cpp +++ b/src/host/transport/shmemi_transport.cpp @@ -13,31 +13,10 @@ #include "transport/shmemi_transport.h" static void *transport_mte_lib = NULL; -static void *rdma_plugin_hdl = nullptr; -static char *rdma_plugin_name = nullptr; - -int (*shmemi_rdma_init)(shmemi_device_host_state_t *state, shmemi_transport_t *t); +static void *transport_rdma_lib = NULL; uint64_t *host_hash_list; -void shmemi_transport_load() -{ - dlerror(); - if (rdma_plugin_hdl == nullptr) { - - rdma_plugin_hdl = dlopen(rdma_plugin_name, RTLD_NOW); - } - dlerror(); -} - -void shmemi_transport_unload() -{ - if (rdma_plugin_hdl != nullptr) { - dlclose(rdma_plugin_hdl); - rdma_plugin_hdl = nullptr; - } -} - shmemi_host_state_t g_host_state; int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state) { @@ -50,12 +29,15 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state) { SHM_LOG_ERROR("Transport unable to load " << "shmem_transport_mte.so" << ", err is: " << stderr); return SHMEM_INVALID_VALUE; } - rdma_plugin_name = TRANSPORT_MODULE_RDMA; - shmemi_transport_load(); + transport_rdma_lib = dlopen("shmem_transport_rdma.so", RTLD_NOW); + if (!transport_rdma_lib) { + SHM_LOG_ERROR("Transport unable to load " << "shmem_transport_rdma.so" << ", err is: " << stderr); + return SHMEM_INVALID_VALUE; + } - transport_init_func init_fn; - init_fn = (transport_init_func)dlsym(transport_mte_lib, "shmemi_mte_init"); - if (!init_fn) { + transport_init_func init_mte_fn; + init_mte_fn = (transport_init_func)dlsym(transport_mte_lib, "shmemi_mte_init"); + if (!init_mte_fn) { dlclose(transport_mte_lib); transport_mte_lib = NULL; SHM_LOG_ERROR("Unable to get info from " << "shmem_transport_mte.so" << "."); @@ -74,21 +56,17 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state) { // AllGather All pe's host info g_boot_handle.allgather((void *)&my_info, g_host_state.pe_info, sizeof(shmemi_transport_pe_info_t), &g_boot_handle); - SHMEM_CHECK_RET(init_fn(&g_host_state.choosen_transports[0], &g_state)); + SHMEM_CHECK_RET(init_mte_fn(&g_host_state.choosen_transports[0], &g_state)); - if (!rdma_plugin_hdl) { - SHM_LOG_ERROR("Bootstrap unable to load " << rdma_plugin_name << ", err is: " << stderr); - shmemi_transport_unload(); + transport_init_func init_rdma_fn; + init_rdma_fn = (transport_init_func)dlsym(transport_rdma_lib, "shmemi_rdma_init"); + if (!init_rdma_fn) { + dlclose(transport_rdma_lib); + transport_rdma_lib = NULL; + SHM_LOG_ERROR("Unable to get info from " << "shmem_transport_rdma.so" << "."); return SHMEM_INVALID_VALUE; } - - *((void **)&shmemi_rdma_init) = dlsym(rdma_plugin_hdl, "shmemi_rdma_init"); - if (!shmemi_rdma_init) { - SHM_LOG_ERROR("Bootstrap plugin init func dlsym failed"); - shmemi_transport_unload(); - return SHMEM_INNER_ERROR; - } - SHMEM_CHECK_RET(shmemi_rdma_init(&g_state, nullptr)); + SHMEM_CHECK_RET(init_rdma_fn(&g_host_state.choosen_transports[1], &g_state)); return SHMEM_SUCCESS; } @@ -141,6 +119,8 @@ int32_t shmemi_transport_setup_connections(shmemi_device_host_state_t &g_state) } } + t.connect_peers(&t, mte_peer_list, mte_peer_num, &g_state); + t = g_host_state.choosen_transports[1]; t.connect_peers(&t, mte_peer_list, mte_peer_num, &g_state); return 0; @@ -156,5 +136,13 @@ int32_t shmemi_transport_finalize() { dlclose(transport_mte_lib); transport_mte_lib = NULL; } + + t = g_host_state.choosen_transports[1]; + t.finalize(&t, &g_state); + + if (transport_rdma_lib != NULL) { + dlclose(transport_rdma_lib); + transport_rdma_lib = NULL; + } return 0; } diff --git a/src/modules/transport/rdma/rdma_manager.h b/src/modules/transport/rdma/rdma_manager.h index 11d4783e..a94b027c 100644 --- a/src/modules/transport/rdma/rdma_manager.h +++ b/src/modules/transport/rdma/rdma_manager.h @@ -7,11 +7,30 @@ class rdma_manager { public: rdma_manager() {} + + ~rdma_manager() { + delete qpManager_; + ClearAllRegisterMRs(); + tsdOpened_ = false; + raInitialized_ = false; + deviceIpRetired_ = false; + storedRdmaHandle_ = nullptr; + } + + void ClearAllRegisterMRs() + { + for (auto it = registerMRS_.begin(); it != registerMRS_.end(); ++it) { + auto ret = DlHccpApi::RaDeregisterMR(rdmaHandle_, it->second.mrHandle); + if (ret != 0) { + SHM_LOG_ERROR("Unregister:" << (void *)(ptrdiff_t)it->first << " : " << it->second << " failed: " << ret); + } + } + registerMRS_.clear(); + } + int OpenDevice(const TransportOptions &options) { int32_t deviceId = -1; - - // SHM_LOG_DEBUG("begin to open device with " << options); auto ret = aclrtGetDevice(&deviceId); if (ret != 0 || deviceId < 0) { SHM_LOG_ERROR("AclrtGetDevice() return=" << ret << ", output deviceId=" << deviceId); @@ -93,11 +112,6 @@ public: } registerMRS_.erase(pos); - // ret = qpManager_->SetLocalMemories(registerMRS_); - // if (ret != SHMEM_SUCCESS) { - // SHM_LOG_ERROR("qp manager set mr failed: " << ret); - // return ret; - // } return 0; } diff --git a/src/modules/transport/shmemi_rdma.cpp b/src/modules/transport/shmemi_rdma.cpp index 64b25a57..c4057eba 100644 --- a/src/modules/transport/shmemi_rdma.cpp +++ b/src/modules/transport/shmemi_rdma.cpp @@ -22,14 +22,15 @@ extern "C" { #endif static rdma_manager* manager; -int shmemi_rdma_init(shmemi_device_host_state_t *state, shmemi_transport_t *t) { - manager = new rdma_manager; - TransportOptions options; - options.rankId = state->mype; - options.rankCount = state->npes; - options.protocol = 7; - options.nic = 10002; - manager->OpenDevice(options); +int shmemi_rdma_can_access_peer(int *access, shmemi_transport_pe_info_t *peer_info, shmemi_transport *t, shmemi_device_host_state_t *state) { + if (peer_info->pe == state->mype) { + *access = 0; + } else { + *access = 1; + } +} + +int shmemi_rdma_connect_peers(shmemi_transport *t, int *selected_dev_ids, int num_selected_devs, shmemi_device_host_state_t *state) { auto local_device_ip = manager->GetDeviceIP(); SHM_LOG_INFO("local ip = " << inet_ntoa(local_device_ip)); std::vector device_ips(state->npes); @@ -39,10 +40,6 @@ int shmemi_rdma_init(shmemi_device_host_state_t *state, shmemi_transport_t *t) { SHM_LOG_INFO("get rank " << i << ", device ip = " << inet_ntoa(device_ips[i])); } - TransportMemoryRegion mr; - mr.addr = reinterpret_cast(state->heap_base); - mr.size = reinterpret_cast(state->heap_size); - manager->RegisterMemoryRegion(mr); auto local_mr = manager->GetLocalMR(); SHM_LOG_INFO("local mr = " << local_mr); std::vector mrs(state->npes); @@ -63,6 +60,30 @@ int shmemi_rdma_init(shmemi_device_host_state_t *state, shmemi_transport_t *t) { return 0; } +int shmemi_rdma_finalize(shmemi_transport *t, shmemi_device_host_state_t *state) { + delete manager; + return 0; +} + +int shmemi_rdma_init(shmemi_transport *t, shmemi_device_host_state_t *state) { + manager = new rdma_manager; + TransportOptions options; + options.rankId = state->mype; + options.rankCount = state->npes; + options.protocol = 7; + options.nic = 10002; + manager->OpenDevice(options); + + TransportMemoryRegion mr; + mr.addr = reinterpret_cast(state->heap_base); + mr.size = reinterpret_cast(state->heap_size); + manager->RegisterMemoryRegion(mr); + t->can_access_peer = shmemi_rdma_can_access_peer; + t->connect_peers = shmemi_rdma_connect_peers; + t->finalize = shmemi_rdma_finalize; + return 0; +} + #ifdef __cplusplus } #endif \ No newline at end of file -- Gitee From e809931a9cd3c9a4fb2f35bc87e76efea2ddc881 Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Mon, 3 Nov 2025 15:03:00 +0800 Subject: [PATCH 49/74] Support ASCEND_RT_VISIBLE_DEVICES. Fix RDMA examples. --- examples/CMakeLists.txt | 2 +- examples/rdma_demo/README.md | 8 ++-- examples/rdma_demo/main.cpp | 8 ++-- examples/rdma_perftest/README.md | 7 +-- examples/rdma_perftest/main.cpp | 17 +++---- .../rdma_perftest/rdma_perftest_kernel.cpp | 48 ++++++++++--------- src/modules/transport/rdma/rdma_manager.h | 43 ++++++++++++++--- 7 files changed, 82 insertions(+), 51 deletions(-) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 2288c277..49b00f6e 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -56,7 +56,7 @@ foreach(EXAMPLE allgather rdma_test # matmul_allreduce - # rdma_perftest + rdma_perftest rdma_demo ) add_subdirectory(${EXAMPLE}) diff --git a/examples/rdma_demo/README.md b/examples/rdma_demo/README.md index 410ef0e1..bf585a7f 100644 --- a/examples/rdma_demo/README.md +++ b/examples/rdma_demo/README.md @@ -7,15 +7,13 @@ bash scripts/build.sh ```bash export PROJECT_ROOT= export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/smem/lib64:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/hybm/lib64:$LD_LIBRARY_PATH -./build/bin/rdma_demo 2 0 tcp://127.0.0.1:8765 2 0 0 # rank 0 -./build/bin/rdma_demo 2 1 tcp://127.0.0.1:8765 2 0 0 # rank 1 +mpirun -np 2 ./build/bin/rdma_demo tcp://127.0.0.1:8765 2 0 0 ``` 3.命令行参数说明 - ./rdma_demo + mpirun -np ./rdma_demo -- n_ranks: 全局Rank数量,只支持2个Rank。 -- rank_id: 当前进程的Rank号。 +- n_ranks: 全局Rank数量。 - ipport: SHMEM初始化需要的IP及端口号,格式为tcp://:<端口号>。如果执行跨机测试,需要讲IP设为rank0所在Host的IP。 - g_npus: 当前卡上启动的NPU数量。 - f_rank: 当前卡上使用的第一个Rank号。 diff --git a/examples/rdma_demo/main.cpp b/examples/rdma_demo/main.cpp index d519886b..fd14fd7f 100644 --- a/examples/rdma_demo/main.cpp +++ b/examples/rdma_demo/main.cpp @@ -93,10 +93,10 @@ int main(int argc, char *argv[]) MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); int rank_id; MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); - ipport = argv[3]; - g_npus = atoi(argv[4]); - f_rank = atoi(argv[5]); - f_npu = atoi(argv[6]); + ipport = argv[1]; + g_npus = atoi(argv[2]); + f_rank = atoi(argv[3]); + f_npu = atoi(argv[4]); uint64_t local_mem_size = 1024UL * 1024UL * 1024; status = test_shmem_team_all_gather(rank_id, n_ranks, local_mem_size); std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; diff --git a/examples/rdma_perftest/README.md b/examples/rdma_perftest/README.md index 1c73d65e..cdf91185 100644 --- a/examples/rdma_perftest/README.md +++ b/examples/rdma_perftest/README.md @@ -7,15 +7,12 @@ bash scripts/build.sh ```bash export PROJECT_ROOT= export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/smem/lib64:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/hybm/lib64:$LD_LIBRARY_PATH -./build/bin/rdma_perftest 2 0 tcp://127.0.0.1:8765 2 0 0 highlevel_put_pingpong_latency 64 # rank 0 -./build/bin/rdma_perftest 2 1 tcp://127.0.0.1:8765 2 0 0 highlevel_put_pingpong_latency 64 # rank 1 +mpirun -np 2 ./build/bin/rdma_perftest tcp://127.0.0.1:8765 2 0 0 highlevel_put_pingpong_latency 64 ``` 3.命令行参数说明 - ./rdma_perftest + mpirun -np 2 ./rdma_perftest -- n_ranks: 全局Rank数量,只支持2个Rank。 -- rank_id: 当前进程的Rank号。 - ipport: SHMEM初始化需要的IP及端口号,格式为tcp://:<端口号>。如果执行跨机测试,需要讲IP设为rank0所在Host的IP。 - g_npus: 当前卡上启动的NPU数量。 - f_rank: 当前卡上使用的第一个Rank号。 diff --git a/examples/rdma_perftest/main.cpp b/examples/rdma_perftest/main.cpp index 138df4d4..446b22b6 100644 --- a/examples/rdma_perftest/main.cpp +++ b/examples/rdma_perftest/main.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include "acl/acl.h" #include "shmem_api.h" @@ -220,12 +221,11 @@ int test_shmem_rdma_mte_put_bw(int rank_id, int n_ranks, uint64_t local_mem_size int main(int argc, char *argv[]) { - if (argc != 9) { + if (argc != 7) { std::cout << "[ERROR] Paramater number mismatch." << std::endl; std::cout << "[USAGE] ./rdma_perftest . See README for more details." << std::endl; } int status = 0; - int n_ranks = atoi(argv[1]); MPI_Init(&argc, &argv); // 获取当前进程的编号(rank) @@ -241,12 +241,12 @@ int main(int argc, char *argv[]) std::cout << "[ERROR] Error rank ID! Only support 2 ranks!" << std::endl; return -1; } - ipport = argv[3]; - g_npus = atoi(argv[4]); - f_rank = atoi(argv[5]); - f_npu = atoi(argv[6]); - test_type = argv[7]; - int msg_len = atoi(argv[8]); + ipport = argv[1]; + g_npus = atoi(argv[2]); + f_rank = atoi(argv[3]); + f_npu = atoi(argv[4]); + test_type = argv[5]; + int msg_len = atoi(argv[6]); uint64_t local_mem_size = 1024UL * 1024UL * 64; if (std::string(test_type) == "highlevel_put_pingpong_latency") { test_shmem_rdma_highlevel_put_pingpong_latency(rank_id, n_ranks, local_mem_size, msg_len); @@ -259,6 +259,7 @@ int main(int argc, char *argv[]) } std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; + MPI_Finalize(); return 0; } \ No newline at end of file diff --git a/examples/rdma_perftest/rdma_perftest_kernel.cpp b/examples/rdma_perftest/rdma_perftest_kernel.cpp index 733ba7dd..33ad60cd 100644 --- a/examples/rdma_perftest/rdma_perftest_kernel.cpp +++ b/examples/rdma_perftest/rdma_perftest_kernel.cpp @@ -24,8 +24,9 @@ extern "C" __global__ __aicore__ void rdma_highlevel_put_pingpong_latency(uint64 pipe.InitBuffer(buf, UB_ALIGN_SIZE); AscendC::LocalTensor ubLocalRead = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); - int64_t rank = smem_shm_get_global_rank(); - int64_t rank_size = smem_shm_get_global_rank_size(); + __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); + int64_t rank = device_state->mype; + int64_t rank_size = device_state->npes; uint32_t peer; // Warm up @@ -34,13 +35,13 @@ extern "C" __global__ __aicore__ void rdma_highlevel_put_pingpong_latency(uint64 peer = 1; shmem_put_uint8_mem_nbi(warm_addr, warm_addr, WARMUP_MESSAGE_LENGTH, peer); while (*(__gm__ uint32_t*)(gva + rank_size * message_length + WARMUP_MESSAGE_LENGTH * (peer + 1)) != peer + MAGIC_VAL) { - cacheWriteThrough(gva + rank_size * message_length + WARMUP_MESSAGE_LENGTH * (peer + 1), sizeof(uint32_t)); + dcci_cachelines(gva + rank_size * message_length + WARMUP_MESSAGE_LENGTH * (peer + 1), sizeof(uint32_t)); AscendC::GetSystemCycle(); } } else { peer = 0; while (*(__gm__ uint32_t*)(gva + rank_size * message_length + WARMUP_MESSAGE_LENGTH * (peer + 1)) != peer + MAGIC_VAL) { - cacheWriteThrough(gva + rank_size * message_length + WARMUP_MESSAGE_LENGTH * (peer + 1), sizeof(uint32_t)); + dcci_cachelines(gva + rank_size * message_length + WARMUP_MESSAGE_LENGTH * (peer + 1), sizeof(uint32_t)); AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); @@ -55,7 +56,7 @@ extern "C" __global__ __aicore__ void rdma_highlevel_put_pingpong_latency(uint64 int64_t start = AscendC::GetSystemCycle(); shmem_put_uint8_mem_nbi(src_addr, src_addr, message_length, peer); while (*(__gm__ uint32_t*)(gva + message_length * 2 - 8) != peer + MAGIC_VAL) { - cacheWriteThrough(gva + message_length * 2 - 8, 8); + dcci_cachelines(gva + message_length * 2 - 8, 8); AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); @@ -64,7 +65,7 @@ extern "C" __global__ __aicore__ void rdma_highlevel_put_pingpong_latency(uint64 } else { peer = 0; while (*(__gm__ uint32_t*)(gva + message_length * 1 - 8) != peer + MAGIC_VAL) { - cacheWriteThrough(gva + message_length * 1 - 8, 8); + dcci_cachelines(gva + message_length * 1 - 8, 8); AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); @@ -87,8 +88,9 @@ extern "C" __global__ __aicore__ void rdma_postsend_cost(uint64_t fftsConfig, GM AscendC::LocalTensor ubLocal32 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); AscendC::LocalTensor ubLocal64 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); - int64_t rank = smem_shm_get_global_rank(); - int64_t rank_size = smem_shm_get_global_rank_size(); + __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); + int64_t rank = device_state->mype; + int64_t rank_size = device_state->npes; uint32_t peer; // Actual test @@ -96,7 +98,7 @@ extern "C" __global__ __aicore__ void rdma_postsend_cost(uint64_t fftsConfig, GM if (rank == 0) { peer = 1; - GM_ADDR dest_addr = (GM_ADDR)(shmem_ptr(src_addr, peer)); + GM_ADDR dest_addr = (GM_ADDR)(shmem_roce_ptr(src_addr, peer)); int64_t start = AscendC::GetSystemCycle(); for (uint32_t i = 0; i < 500; i++) { shmemi_roce_write(dest_addr, src_addr, peer, 0, message_length, ubLocal64, ubLocal32); @@ -122,8 +124,9 @@ extern "C" __global__ __aicore__ void rdma_highlevel_put_bw(uint64_t fftsConfig, AscendC::LocalTensor ubLocal32 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); AscendC::LocalTensor ubLocal64 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); - int64_t rank = smem_shm_get_global_rank(); - int64_t rank_size = smem_shm_get_global_rank_size(); + __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); + int64_t rank = device_state->mype; + int64_t rank_size = device_state->npes; uint32_t peer; // Actual test @@ -137,7 +140,7 @@ extern "C" __global__ __aicore__ void rdma_highlevel_put_bw(uint64_t fftsConfig, shmemi_roce_quiet(peer, 0, ubLocal64, ubLocal32); shmem_put_uint8_mem_nbi(gva + rank_size * message_length + 8, src_addr, sizeof(uint32_t), peer); while (*(__gm__ uint32_t*)(gva + message_length * rank_size + 16) != peer + MAGIC_VAL) { - cacheWriteThrough(gva + message_length * rank_size + 16, 8); + dcci_cachelines(gva + message_length * rank_size + 16, 8); AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); @@ -146,7 +149,7 @@ extern "C" __global__ __aicore__ void rdma_highlevel_put_bw(uint64_t fftsConfig, } else { peer = 0; while (*(__gm__ uint32_t*)(gva + rank_size * message_length + 8) != peer + MAGIC_VAL) { - cacheWriteThrough(gva + rank_size * message_length + 8, 8); + dcci_cachelines(gva + rank_size * message_length + 8, 8); AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); @@ -169,8 +172,9 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_AD ubLocal64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR + UB_ALIGN_SIZE); ubLocal64.address_.dataLen = UB_ALIGN_SIZE; - int64_t rank = smem_shm_get_global_rank(); - int64_t rank_size = smem_shm_get_global_rank_size(); + __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); + int64_t rank = device_state->mype; + int64_t rank_size = device_state->npes; uint32_t peer; // Core 0, RDMA @@ -180,12 +184,12 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_AD peer = 1; int64_t start = AscendC::GetSystemCycle(); for (int i = 0; i < 10000; i++) { - shmemi_roce_write((GM_ADDR)shmem_ptr(src_addr, peer), src_addr, peer, 0, message_length, ubLocal64, ubLocal32); + shmemi_roce_write((GM_ADDR)shmem_roce_ptr(src_addr, peer), src_addr, peer, 0, message_length, ubLocal64, ubLocal32); } shmemi_roce_quiet(peer, 0, ubLocal64, ubLocal32); - shmemi_roce_write((GM_ADDR)shmem_ptr(gva + rank_size * message_length * 2 + 8, peer), src_addr, peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); + shmemi_roce_write((GM_ADDR)shmem_roce_ptr(gva + rank_size * message_length * 2 + 8, peer), src_addr, peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); while (*(__gm__ int64_t*)(gva + message_length * rank_size * 2 + 16) != peer + MAGIC_VAL + iter) { - cacheWriteThrough(gva + message_length * rank_size * 2 + 16, 8); + dcci_cachelines(gva + message_length * rank_size * 2 + 16, 8); AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); @@ -194,11 +198,11 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_AD } else { peer = 0; while (*(__gm__ int64_t*)(gva + rank_size * message_length * 2 + 8) != peer + MAGIC_VAL + iter) { - cacheWriteThrough(gva + rank_size * message_length * 2 + 8, 8); + dcci_cachelines(gva + rank_size * message_length * 2 + 8, 8); AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); - shmemi_roce_write((GM_ADDR)shmem_ptr(gva + rank_size * message_length * 2 + 16, peer), src_addr, peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); + shmemi_roce_write((GM_ADDR)shmem_roce_ptr(gva + rank_size * message_length * 2 + 16, peer), src_addr, peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); } } else { // core 1, MTE GM_ADDR src_addr = gva + (rank + rank_size) * message_length; @@ -216,7 +220,7 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_AD AscendC::PipeBarrier(); shmem_mte_put_mem_nbi(gva + rank_size * message_length * 2 + 24, src_addr, reinterpret_cast<__ubuf__ uint8_t*>(copy_ub), copy_ub_size, sizeof(uint32_t), peer, copy_event_id); while (*(__gm__ uint32_t*)(gva + message_length * rank_size * 2 + 32) != peer + MAGIC_VAL + iter) { - cacheWriteThrough(gva + message_length * rank_size * 2 + 32, 8); + dcci_cachelines(gva + message_length * rank_size * 2 + 32, 8); AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); @@ -225,7 +229,7 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_AD } else { peer = 0; while (*(__gm__ uint32_t*)(gva + rank_size * message_length * 2 + 24) != peer + MAGIC_VAL + iter) { - cacheWriteThrough(gva + rank_size * message_length * 2 + 24, 8); + dcci_cachelines(gva + rank_size * message_length * 2 + 24, 8); AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); diff --git a/src/modules/transport/rdma/rdma_manager.h b/src/modules/transport/rdma/rdma_manager.h index a94b027c..6c2ccfa2 100644 --- a/src/modules/transport/rdma/rdma_manager.h +++ b/src/modules/transport/rdma/rdma_manager.h @@ -1,14 +1,22 @@ #include #include +#include #include "dl_hccp_api.h" #include "acl/acl.h" #include "device_qp_manager.h" +const char *g_rt_lib_name = "libascendcl.so"; +int (*rtGetLogicDevIdByUserDevIdFunc)(const int32_t, int32_t *const); + class rdma_manager { public: rdma_manager() {} ~rdma_manager() { + if (rt_handle_) { + dlclose(rt_handle_); + rt_handle_ = nullptr; + } delete qpManager_; ClearAllRegisterMRs(); tsdOpened_ = false; @@ -31,12 +39,15 @@ public: int OpenDevice(const TransportOptions &options) { int32_t deviceId = -1; + OpenRTLib(); auto ret = aclrtGetDevice(&deviceId); if (ret != 0 || deviceId < 0) { SHM_LOG_ERROR("AclrtGetDevice() return=" << ret << ", output deviceId=" << deviceId); return -1; } - deviceId_ = static_cast(deviceId); + int32_t logicDeviceId = -1; + SHMEM_CHECK_RET(rtGetLogicDevIdByUserDevIdFunc(deviceId, &logicDeviceId)); + deviceId_ = static_cast(logicDeviceId); rankId_ = options.rankId; rankCount_ = options.rankCount; role_ = options.role; @@ -47,7 +58,7 @@ public: devicePort_ = static_cast(port); DlHccpApi::LoadLibrary(); - if (!PrepareOpenDevice(deviceId_, rankCount_, deviceIp_, rdmaHandle_)) { + if (!PrepareOpenDevice(deviceId, rankCount_, deviceIp_, rdmaHandle_, deviceId_)) { SHM_LOG_ERROR("PrepareOpenDevice failed."); return -1; } @@ -171,6 +182,25 @@ public: return SHMEM_SUCCESS; } private: + bool OpenRTLib() { + if (rt_handle_) { + return true; + } + rt_handle_ = dlopen(g_rt_lib_name, RTLD_NOW); + if (!rt_handle_) { + SHM_LOG_ERROR("dlopen failed: " << dlerror()); + return false; + } + + *((void**)&rtGetLogicDevIdByUserDevIdFunc) = dlsym(rt_handle_, "rtGetLogicDevIdByUserDevId"); + if (!rtGetLogicDevIdByUserDevIdFunc) { + dlclose(rt_handle_); + rt_handle_ = nullptr; + SHM_LOG_ERROR("Unable to get info from " << "libascendcl.so" << "."); + return SHMEM_INVALID_VALUE; + } + return true; + } bool OpenTsd(uint32_t deviceId, uint32_t rankCount) { if (tsdOpened_) { @@ -285,7 +315,7 @@ private: return true; } - bool PrepareOpenDevice(uint32_t device, uint32_t rankCount, in_addr &deviceIp, void *&rdmaHandle) + bool PrepareOpenDevice(uint32_t device, uint32_t rankCount, in_addr &deviceIp, void *&rdmaHandle, uint32_t logicDeviceId) { // If can get rdmaHanle, maybe the device has beed opened, can try get rdmaHanle directly. if (DlHccpApi::RaRdevGetHandle(device, rdmaHandle) == 0) { @@ -304,17 +334,17 @@ private: return false; } - if (!RaInit(device)) { + if (!RaInit(logicDeviceId)) { SHM_LOG_ERROR("RaInit failed."); return false; } - if (!RetireDeviceIp(device, deviceIp)) { + if (!RetireDeviceIp(logicDeviceId, deviceIp)) { SHM_LOG_ERROR("RetireDeviceIp failed."); return false; } - if (!RaRdevInit(device, deviceIp, rdmaHandle)) { + if (!RaRdevInit(logicDeviceId, deviceIp, rdmaHandle)) { SHM_LOG_ERROR("RaRdevInit failed."); return false; } @@ -430,4 +460,5 @@ private: DeviceQpManager* qpManager_; RegMemResult localMR_; MemoryRegionMap registerMRS_; + void* rt_handle_{nullptr}; }; \ No newline at end of file -- Gitee From 61eab436742bbf0d805e5b8a3511404b0622332b Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Tue, 4 Nov 2025 15:04:56 +0800 Subject: [PATCH 50/74] Fix highlevel test. --- .../rdma_perftest/rdma_perftest_kernel.cpp | 20 ++++++++----------- src/host/mem/shmemi_heap.cpp | 2 +- src/host/transport/shmemi_transport.cpp | 8 ++++++-- .../transport/rdma/device_qp_manager.cpp | 11 ++++++---- .../transport/rdma/device_qp_manager.h | 6 +++--- src/modules/transport/shmemi_rdma.cpp | 1 + 6 files changed, 26 insertions(+), 22 deletions(-) diff --git a/examples/rdma_perftest/rdma_perftest_kernel.cpp b/examples/rdma_perftest/rdma_perftest_kernel.cpp index 33ad60cd..c3eaab96 100644 --- a/examples/rdma_perftest/rdma_perftest_kernel.cpp +++ b/examples/rdma_perftest/rdma_perftest_kernel.cpp @@ -24,9 +24,8 @@ extern "C" __global__ __aicore__ void rdma_highlevel_put_pingpong_latency(uint64 pipe.InitBuffer(buf, UB_ALIGN_SIZE); AscendC::LocalTensor ubLocalRead = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); - __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); - int64_t rank = device_state->mype; - int64_t rank_size = device_state->npes; + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); uint32_t peer; // Warm up @@ -88,9 +87,8 @@ extern "C" __global__ __aicore__ void rdma_postsend_cost(uint64_t fftsConfig, GM AscendC::LocalTensor ubLocal32 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); AscendC::LocalTensor ubLocal64 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); - __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); - int64_t rank = device_state->mype; - int64_t rank_size = device_state->npes; + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); uint32_t peer; // Actual test @@ -124,9 +122,8 @@ extern "C" __global__ __aicore__ void rdma_highlevel_put_bw(uint64_t fftsConfig, AscendC::LocalTensor ubLocal32 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); AscendC::LocalTensor ubLocal64 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); - __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); - int64_t rank = device_state->mype; - int64_t rank_size = device_state->npes; + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); uint32_t peer; // Actual test @@ -172,9 +169,8 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_AD ubLocal64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR + UB_ALIGN_SIZE); ubLocal64.address_.dataLen = UB_ALIGN_SIZE; - __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); - int64_t rank = device_state->mype; - int64_t rank_size = device_state->npes; + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); uint32_t peer; // Core 0, RDMA diff --git a/src/host/mem/shmemi_heap.cpp b/src/host/mem/shmemi_heap.cpp index d8f955b4..f45871cb 100644 --- a/src/host/mem/shmemi_heap.cpp +++ b/src/host/mem/shmemi_heap.cpp @@ -110,7 +110,7 @@ int shmem_symmetric_heap::setup_heap() if (i == mype) continue; - if (g_host_state.transport_map[local_offset + i] == 1) { + if (g_host_state.transport_map[local_offset + i] & 1) { SHMEM_CHECK_RET(aclrtReserveMemAddress(&(peer_heap_base_p2p_[i]), alloc_size, 0, nullptr, 1)); } } diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp index d542aeac..841066e5 100644 --- a/src/host/transport/shmemi_transport.cpp +++ b/src/host/transport/shmemi_transport.cpp @@ -20,7 +20,7 @@ uint64_t *host_hash_list; shmemi_host_state_t g_host_state; int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state) { - g_host_state.num_choosen_transport = 1; // now only support mte; + g_host_state.num_choosen_transport = 2; g_host_state.transport_map = (int *)calloc(g_state.npes * g_state.npes, sizeof(int)); g_host_state.pe_info = (shmemi_transport_pe_info *)calloc(g_state.npes, sizeof(shmemi_transport_pe_info)); @@ -93,6 +93,10 @@ int32_t shmemi_build_transport_map(shmemi_device_host_state_t &g_state) { } } + for (int i = 0; i < g_state.npes; i++) { + g_state.topo_list[i] = reinterpret_cast(local_map[i]); + } + g_boot_handle.allgather(local_map, g_host_state.transport_map, g_state.npes * sizeof(int), &g_boot_handle); if (local_map) free(local_map); @@ -112,7 +116,7 @@ int32_t shmemi_transport_setup_connections(shmemi_device_host_state_t &g_state) for (int i = 0; i < g_state.npes; i++) { if (i == g_state.mype) continue; - if (g_host_state.transport_map[local_offset + i] == 1) { + if (g_host_state.transport_map[local_offset + i] & 1) { shmemi_transport_pe_info_t *peer_info = (g_host_state.pe_info + i); mte_peer_list[mte_peer_num] = peer_info->dev_id; ++mte_peer_num; diff --git a/src/modules/transport/rdma/device_qp_manager.cpp b/src/modules/transport/rdma/device_qp_manager.cpp index d080c714..df2330b7 100644 --- a/src/modules/transport/rdma/device_qp_manager.cpp +++ b/src/modules/transport/rdma/device_qp_manager.cpp @@ -99,7 +99,10 @@ static constexpr uint32_t SEND_CQ_DEPTH = 8192; static constexpr uint32_t RECV_CQ_DEPTH = 128; static constexpr uint32_t MAX_SEND_WR = 8192; static constexpr uint32_t MAX_RECV_WR = 128; +static constexpr uint32_t MAX_SEND_SGE = 1; +static constexpr uint32_t MAX_RECV_SGE = 1; static constexpr uint32_t QP_MODE = 2; +static constexpr uint32_t CALLER_POLL_CQ_CSTM = 1; DeviceQpManager::~DeviceQpManager() noexcept { @@ -492,12 +495,12 @@ int DeviceQpManager::CreateOneQp(ConnQpType qpType, ConnectionChannel &channel) attr.version = 1; attr.cqAttr.sendCqDepth = SEND_CQ_DEPTH; attr.cqAttr.recvDqDepth = RECV_CQ_DEPTH; - attr.qp_attr.cap.max_recv_sge = 1; + attr.qp_attr.cap.max_send_wr = MAX_SEND_WR; + attr.qp_attr.cap.max_send_sge = MAX_SEND_SGE; attr.qp_attr.cap.max_recv_wr = MAX_RECV_WR; - attr.qp_attr.cap.max_recv_sge = 1; + attr.qp_attr.cap.max_recv_sge = MAX_RECV_SGE; attr.qp_attr.qp_type = IBV_QPT_RC; - attr.qp_attr.cap.max_send_wr = MAX_SEND_WR; - attr.data_plane_flag.bs.cq_cstm = 1; + attr.data_plane_flag.bs.cq_cstm = CALLER_POLL_CQ_CSTM; ret = DlHccpApi::RaQpAiCreate(rdmaHandle_, attr, channel.aiQpInfo, channel.qpHandles[qpType]); } else { ret = DlHccpApi::RaQpCreate(rdmaHandle_, 0, QP_MODE, channel.qpHandles[qpType]); diff --git a/src/modules/transport/rdma/device_qp_manager.h b/src/modules/transport/rdma/device_qp_manager.h index 1aacb566..52df0e0e 100644 --- a/src/modules/transport/rdma/device_qp_manager.h +++ b/src/modules/transport/rdma/device_qp_manager.h @@ -8,8 +8,8 @@ * See LICENSE in the root of the software repository for the full text of the License. */ -#ifndef MF_HYBRID_DEVICE_QP_MANAGER_H -#define MF_HYBRID_DEVICE_QP_MANAGER_H +#ifndef DEVICE_QP_MANAGER_H +#define DEVICE_QP_MANAGER_H #include #include @@ -90,4 +90,4 @@ private: std::unordered_map serverConnections_; }; -#endif // MF_HYBRID_DEVICE_QP_MANAGER_H \ No newline at end of file +#endif // DEVICE_QP_MANAGER_H \ No newline at end of file diff --git a/src/modules/transport/shmemi_rdma.cpp b/src/modules/transport/shmemi_rdma.cpp index c4057eba..36dddb6e 100644 --- a/src/modules/transport/shmemi_rdma.cpp +++ b/src/modules/transport/shmemi_rdma.cpp @@ -28,6 +28,7 @@ int shmemi_rdma_can_access_peer(int *access, shmemi_transport_pe_info_t *peer_in } else { *access = 1; } + return 0; } int shmemi_rdma_connect_peers(shmemi_transport *t, int *selected_dev_ids, int num_selected_devs, shmemi_device_host_state_t *state) { -- Gitee From 6ae922da6e8caec003cce263e8a3fdce915ac14e Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Tue, 4 Nov 2025 19:53:57 +0800 Subject: [PATCH 51/74] Runtime library load refactor. --- src/host/common/shmemi_host_types.h | 2 + src/host/mem/shmemi_global_state.cpp | 38 ++------------- src/host/shmemi_host_common.h | 1 + src/host/transport/shmemi_transport.cpp | 9 +++- src/host/utils.h | 59 +++++++++++++++++++++++ src/modules/transport/rdma/dl_hccp_def.h | 2 + src/modules/transport/rdma/rdma_manager.h | 35 +------------- src/modules/transport/shmemi_rdma.cpp | 2 + 8 files changed, 81 insertions(+), 67 deletions(-) create mode 100644 src/host/utils.h diff --git a/src/host/common/shmemi_host_types.h b/src/host/common/shmemi_host_types.h index 5039a326..38d22869 100644 --- a/src/host/common/shmemi_host_types.h +++ b/src/host/common/shmemi_host_types.h @@ -68,6 +68,8 @@ typedef struct shmemi_transport { void (*amo)(struct shmemi_transport *t, int32_t type, void *dst, void *src, size_t size, int32_t pe); void (*quiet)(struct shmemi_transport *t); void (*fence)(struct shmemi_transport *t); + int32_t logical_dev_id; + int32_t dev_id; } shmemi_transport_t; typedef struct { diff --git a/src/host/mem/shmemi_global_state.cpp b/src/host/mem/shmemi_global_state.cpp index cc350e75..34f24f05 100644 --- a/src/host/mem/shmemi_global_state.cpp +++ b/src/host/mem/shmemi_global_state.cpp @@ -13,6 +13,7 @@ #include "host/shmem_host_def.h" #include "common/shmemi_host_types.h" #include "common/shmemi_logger.h" +#include "utils.h" #define LOAD_SYM(TARGET_FUNC, FILE_HANDLE, SYMBOL_NAME) \ dlerror(); \ @@ -36,12 +37,6 @@ int (*halMemReleaseFunc)(drv_mem_handle_t *handle); int (*halMemMapFunc)(void *ptr, size_t size, size_t offset, drv_mem_handle_t *handle, uint64_t flag); int (*halMemUnmapFunc)(void *ptr); -bool g_rt_loaded = false; -static void *rt_handle; -const char *g_rt_lib_name = "libascendcl.so"; - -int (*rtGetLogicDevIdByUserDevIdFunc)(const int32_t, int32_t *const); - int32_t load_hal_library() { char *error; @@ -69,37 +64,17 @@ int32_t load_hal_library() return 0; } -int32_t load_rt_library() -{ - char *error; - std::lock_guard guard(g_mutex); - if (g_rt_loaded) { - return 0; - } - - dlerror(); - - rt_handle = dlopen(g_rt_lib_name, RTLD_NOW); - if (!rt_handle) { - fprintf(stderr, "dlopen failed: %s\n", dlerror()); - return 1; - } - - LOAD_SYM(rtGetLogicDevIdByUserDevIdFunc, rt_handle, "rtGetLogicDevIdByUserDevId"); - - g_rt_loaded = true; - return 0; -} - global_state_reigister::global_state_reigister(int device_id): device_id_{device_id} { SHMEM_CHECK(load_hal_library()); - SHMEM_CHECK(load_rt_library()); SHMEM_CHECK(halMemAddressReserveFunc(&device_ptr_, GLOBAL_STATE_SIZE, 0, (void *)(SVM_END_ADDR - GLOBAL_STATE_SIZE), 1)); int32_t logicDeviceId = -1; - SHMEM_CHECK(rtGetLogicDevIdByUserDevIdFunc(device_id_, &logicDeviceId)); + rtLibLoader& loader = rtLibLoader::getInstance(); + if (loader.isLoaded()) { + loader.getLogicDevId(device_id_, &logicDeviceId); + } drv_mem_prop memprop; memprop.side = 1; @@ -127,9 +102,6 @@ global_state_reigister::~global_state_reigister() if (hal_handle != nullptr) dlclose(hal_handle); - - if (rt_handle != nullptr) - dlclose(rt_handle); } void *global_state_reigister::get_ptr() diff --git a/src/host/shmemi_host_common.h b/src/host/shmemi_host_common.h index ed872c2f..2a6264c4 100644 --- a/src/host/shmemi_host_common.h +++ b/src/host/shmemi_host_common.h @@ -21,5 +21,6 @@ #include "sync/shmemi_sync.h" #include "bootstrap/shmemi_bootstrap.h" #include "transport/shmemi_transport.h" +#include "utils.h" #endif // SHMEM_SHMEMI_HOST_COMMON_H diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp index 841066e5..e853c09f 100644 --- a/src/host/transport/shmemi_transport.cpp +++ b/src/host/transport/shmemi_transport.cpp @@ -52,6 +52,13 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state) { int32_t device_id; SHMEM_CHECK_RET(aclrtGetDevice(&device_id)); my_info.dev_id = device_id; + int32_t logicDeviceId = -1; + rtLibLoader& loader = rtLibLoader::getInstance(); + if (loader.isLoaded()) { + loader.getLogicDevId(device_id, &logicDeviceId); + } + g_host_state.choosen_transports[1].logical_dev_id = logicDeviceId; + g_host_state.choosen_transports[1].dev_id = device_id; // AllGather All pe's host info g_boot_handle.allgather((void *)&my_info, g_host_state.pe_info, sizeof(shmemi_transport_pe_info_t), &g_boot_handle); @@ -94,7 +101,7 @@ int32_t shmemi_build_transport_map(shmemi_device_host_state_t &g_state) { } for (int i = 0; i < g_state.npes; i++) { - g_state.topo_list[i] = reinterpret_cast(local_map[i]); + g_state.topo_list[i] = static_cast(local_map[i]); } g_boot_handle.allgather(local_map, g_host_state.transport_map, g_state.npes * sizeof(int), &g_boot_handle); diff --git a/src/host/utils.h b/src/host/utils.h new file mode 100644 index 00000000..9ada47ab --- /dev/null +++ b/src/host/utils.h @@ -0,0 +1,59 @@ +#include +#include +#include "common/shmemi_logger.h" + +class rtLibLoader { +private: + void* rt_handle_; + int (*rtGetLogicDevIdByUserDevIdFunc_)(const int32_t, int32_t *const); + + rtLibLoader() : rt_handle_(nullptr) { + if (!loadLibrary()) { + SHM_LOG_ERROR("Failed to initialize rtLibLoader: could not load liba.so or foo function"); + } + } + + rtLibLoader(const rtLibLoader&) = delete; + rtLibLoader& operator=(const rtLibLoader&) = delete; + + bool loadLibrary() { + rt_handle_ = dlopen("libascendcl.so", RTLD_NOW); + if (!rt_handle_) { + SHM_LOG_ERROR("dlopen failed: " << dlerror()); + return false; + } + + *((void**)&rtGetLogicDevIdByUserDevIdFunc_) = dlsym(rt_handle_, "rtGetLogicDevIdByUserDevId"); + if (!rtGetLogicDevIdByUserDevIdFunc_) { + dlclose(rt_handle_); + rt_handle_ = nullptr; + SHM_LOG_ERROR("Unable to get info from " << "libascendcl.so" << "."); + return SHMEM_INVALID_VALUE; + } + return true; + } + +public: + static rtLibLoader& getInstance() { + static rtLibLoader instance; + return instance; + } + + void getLogicDevId(const int32_t userDeviceId, int32_t *const logicDeviceId) { + if (rtGetLogicDevIdByUserDevIdFunc_) { + rtGetLogicDevIdByUserDevIdFunc_(userDeviceId, logicDeviceId); + } else { + SHM_LOG_ERROR("rtGetLogicDevIdByUserDevIdFunc function is not available"); + } + } + + bool isLoaded() const { + return rt_handle_ != nullptr && rtGetLogicDevIdByUserDevIdFunc_ != nullptr; + } + + ~rtLibLoader() { + if (rt_handle_) { + dlclose(rt_handle_); + } + } +}; \ No newline at end of file diff --git a/src/modules/transport/rdma/dl_hccp_def.h b/src/modules/transport/rdma/dl_hccp_def.h index e149a145..4d905f83 100644 --- a/src/modules/transport/rdma/dl_hccp_def.h +++ b/src/modules/transport/rdma/dl_hccp_def.h @@ -559,6 +559,8 @@ struct TransportOptions { uint32_t protocol; hybm_role_type role; int nic; + int32_t dev_id; + int32_t logic_dev_id; }; struct TransportMemoryRegion { diff --git a/src/modules/transport/rdma/rdma_manager.h b/src/modules/transport/rdma/rdma_manager.h index 6c2ccfa2..fe75d7c9 100644 --- a/src/modules/transport/rdma/rdma_manager.h +++ b/src/modules/transport/rdma/rdma_manager.h @@ -13,10 +13,6 @@ public: rdma_manager() {} ~rdma_manager() { - if (rt_handle_) { - dlclose(rt_handle_); - rt_handle_ = nullptr; - } delete qpManager_; ClearAllRegisterMRs(); tsdOpened_ = false; @@ -38,15 +34,8 @@ public: int OpenDevice(const TransportOptions &options) { - int32_t deviceId = -1; - OpenRTLib(); - auto ret = aclrtGetDevice(&deviceId); - if (ret != 0 || deviceId < 0) { - SHM_LOG_ERROR("AclrtGetDevice() return=" << ret << ", output deviceId=" << deviceId); - return -1; - } - int32_t logicDeviceId = -1; - SHMEM_CHECK_RET(rtGetLogicDevIdByUserDevIdFunc(deviceId, &logicDeviceId)); + int32_t deviceId = options.dev_id; + int32_t logicDeviceId = options.logic_dev_id; deviceId_ = static_cast(logicDeviceId); rankId_ = options.rankId; rankCount_ = options.rankCount; @@ -182,25 +171,6 @@ public: return SHMEM_SUCCESS; } private: - bool OpenRTLib() { - if (rt_handle_) { - return true; - } - rt_handle_ = dlopen(g_rt_lib_name, RTLD_NOW); - if (!rt_handle_) { - SHM_LOG_ERROR("dlopen failed: " << dlerror()); - return false; - } - - *((void**)&rtGetLogicDevIdByUserDevIdFunc) = dlsym(rt_handle_, "rtGetLogicDevIdByUserDevId"); - if (!rtGetLogicDevIdByUserDevIdFunc) { - dlclose(rt_handle_); - rt_handle_ = nullptr; - SHM_LOG_ERROR("Unable to get info from " << "libascendcl.so" << "."); - return SHMEM_INVALID_VALUE; - } - return true; - } bool OpenTsd(uint32_t deviceId, uint32_t rankCount) { if (tsdOpened_) { @@ -460,5 +430,4 @@ private: DeviceQpManager* qpManager_; RegMemResult localMR_; MemoryRegionMap registerMRS_; - void* rt_handle_{nullptr}; }; \ No newline at end of file diff --git a/src/modules/transport/shmemi_rdma.cpp b/src/modules/transport/shmemi_rdma.cpp index 36dddb6e..7a711143 100644 --- a/src/modules/transport/shmemi_rdma.cpp +++ b/src/modules/transport/shmemi_rdma.cpp @@ -73,6 +73,8 @@ int shmemi_rdma_init(shmemi_transport *t, shmemi_device_host_state_t *state) { options.rankCount = state->npes; options.protocol = 7; options.nic = 10002; + options.dev_id = t->dev_id; + options.logic_dev_id = t->logical_dev_id; manager->OpenDevice(options); TransportMemoryRegion mr; -- Gitee From d5f61a5285d8ca4e9f35e637cb2c94dd3e605d02 Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Tue, 4 Nov 2025 20:09:39 +0800 Subject: [PATCH 52/74] Remove unnecessary example. --- examples/CMakeLists.txt | 1 - examples/rdma_test/CMakeLists.txt | 9 - examples/rdma_test/main.cpp | 211 ---------------------- examples/rdma_test/rdma_test_kernel.cpp | 69 ------- src/host/utils.h | 16 +- src/modules/transport/rdma/rdma_manager.h | 17 +- 6 files changed, 31 insertions(+), 292 deletions(-) delete mode 100644 examples/rdma_test/CMakeLists.txt delete mode 100644 examples/rdma_test/main.cpp delete mode 100644 examples/rdma_test/rdma_test_kernel.cpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 49b00f6e..4fd3339d 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -54,7 +54,6 @@ endfunction() foreach(EXAMPLE allgather - rdma_test # matmul_allreduce rdma_perftest rdma_demo diff --git a/examples/rdma_test/CMakeLists.txt b/examples/rdma_test/CMakeLists.txt deleted file mode 100644 index 5593c77a..00000000 --- a/examples/rdma_test/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -shmem_add_collective_example(rdma_test main.cpp) \ No newline at end of file diff --git a/examples/rdma_test/main.cpp b/examples/rdma_test/main.cpp deleted file mode 100644 index 9467bd78..00000000 --- a/examples/rdma_test/main.cpp +++ /dev/null @@ -1,211 +0,0 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -#include -#include -#include -#include -#include -#include "acl/acl.h" -#include "shmem_api.h" -#include "shmemi_host_common.h" - -int g_npus = 8; -const char *ipport; -int f_rank = 0; -int f_npu = 0; - -extern void qpinfo_demo(uint32_t block_dim, void* stream, uint8_t* gva, uint32_t destRankId, uint32_t qpIdx); -extern void shm_rdma_write_test_do(void* stream, uint8_t* gva, uint64_t heap_size, uint8_t* remote_gva); -extern void shm_rdma_write_test_poll_cq_do(void* stream, uint8_t* gva, uint64_t heap_size, uint8_t* remote_gva); - -int test_shmem_rdma(int rank_id, int n_ranks, uint64_t local_mem_size) -{ - // 初始化ACL和SHMEM - int32_t device_id = rank_id % g_npus + f_npu; - int status = 0; - aclrtStream stream = nullptr; - - status = aclInit(nullptr); - status = aclrtSetDevice(device_id); - status = aclrtCreateStream(&stream); - - shmem_init_attr_t *attributes; - status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); - status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); - - void* dev_ptr; - aclrtMalloc(&dev_ptr, 120 * 8, ACL_MEM_MALLOC_HUGE_FIRST); - uint64_t *xHost; - size_t totalSize = 120; - size_t elementCount = totalSize / sizeof(uint64_t); - aclrtMallocHost((void **)(&xHost), totalSize); - std::fill(xHost, xHost + elementCount, 0); - - for (uint32_t curRank = 0; curRank < n_ranks; curRank++) { - if (curRank == rank_id) { - continue; - } - qpinfo_demo(1, stream, (uint8_t*)dev_ptr + rank_id * totalSize, curRank, 0); - aclrtSynchronizeStream(stream); - sleep(1); - - aclrtMemcpy(xHost, totalSize, (uint8_t*)dev_ptr + rank_id * totalSize, totalSize, ACL_MEMCPY_DEVICE_TO_HOST); - for (uint32_t i = 0; i < totalSize / sizeof(uint64_t); i++) { - printf("GetQPInfo srcRank = %d, destRank = %d, index = %d, value = %lu\n", rank_id, curRank, i, xHost[i]); - } - } - - aclrtFreeHost(xHost); - aclrtFree(dev_ptr); - - status = shmem_finalize(); - status = aclrtDestroyStream(stream); - status = aclrtResetDevice(device_id); - status = aclFinalize(); - return 0; -} - -int test_shmem_rdma_put(int rank_id, int n_ranks, uint64_t local_mem_size, uint64_t remote_gva) -{ - // 初始化ACL和SHMEM - int32_t device_id = rank_id % g_npus + f_npu; - int status = 0; - aclrtStream stream = nullptr; - - status = aclInit(nullptr); - status = aclrtSetDevice(device_id); - status = aclrtCreateStream(&stream); - - shmem_init_attr_t *attributes; - status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); - status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); - - void* dev_ptr; - aclrtMalloc(&dev_ptr, 64 * 16, ACL_MEM_MALLOC_HUGE_FIRST); - uint32_t *xHost; - size_t messageSize = 64; - size_t totalSize = messageSize * n_ranks; - aclrtMallocHost((void **)(&xHost), totalSize); - for (uint32_t i = 0; i < messageSize / sizeof(uint32_t); i++) { - xHost[i] = rank_id; - } - - if (rank_id == 0) { - aclrtMemcpy((uint8_t*)dev_ptr + rank_id * messageSize + rank_id * messageSize, - messageSize, xHost, messageSize, ACL_MEMCPY_HOST_TO_DEVICE); - shm_rdma_write_test_do(stream, (uint8_t*)dev_ptr, messageSize, (uint8_t*)remote_gva); - if (aclrtSynchronizeStream(stream) != 0) { - std::cout << "[ERROR] aclrtSynchronizeStream failed." << std::endl; - } - } - sleep(1); - - aclrtMemcpy(xHost, totalSize, (uint8_t*)dev_ptr + rank_id * messageSize, totalSize, ACL_MEMCPY_DEVICE_TO_HOST); - if (rank_id == 1) { - for (uint32_t i = 0; i < n_ranks; i++) { - if (xHost[i * messageSize / sizeof(uint32_t)] != i) { - std::cout << "[ERROR] Put result check error at " << i << std::endl; - } - } - } - - aclrtFreeHost(xHost); - aclrtFree(dev_ptr); - - status = shmem_finalize(); - status = aclrtDestroyStream(stream); - status = aclrtResetDevice(device_id); - status = aclFinalize(); - return 0; -} - -int test_shmem_rdma_put_poll_cq(int rank_id, int n_ranks, uint64_t local_mem_size, uint64_t remote_gva) -{ - // 初始化ACL和SHMEM - int32_t device_id = rank_id % g_npus + f_npu; - int status = 0; - aclrtStream stream = nullptr; - - status = aclInit(nullptr); - status = aclrtSetDevice(device_id); - status = aclrtCreateStream(&stream); - - shmem_init_attr_t *attributes; - status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); - status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); - - void* dev_ptr = shmem_malloc(1024); - uint32_t *xHost; - size_t messageSize = 64; - size_t totalSize = messageSize * n_ranks; - aclrtMallocHost((void **)(&xHost), totalSize); - for (uint32_t i = 0; i < messageSize / sizeof(uint32_t); i++) { - xHost[i] = rank_id + 10; - } - - if (rank_id == 0) { - aclrtMemcpy((uint8_t*)dev_ptr + 128, - messageSize, xHost, messageSize, ACL_MEMCPY_HOST_TO_DEVICE); - shm_rdma_write_test_poll_cq_do(stream, (uint8_t*)dev_ptr, messageSize, (uint8_t*)dev_ptr); - if (aclrtSynchronizeStream(stream) != 0) { - std::cout << "[ERROR] aclrtSynchronizeStream failed." << std::endl; - } - } - sleep(1); - - if (rank_id == 1) { - aclrtMemcpy(xHost, totalSize, (uint8_t*)dev_ptr, totalSize, ACL_MEMCPY_DEVICE_TO_HOST); - for (uint32_t i = 0; i < n_ranks; i++) { - if ((i != rank_id) && (xHost[i * messageSize / sizeof(uint32_t)] != i + 10)) { - std::cout << "[ERROR] Put result check error at " << i << std::endl; - } - } - } else { - aclrtMemcpy(xHost, totalSize, (uint8_t*)dev_ptr, totalSize, ACL_MEMCPY_DEVICE_TO_HOST); - for (uint32_t i = 0; i < totalSize / sizeof(uint64_t); i++) { - printf("GetQPInfo srcRank = %d, index = %d, value = %lu\n", rank_id, i, ((uint64_t*)xHost)[i]); - } - } - - aclrtFreeHost(xHost); - shmem_free(dev_ptr); - - status = shmem_finalize(); - status = aclrtDestroyStream(stream); - status = aclrtResetDevice(device_id); - status = aclFinalize(); - return 0; -} - -int main(int argc, char *argv[]) -{ - int status = 0; - // 初始化MPI环境 - MPI_Init(&argc, &argv); - - // 获取当前进程的编号(rank) - int n_ranks; - MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); - int rank_id; - MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); - ipport = argv[1]; - g_npus = atoi(argv[2]); - f_rank = atoi(argv[3]); - f_npu = atoi(argv[4]); - uint64_t remote_gva = atol(argv[5]); - uint64_t local_mem_size = 1024UL * 1024UL * 1024; - // status = test_shmem_rdma_put(rank_id, n_ranks, local_mem_size, remote_gva); - status = test_shmem_rdma_put_poll_cq(rank_id, n_ranks, local_mem_size, remote_gva); - std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; - MPI_Finalize(); - - return 0; -} \ No newline at end of file diff --git a/examples/rdma_test/rdma_test_kernel.cpp b/examples/rdma_test/rdma_test_kernel.cpp deleted file mode 100644 index 9247d05b..00000000 --- a/examples/rdma_test/rdma_test_kernel.cpp +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -#ifndef _RDMA_TEST_KERNEL_ -#define _RDMA_TEST_KERNEL_ - -#include "kernel_operator.h" -#include "shmem_api.h" - -constexpr uint32_t MESSAGE_SIZE = 64; - -extern "C" __global__ __aicore__ void shm_rdma_write_qpinfo_test(GM_ADDR gva, uint32_t destRankId, uint32_t qpIdx) -{ - shmemi_roce_qpinfo_test(gva, destRankId, qpIdx); -} - -void qpinfo_demo(uint32_t block_dim, void* stream, uint8_t* gva, uint32_t destRankId, uint32_t qpIdx) -{ - shm_rdma_write_qpinfo_test<<>>(gva, destRankId, qpIdx); -} - -extern "C" __global__ __aicore__ void shm_rdma_write_test(GM_ADDR gva, uint64_t heap_size, GM_ADDR remote_gva) -{ - AscendC::TPipe pipe; - AscendC::TBuf buf; - pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); - AscendC::LocalTensor ubLocal32 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); - AscendC::LocalTensor ubLocal64 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); - auto myRank = 0; - auto totalRank = 2; - for (int i = 0; i < totalRank; i++) { - if (i == myRank) { - continue; - } - shmemi_roce_write(gva + myRank * heap_size + myRank * MESSAGE_SIZE, - remote_gva + i * heap_size + myRank * MESSAGE_SIZE, i, 0, MESSAGE_SIZE, ubLocal64, ubLocal32); - } -} - -void shm_rdma_write_test_do(void* stream, uint8_t* gva, uint64_t heap_size, uint8_t* remote_gva) -{ - shm_rdma_write_test<<<1, nullptr, stream>>>(gva, heap_size, remote_gva); -} - -extern "C" __global__ __aicore__ void shm_rdma_write_test_poll_cq(GM_ADDR gva, uint64_t heap_size, GM_ADDR remote_gva) -{ - AscendC::TPipe pipe; - AscendC::TBuf buf; - pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); - AscendC::LocalTensor ubLocal32 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); - AscendC::LocalTensor ubLocal64 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); - auto myRank = 0; - auto totalRank = 2; - shmemi_roce_pollcq_test(gva + 128, remote_gva, 1, 0, MESSAGE_SIZE, ubLocal64, ubLocal32, gva); -} - -void shm_rdma_write_test_poll_cq_do(void* stream, uint8_t* gva, uint64_t heap_size, uint8_t* remote_gva) -{ - shm_rdma_write_test_poll_cq<<<1, nullptr, stream>>>(gva, heap_size, remote_gva); -} - -#endif // _RDMA_DEMO_KERNEL_ \ No newline at end of file diff --git a/src/host/utils.h b/src/host/utils.h index 9ada47ab..0d398de9 100644 --- a/src/host/utils.h +++ b/src/host/utils.h @@ -1,3 +1,15 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef SHMEM_UTILS_H +#define SHMEM_UTILS_H + #include #include #include "common/shmemi_logger.h" @@ -56,4 +68,6 @@ public: dlclose(rt_handle_); } } -}; \ No newline at end of file +}; + +#endif // SHMEM_UTILS_H \ No newline at end of file diff --git a/src/modules/transport/rdma/rdma_manager.h b/src/modules/transport/rdma/rdma_manager.h index fe75d7c9..9f2fbbe3 100644 --- a/src/modules/transport/rdma/rdma_manager.h +++ b/src/modules/transport/rdma/rdma_manager.h @@ -1,3 +1,16 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef RDMA_MANAGER_H +#define RDMA_MANAGER_H + #include #include #include @@ -430,4 +443,6 @@ private: DeviceQpManager* qpManager_; RegMemResult localMR_; MemoryRegionMap registerMRS_; -}; \ No newline at end of file +}; + +#endif // RDMA_MANAGER_H \ No newline at end of file -- Gitee From 85830181a7fdec91f038bad1d724fc48bfe48a62 Mon Sep 17 00:00:00 2001 From: zhangyunqi05 Date: Tue, 25 Nov 2025 11:22:39 +0000 Subject: [PATCH 53/74] !462 uid reconstructed V1 * add env info in readme * Modified review comments * uid reconstructed V1 --- docs/quickstart.md | 24 +- examples/allgather/main.cpp | 51 +- examples/uid_init/CMakeLists.txt | 84 ++ examples/uid_init/main.cpp | 62 + examples/uid_init/run.sh | 21 + include/host/shmem_host_def.h | 19 + include/host/shmem_host_init.h | 13 + include/internal/host/shmemi_host_def.h | 38 + src/CMakeLists.txt | 14 + src/host/bootstrap/shmemi_bootstrap.cpp | 46 +- src/host/bootstrap/shmemi_bootstrap.h | 7 +- src/host/common/shmemi_logger.h | 40 +- src/host/init/shmem_init.cpp | 81 +- .../bootstrap/shmemi_bootstrap_uid.cpp | 1088 ++++++++++++++++- src/modules/bootstrap/socket/uid_socket.cpp | 565 +++++++++ src/modules/bootstrap/socket/uid_socket.h | 113 ++ src/modules/bootstrap/socket/uid_utils.h | 54 + 17 files changed, 2263 insertions(+), 57 deletions(-) create mode 100644 examples/uid_init/CMakeLists.txt create mode 100644 examples/uid_init/main.cpp create mode 100644 examples/uid_init/run.sh create mode 100644 include/internal/host/shmemi_host_def.h create mode 100644 src/modules/bootstrap/socket/uid_socket.cpp create mode 100644 src/modules/bootstrap/socket/uid_socket.h create mode 100644 src/modules/bootstrap/socket/uid_utils.h diff --git a/docs/quickstart.md b/docs/quickstart.md index e86b9392..7ee7020e 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -143,4 +143,26 @@ shm.set_conf_store_tls(True, tls_info) # 开启TLS认证 ```sh torchrun --nproco-per-node=k test.py // k为想运行的ranksize ``` -看到日志中打印出“test.py running success!”即为demo运行成功 \ No newline at end of file +看到日志中打印出“test.py running success!”即为demo运行成功 + +## unique id 初始化方式 + +注:使用unique id的接口初始化,可以手动配置环境变量SHMEM_UID_SESSION_ID或者SHMEM_UID_SOCK_IFNAM,同时配置时只读SHMEM_UID_SESSION_ID,都不配置会自动搜索可用网口。 +SHMEM_UID_SESSION_ID配置示例: +SHMEM_UID_SESSION_ID=127.0.0.1:1234 +SHMEM_UID_SESSION_ID=[6666:6666:6666:6666:6666:6666:6666:6666]:886 +SHMEM_UID_SESSION_ID=[6666:6666:6666:6666:6666:6666:6666:6666%eth]:886 +SHMEM_UID_SOCK_IFNAM配置示例: +SHMEM_UID_SOCK_IFNAM=enpxxxx:inet4 取ipv4 +SHMEM_UID_SOCK_IFNAM=enpxxxx:inet6 取ipv6 +不配置默认取inet4自动搜索可用网口,搜索优先级:非docker、lo>>docker>>lo。 + + +- c++初始化例子 +```cpp +shmemx_uniqueid_t uid; +shmem_init_attr_t *attr; +int ret = shmem_get_uniqueid(&uid); +shmemx_set_attr_uniqueid_args(rank, rank_size, local_mem_size, &uid, &attributes); +status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); +``` \ No newline at end of file diff --git a/examples/allgather/main.cpp b/examples/allgather/main.cpp index bbb42b84..63489423 100644 --- a/examples/allgather/main.cpp +++ b/examples/allgather/main.cpp @@ -50,21 +50,10 @@ extern void allgather_demo(uint32_t block_dim, void *stream, uint64_t fftsAddr, uint8_t *gva, int elements, int magic); template -int test_shmem_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size) +int test_shmem_all_gather(int rank_id, int n_ranks, aclrtStream stream) { // 初始化ACL和SHMEM - int32_t device_id = rank_id % g_npus + f_npu; int status = 0; - aclrtStream stream = nullptr; - - status = aclInit(nullptr); - status = aclrtSetDevice(device_id); - status = aclrtCreateStream(&stream); - - shmem_init_attr_t *attributes; - status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); - status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); - // Prepare FFTS address uint64_t fftsAddr = shmemx_get_ffts_config(); @@ -163,11 +152,6 @@ int test_shmem_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size) } outFile.close(); - - status = shmem_finalize(); - status = aclrtDestroyStream(stream); - status = aclrtResetDevice(device_id); - status = aclFinalize(); return 0; } @@ -180,16 +164,41 @@ int main(int argc, char *argv[]) MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmem_get_uniqueid(&uid); + } + std::cout << "MPI_Bcast!" << std::endl; + MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + status = aclInit(nullptr); + + int32_t device_id = rank_id % g_npus + f_npu; + status = aclrtSetDevice(device_id); + + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); + + shmem_init_attr_t *attributes; + shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, + &uid, + &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + if (std::string(data_type) == "int") { - status = test_shmem_all_gather(rank_id, n_ranks, local_mem_size); + status = test_shmem_all_gather(rank_id, n_ranks, stream); } else if (std::string(data_type) == "int32_t") { - status = test_shmem_all_gather(rank_id, n_ranks, local_mem_size); + status = test_shmem_all_gather(rank_id, n_ranks, stream); } else if (std::string(data_type) == "float16_t") { - status = test_shmem_all_gather(rank_id, n_ranks, local_mem_size); + status = test_shmem_all_gather(rank_id, n_ranks, stream); } else if (std::string(data_type) == "bfloat16_t") { - status = test_shmem_all_gather(rank_id, n_ranks, local_mem_size); + status = test_shmem_all_gather(rank_id, n_ranks, stream); } + status = shmem_finalize(); + status = aclrtDestroyStream(stream); + status = aclrtResetDevice(device_id); + status = aclFinalize(); if (status) { std::exit(EXIT_FAILURE); } diff --git a/examples/uid_init/CMakeLists.txt b/examples/uid_init/CMakeLists.txt new file mode 100644 index 00000000..954b552c --- /dev/null +++ b/examples/uid_init/CMakeLists.txt @@ -0,0 +1,84 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +cmake_minimum_required(VERSION 3.18) +project(SHMEM) + +# 设置C++标准 +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# 生成位置无关代码 +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +# 设置可执行文件输出目录 +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) + +# 设置安装路径 +set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_SOURCE_DIR}/install/shmem) + +# 获取CANN相关环境变量 +if(NOT DEFINED ENV{ASCEND_HOME_PATH}) + message(FATAL_ERROR "Cannot find ASCEND_HOME_PATH, please run set_env.sh.") +else() + set(ASCEND_HOME_PATH $ENV{ASCEND_HOME_PATH}) +endif() + +set(CMAKE_COMPILER bisheng) +set(CMAKE_CXX_COMPILER ${CMAKE_COMPILER}) + +add_compile_options( + -O2 -std=c++17 + -Wno-macro-redefined -Wno-ignored-attributes + # avoid ascendc interference + -DL2_CACHE_HINT + -DTILING_KEY_VAR +) + +set(CMAKE_CPP_COMPILE_OPTIONS + -xc++ + "SHELL:-include stdint.h" + "SHELL:-include stddef.h" +) + +include_directories( + ${ASCEND_HOME_PATH}/compiler/tikcpp + ${ASCEND_HOME_PATH}/compiler/tikcpp/tikcfw + ${ASCEND_HOME_PATH}/compiler/tikcpp/tikcfw/impl + ${ASCEND_HOME_PATH}/compiler/tikcpp/tikcfw/interface + ${ASCEND_HOME_PATH}/include + ${ASCEND_HOME_PATH}/include/experiment/runtime + ${ASCEND_HOME_PATH}/include/experiment/msprof +) + +link_directories( + ${ASCEND_HOME_PATH}/lib64 +) + +link_libraries(runtime stdc++ ascendcl m tiling_api platform c_sec dl nnopbase pthread) + +find_package(MPI REQUIRED) + +include_directories( + ${MPI_INCLUDE_PATH} + ${ASCEND_HOME_PATH}/lib64 +) + +link_directories( + ${CMAKE_CURRENT_SOURCE_DIR}/../../install/shmem/lib +) + +add_executable(helloword main.cpp) + +target_include_directories(helloword PRIVATE + ${ASCEND_HOME_PATH}/include/ + ${CMAKE_CURRENT_SOURCE_DIR}/../../install/shmem/include +) + +target_link_libraries(helloword PRIVATE MPI::MPI_CXX) +target_link_libraries(helloword PRIVATE shmem) \ No newline at end of file diff --git a/examples/uid_init/main.cpp b/examples/uid_init/main.cpp new file mode 100644 index 00000000..62551db2 --- /dev/null +++ b/examples/uid_init/main.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include +#include +#include "shmem_api.h" +int main(int argc, char* argv[]) +{ + // 初始化MPI环境 + MPI_Init(&argc, &argv); + + // 获取当前进程的编号(rank) + int rank; + int rank_size; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &rank_size); + int status = SHMEM_SUCCESS; + aclInit(nullptr); + aclrtSetDevice(rank); + shmem_init_attr_t *attributes; + // status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); + shmemx_uniqueid_t uid; + // shmem_init_attr_t *attr; + if (rank == 0) { + std::cout << "shmem_get_uniqueid" < +#include + +typedef enum { + ADDR_IPv4, + ADDR_IPv6 +} addr_type_t; + +typedef struct { + union { + struct sockaddr sa; + struct sockaddr_in addr4; // IPv4地址(含端口) + struct sockaddr_in6 addr6; // IPv6地址(含端口) + } addr; + addr_type_t type; +} sockaddr_t; + +typedef struct { + int32_t version; + sockaddr_t addr; // 动态传入的地址(含端口) + uint64_t magic; + int rank; + int nranks; +} shmemx_bootstrap_uid_state_t; +#endif // SHMEMI_HOST_DEF_H \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 136c47d3..93b37f69 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -85,6 +85,20 @@ if(SHMEM_MPI_SUPPORT) ) endif() +# UID +add_library(shmem_bootstrap_uid SHARED) + +target_sources(shmem_bootstrap_uid PRIVATE modules/bootstrap/socket/uid_socket.cpp + modules/bootstrap/shmemi_bootstrap_uid.cpp) +target_include_directories(shmem_bootstrap_uid + PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src/host +) +set_target_properties(shmem_bootstrap_uid PROPERTIES PREFIX "") +install(TARGETS shmem_bootstrap_uid + LIBRARY DESTINATION lib) + set(SHMEM_RDMA_SUPPORT ON) if(SHMEM_RDMA_SUPPORT) diff --git a/src/host/bootstrap/shmemi_bootstrap.cpp b/src/host/bootstrap/shmemi_bootstrap.cpp index 63e64590..9ff0c5fa 100644 --- a/src/host/bootstrap/shmemi_bootstrap.cpp +++ b/src/host/bootstrap/shmemi_bootstrap.cpp @@ -14,6 +14,7 @@ #define BOOTSTRAP_MODULE_UID "shmem_bootstrap_uid.so" #define BOOTSTRAP_PLUGIN_INIT_FUNC "shmemi_bootstrap_plugin_init" +#define BOOTSTRAP_PLUGIN_PREINIT_FUNC "shmemi_bootstrap_plugin_pre_init" shmemi_bootstrap_handle_t g_boot_handle; @@ -35,10 +36,7 @@ int bootstrap_loader_finalize(shmemi_bootstrap_handle_t *handle) return 0; } -// for UID -int32_t shmemi_bootstrap_pre_init() { -} void shmemi_bootstrap_loader() { @@ -63,15 +61,51 @@ void shmemi_bootstrap_free() } } -int32_t shmemi_bootstrap_init(int flags, shmemi_bootstrap_attr_t *attr) { +// rank0 requires preloading uid.so to obtain the getuid capability +int32_t shmemi_bootstrap_pre_init(int flags, shmemi_bootstrap_handle_t *handle) { + int32_t status = SHMEM_SUCCESS; + + if (flags & SHMEMX_INIT_WITH_MPI) { + SHM_LOG_ERROR("Unsupport Type for bootstrap preinit."); + return SHMEM_INVALID_PARAM; + } else if (flags & SHMEMX_INIT_WITH_UNIQUEID) { + plugin_name = BOOTSTRAP_MODULE_UID; + } else { + SHM_LOG_ERROR("Unknown Type for bootstrap"); + status = SHMEM_INVALID_PARAM; + } + shmemi_bootstrap_loader(); + + if (!plugin_hdl) { + SHM_LOG_ERROR("Bootstrap unable to load " << plugin_name << ", err is: " << stderr); + shmemi_bootstrap_free(); + return SHMEM_INVALID_VALUE; + } + int (*plugin_pre_init)(shmemi_bootstrap_handle_t *); + *((void **)&plugin_pre_init) = dlsym(plugin_hdl, BOOTSTRAP_PLUGIN_PREINIT_FUNC); + if (!plugin_pre_init) { + SHM_LOG_ERROR("Bootstrap plugin init func dlsym failed"); + shmemi_bootstrap_free(); + return SHMEM_INNER_ERROR; + } + status = plugin_pre_init(&g_boot_handle); + if (status != 0) { + SHM_LOG_ERROR("Bootstrap plugin init failed for " << plugin_name); + shmemi_bootstrap_free(); + return SHMEM_INNER_ERROR; + } + return status; +} + +int32_t shmemi_bootstrap_init(int flags, shmem_init_attr_t *attr) { int32_t status = SHMEM_SUCCESS; void *arg; if (flags & SHMEMX_INIT_WITH_MPI) { plugin_name = BOOTSTRAP_MODULE_MPI; - arg = (attr != NULL) ? attr->mpi_comm : NULL; + arg = (attr != NULL) ? attr->comm_args : NULL; } else if (flags & SHMEMX_INIT_WITH_UNIQUEID) { plugin_name = BOOTSTRAP_MODULE_UID; - status = shmemi_bootstrap_pre_init(); + arg = (attr != NULL) ? attr->comm_args : NULL; } else { SHM_LOG_ERROR("Unknown Type for bootstrap"); status = SHMEM_INVALID_PARAM; diff --git a/src/host/bootstrap/shmemi_bootstrap.h b/src/host/bootstrap/shmemi_bootstrap.h index a3ad3637..c5e56e56 100644 --- a/src/host/bootstrap/shmemi_bootstrap.h +++ b/src/host/bootstrap/shmemi_bootstrap.h @@ -10,13 +10,16 @@ #ifndef SHMEMI_BOOTSTRAP_H #define SHMEMI_BOOTSTRAP_H +#include "shmem_api.h" #ifdef __cplusplus extern "C" { #endif -int32_t shmemi_bootstrap_pre_init(); +int32_t shmemi_bootstrap_plugin_pre_init(shmemi_bootstrap_handle_t *handle); -int32_t shmemi_bootstrap_init(int flags, shmemi_bootstrap_attr_t *attr); +int32_t shmemi_bootstrap_pre_init(int flags, shmemi_bootstrap_handle_t *handle); + +int32_t shmemi_bootstrap_init(int flags, shmem_init_attr_t *attr); void shmemi_bootstrap_finalize(); diff --git a/src/host/common/shmemi_logger.h b/src/host/common/shmemi_logger.h index a5bbf3fc..3bc24a7b 100644 --- a/src/host/common/shmemi_logger.h +++ b/src/host/common/shmemi_logger.h @@ -172,22 +172,46 @@ private: } \ } while (0) -#define SHMEM_CHECK_RET(x) \ +#define SHMEM_CHECK(x) \ do { \ int32_t check_ret = x; \ if (check_ret != 0) { \ SHM_LOG_ERROR(" return shmem error: " << check_ret); \ - return check_ret; \ + return ; \ } \ } while (0) -#define SHMEM_CHECK(x) \ - do { \ - int32_t check_ret = x; \ - if (check_ret != 0) { \ +#define SHMEM_CHECK_RET(...) \ + _SHMEM_CHECK_RET_HELPER(__VA_ARGS__, _SHMEM_CHECK_RET_WITH_LOG_AND_ERR_CODE, _SHMEM_CHECK_RET_WITH_LOG, _SHMEM_CHECK_RET)(__VA_ARGS__) + +#define _SHMEM_CHECK_RET(x) \ + do { \ + int32_t check_ret = (x); \ + if (check_ret != 0) { \ SHM_LOG_ERROR(" return shmem error: " << check_ret); \ - return ; \ - } \ + return check_ret; \ + } \ + } while (0) + +#define _SHMEM_CHECK_RET_WITH_LOG(x, log_str) \ + do { \ + int32_t check_ret = (x); \ + if (check_ret != 0) { \ + SHM_LOG_ERROR(" " << log_str << " return shmem error: " << check_ret); \ + return check_ret; \ + } \ } while (0) +#define _SHMEM_CHECK_RET_WITH_LOG_AND_ERR_CODE(x, log_str, error_code) \ + do { \ + int32_t check_ret = (x); \ + if (check_ret != 0) { \ + SHM_LOG_ERROR(" " << log_str << " return shmem error: " << error_code); \ + return error_code; \ + } \ + } while (0) + +#define _SHMEM_CHECK_RET_HELPER(_1, _2, _3, FUNC, ...) FUNC + + #endif // SHMEM_SHM_OUT_LOGGER_H diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index b4a36cf2..54dba01c 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -19,6 +19,7 @@ #include "acl/acl.h" #include "shmemi_host_common.h" +#include "internal/host/shmemi_host_def.h" using namespace std; @@ -204,9 +205,27 @@ int32_t shmem_init_status() return SHMEM_STATUS_INVALID; } +int shmemx_set_attr_uniqueid_args(const int my_rank, const int n_ranks, const int64_t local_mem_size, + const shmemx_uniqueid_t *uid, + shmem_init_attr_t **shmem_attr) { + /* Save to uid_args */ + *shmem_attr = &g_attr; + shmemx_bootstrap_uid_state_t *uid_args = (shmemx_bootstrap_uid_state_t *)(uid); + uid_args->rank = my_rank; + uid_args->nranks = n_ranks; + void * comm_args = reinterpret_cast(uid_args); + g_attr.comm_args = comm_args; + g_attr.my_rank = my_rank; + g_attr.n_ranks = n_ranks; + g_attr.local_mem_size = local_mem_size; + + return SHMEM_SUCCESS; +} + int32_t shmem_init_attr(shmemx_bootstrap_t bootstrap_flags, shmem_init_attr_t *attributes) { int32_t ret; + shmem_set_log_level(shm::ERROR_LEVEL); // config init SHM_ASSERT_RETURN(attributes != nullptr, SHMEM_INVALID_PARAM); @@ -216,7 +235,7 @@ int32_t shmem_init_attr(shmemx_bootstrap_t bootstrap_flags, shmem_init_attr_t *a // bootstrap init shmemi_bootstrap_attr_t attr = {}; - SHMEM_CHECK_RET(shmemi_bootstrap_init(bootstrap_flags, &attr)); + SHMEM_CHECK_RET(shmemi_bootstrap_init(bootstrap_flags, attributes)); // shmem basic init #ifdef BACKEND_MF @@ -267,4 +286,62 @@ void shmem_info_get_name(char *name) name[i] = version_str[i]; } name[i] = '\0'; -} \ No newline at end of file +} + +int32_t shmem_get_uniqueid_mf(shmemx_uniqueid_t *uid) +{ + return SHMEM_SUCCESS; +} + + +int32_t shmem_get_uniqueid_default(shmemx_uniqueid_t *uid) +{ + int status = 0; + SHMEM_CHECK_RET(shmemi_options_init(), "Bootstrap failed during the preloading step."); + SHMEM_CHECK_RET(shmemi_bootstrap_pre_init(SHMEMX_INIT_WITH_UNIQUEID, &g_boot_handle), "Get uniqueid failed during the bootstrap preloading step."); + + if (g_boot_handle.pre_init_ops) { + SHMEM_CHECK_RET(g_boot_handle.pre_init_ops->get_unique_id((void *)uid), "Get uniqueid failed during the get uniqueid step."); + } else { + SHM_LOG_ERROR("Pre_init_ops is empty, unique_id cannot be obtained."); + status = SHMEM_INVALID_PARAM; + } + + return (status); +} + +int32_t shmem_get_uniqueid(shmemx_uniqueid_t *uid){ + shmem_set_log_level(shm::ERROR_LEVEL); +#ifdef BACKEND_MF + SHMEM_CHECK_RET(shmem_get_uniqueid_mf(uid), "shmem_get_uniqueid failed, backend: mf"); + return SHMEM_SUCCESS; +#else + SHMEM_CHECK_RET(shmem_get_uniqueid_default(uid), "shmem_get_uniqueid failed, backend: default"); + return SHMEM_SUCCESS; +#endif +} + +int32_t shmem_set_log_level(int level) +{ + // use env first, input level secondly, user may change level from env instead call func + const char *in_level = std::getenv("SHMEM_LOG_LEVEL"); + if (in_level != nullptr) { + auto tmp_level = std::string(in_level); + if (tmp_level == "DEBUG") { + level = shm::DEBUG_LEVEL; + } else if (tmp_level == "INFO") { + level = shm::INFO_LEVEL; + } else if (tmp_level == "WARN") { + level = shm::WARN_LEVEL; + } else if (tmp_level == "ERROR") { + level = shm::ERROR_LEVEL; + } else if (tmp_level == "FATAL") { + level = shm::FATAL_LEVEL; + } + } + #ifdef BACKEND_MF + smem_set_log_level(level); + #endif + + return shm::shm_out_logger::Instance().set_log_level(static_cast(level)); +} diff --git a/src/modules/bootstrap/shmemi_bootstrap_uid.cpp b/src/modules/bootstrap/shmemi_bootstrap_uid.cpp index a3f6249e..b3e4cde3 100644 --- a/src/modules/bootstrap/shmemi_bootstrap_uid.cpp +++ b/src/modules/bootstrap/shmemi_bootstrap_uid.cpp @@ -7,36 +7,1090 @@ * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. * See LICENSE in the root of the software repository for the full text of the License. */ -#ifdef SHMEM_BOOTSTRAP_UID -#include "shmemi_bootstrap.h" -typedef struct { - char ifname[64]; - int af; - int32_t ip, port; -} shmemi_bootstrap_uid_state_t; +#include +#include +#include +#include +#include +#include +#include +#include "socket/uid_socket.h" +#include "socket/uid_utils.h" -static shmemi_bootstrap_uid_state_t shmemi_bootstrap_uid_state; -int shmemi_bootstrap_plugin_init(void *args, shmemi_bootstrap_handle_t *boot_handle) { - // INIT - handle->allgather = shmemi_bootstrap_uid_allgather; - handle->barrier = shmemi_bootstrap_uid_allgather; - handle->finalize = shmemi_bootstrap_uid_finalize; +#define SHMEM_UNIQUEID_INITIALIZER \ + { \ + SHMEM_UNIQUEID_VERSION, \ + { \ + 0 \ + } \ + } \ + + +#define MAX_ATTEMPTS 500 +#define MAX_IFCONFIG_LENGTH 23 +#define MAX_IP 48 +#define DEFAULT_IFNAME_LNEGTH 4 +#define BOOTSTRAP_IN_PLACE (void*)0x1 +#define SOCKET_MAGIC 0x243ab9f2fc4b9d6cULL + +static const char* env_ip_port = nullptr; +static const char* env_ifname = nullptr; +static shmemx_bootstrap_uid_state_t shmemi_bootstrap_uid_state; +static struct bootstrap_netstate priv_info; + +bool is_ipv6_loopback(const struct in6_addr *addr6) { + static const struct in6_addr loopback6 = IN6ADDR_LOOPBACK_INIT; + return memcmp(addr6, &loopback6, sizeof(struct in6_addr)) == 0; +} + +bool is_ipv4_loopback(const struct in_addr *addr4) { + return ((ntohl(addr4->s_addr) >> 24) & 0xFF) == IN_LOOPBACKNET; +} + +static int32_t shmemi_get_uid_magic(shmemx_bootstrap_uid_state_t *innerUId) +{ + std::ifstream urandom("/dev/urandom", std::ios::binary); + if (!urandom) { + SHM_LOG_ERROR("open random failed"); + return SHMEM_INNER_ERROR; + } + + urandom.read(reinterpret_cast(&innerUId->magic), sizeof(innerUId->magic)); + if (urandom.fail()) { + SHM_LOG_ERROR("read random failed."); + return SHMEM_INNER_ERROR; + } + SHM_LOG_DEBUG("init magic id to " << innerUId->magic); + return SHMEM_SUCCESS; +} + + +static int32_t shmemi_uid_parse_interface_with_type(const char *ipInfo, char *IP, sa_family_t &sockType, bool &flag) +{ + const char *delim = ":"; + const char *sep = strchr(ipInfo, delim[0]); + if (sep != nullptr) { + size_t leftLen = sep - ipInfo; + if (leftLen >= MAX_IFCONFIG_LENGTH - 1 || leftLen == 0) { + SHM_LOG_ERROR("Invalid interface prefix length: " << leftLen); + return SHMEM_INVALID_VALUE; + } + strncpy(IP, ipInfo, leftLen); + IP[leftLen] = '\0'; + sockType = (strcmp(sep + 1, "inet6") != 0) ? AF_INET : AF_INET6; + flag = true; + SHM_LOG_INFO("Parse ipInfo success: ifaPrefix=" << IP << ", sockType=" << (sockType == AF_INET ? "IPv4" : "IPv6")); + } + return SHMEM_SUCCESS; +} + +int32_t shmemi_traverse_ifa( + struct ifaddrs *ifaddr, + sa_family_t &sockType, + bool flag, + const char **prefixes, + bool exclude, + shmemx_bootstrap_uid_state_t *uid_args, + bool skipStateCheck = false +) { + for (struct ifaddrs *ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next) { + if (ifa->ifa_addr == nullptr) continue; + + bool match = false; + const char **p = prefixes; + + while (*p != nullptr) { + if (**p == '\0') { + p++; + continue; + } + size_t prefix_len = strlen(*p); + size_t ifname_len = strlen(ifa->ifa_name); + if (ifname_len < prefix_len) { + p++; + continue; + } + if (strncmp(ifa->ifa_name, *p, prefix_len) == 0) { + match = true; + break; + } + p++; + } + if (exclude && match) continue; + if (!exclude && !match) continue; + + if (!skipStateCheck && (!(ifa->ifa_flags & IFF_UP) || !(ifa->ifa_flags & IFF_RUNNING))) continue; + + if (flag) { + if (ifa->ifa_addr->sa_family != sockType) { + SHM_LOG_DEBUG("Protocol family not match (flag=true), interface: " << ifa->ifa_name << ", get: " << ifa->ifa_addr->sa_family << ", expect: "<< sockType); + continue; + } + } + + if (ifa->ifa_addr->sa_family == AF_INET && (sockType == AF_UNSPEC || sockType == AF_INET)) { + memset(&uid_args->addr.addr.addr4, 0, sizeof(struct sockaddr_in)); + uid_args->addr.type = ADDR_IPv4; + uid_args->addr.addr.addr4 = *(struct sockaddr_in *)ifa->ifa_addr; + uid_args->addr.addr.addr4.sin_port = 0; + sockType = AF_INET; + SHM_LOG_INFO("Assign IPv4 from interface: " << ifa->ifa_name); + return SHMEM_SUCCESS; + } + + if (ifa->ifa_addr->sa_family == AF_INET6 && (sockType == AF_UNSPEC || sockType == AF_INET6)) { + memset(&uid_args->addr.addr.addr6, 0, sizeof(struct sockaddr_in6)); + uid_args->addr.type = ADDR_IPv6; + uid_args->addr.addr.addr6 = *(struct sockaddr_in6 *)ifa->ifa_addr; + uid_args->addr.addr.addr6.sin6_port = 0; + uid_args->addr.addr.addr6.sin6_flowinfo = 0; + + sockType = AF_INET6; + SHM_LOG_INFO("Assign IPv6 from interface: " << ifa->ifa_name <<" scope_id: " << uid_args->addr.addr.addr6.sin6_scope_id); + return SHMEM_SUCCESS; + } + } + return SHMEM_INVALID_PARAM; +} +int32_t shmemi_get_ip_from_ifa(shmemx_bootstrap_uid_state_t *uid_args, const char *ipInfo) { + if (uid_args == nullptr) { + SHM_LOG_ERROR("uid_args is nullptr"); + return SHMEM_INVALID_PARAM; + } + + struct ifaddrs *ifaddr = nullptr; + char ifaPrefix[MAX_IFCONFIG_LENGTH] = {0}; + bool flag = false; + sa_family_t sockType = AF_INET; + bool foundValidIp = false; + + shmemi_get_uid_magic(uid_args); + + bool isIpInfoConfigured = (ipInfo != nullptr && strlen(ipInfo) > 0); + if (isIpInfoConfigured) { + int32_t ret = shmemi_uid_parse_interface_with_type(ipInfo, ifaPrefix, sockType, flag); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("Parse ipInfo failed, ret: " << ret); + return ret; + } + } + + if (getifaddrs(&ifaddr) == -1) { + SHM_LOG_ERROR("getifaddrs failed: " << strerror(errno)); + return SHMEM_INVALID_PARAM; + } + + if (isIpInfoConfigured) { + const char *specifiedPrefixes[] = {ifaPrefix, nullptr}; + SHM_LOG_INFO("Search interface with specified prefix: " << ifaPrefix); + foundValidIp = (shmemi_traverse_ifa(ifaddr, sockType, flag, specifiedPrefixes, false, uid_args) == SHMEM_SUCCESS); + } else { + const char *excludePrefixes[] = {"docker", "lo", nullptr}; + const char *dockerPrefixes[] = {"docker", nullptr}; + const char *loPrefixes[] = {"lo", nullptr}; + + SHM_LOG_INFO("Step 1: Search interfaces exclude 'docker' and 'lo'"); + foundValidIp = (shmemi_traverse_ifa(ifaddr, sockType, flag, excludePrefixes, true, uid_args) == SHMEM_SUCCESS); + + if (!foundValidIp) { + SHM_LOG_WARN("Step 2: Search interfaces match 'docker'"); + foundValidIp = (shmemi_traverse_ifa(ifaddr, sockType, flag, dockerPrefixes, false, uid_args) == SHMEM_SUCCESS); + } + + if (!foundValidIp) { + SHM_LOG_WARN("Step 3: Search interfaces match 'lo' (skip state check)"); + foundValidIp = (shmemi_traverse_ifa(ifaddr, sockType, flag, loPrefixes, false, uid_args, true) == SHMEM_SUCCESS); + } + } + + if (!foundValidIp) { + SHM_LOG_ERROR("Failed to get any valid IP address from interfaces"); + freeifaddrs(ifaddr); + return SHMEM_INVALID_PARAM; + } + + freeifaddrs(ifaddr); + SHM_LOG_INFO("Assign IP/Port from interface success"); + return SHMEM_SUCCESS; +} + +int32_t shmemi_get_ip_from_env(shmemx_bootstrap_uid_state_t *uid_args, const char *ipPort) { + if (uid_args == nullptr || ipPort == nullptr || strlen(ipPort) == 0) { + SHM_LOG_ERROR("Invalid param: uid_args is null or ipPort is empty"); + return SHMEM_INVALID_PARAM; + } + + shmemi_get_uid_magic(uid_args); + SHM_LOG_DEBUG("get env SHMEM_UID_SESSION_ID value: " << ipPort); + std::string ipPortStr = ipPort; + + if (ipPort[0] == '[') { + size_t bracket_end = ipPortStr.find_last_of(']'); + if (bracket_end == std::string::npos || ipPortStr.length() - bracket_end <= 1) { + SHM_LOG_ERROR("Invalid IPv6 format: no closing ']'"); + return SHMEM_INVALID_PARAM; + } + + std::string ip_with_scope = ipPortStr.substr(1, bracket_end - 1); + size_t scope_sep = ip_with_scope.find('%'); + std::string ipStr; + std::string if_name; + + memset(&uid_args->addr.addr.addr6, 0, sizeof(struct sockaddr_in6)); + uid_args->addr.type = ADDR_IPv6; + uid_args->addr.addr.addr6.sin6_family = AF_INET6; + + if (scope_sep != std::string::npos) { + ipStr = ip_with_scope.substr(0, scope_sep); + if_name = ip_with_scope.substr(scope_sep + 1); + uid_args->addr.addr.addr6.sin6_scope_id = if_nametoindex(if_name.c_str()); + if (uid_args->addr.addr.addr6.sin6_scope_id == 0) { + SHM_LOG_WARN("Interface " << if_name.c_str() << "not found, scope_id set to 0"); + } + } else { + ipStr = ip_with_scope; + uid_args->addr.addr.addr6.sin6_scope_id = 0; + } + + std::string portStr = ipPortStr.substr(bracket_end + 2); + if (portStr.empty()) { + SHM_LOG_ERROR("IPv6 port is empty"); + return SHMEM_INVALID_PARAM; + } + uint16_t port = static_cast(std::stoi(portStr)); + uid_args->addr.addr.addr6.sin6_port = htons(port); + uid_args->addr.addr.addr6.sin6_flowinfo = 0; + + if (inet_pton(AF_INET6, ipStr.c_str(), &uid_args->addr.addr.addr6.sin6_addr) <= 0) { + SHM_LOG_ERROR("inet_pton IPv6 failed: " << strerror(errno)); + return SHMEM_NOT_INITED; + } + } else { + size_t colon_pos = ipPortStr.find_last_of(':'); + if (colon_pos == std::string::npos || ipPortStr.length() - colon_pos <= 1) { + SHM_LOG_ERROR("Invalid IPv4 format: no colon separator"); + return SHMEM_INVALID_PARAM; + } + + std::string ipStr = ipPortStr.substr(0, colon_pos); + std::string portStr = ipPortStr.substr(colon_pos + 1); + + memset(&uid_args->addr.addr.addr4, 0, sizeof(struct sockaddr_in)); + uid_args->addr.type = ADDR_IPv4; + uid_args->addr.addr.addr4.sin_family = AF_INET; + uint16_t port = static_cast(std::stoi(portStr)); + uid_args->addr.addr.addr4.sin_port = htons(port); + + if (inet_pton(AF_INET, ipStr.c_str(), &uid_args->addr.addr.addr4.sin_addr) <= 0) { + SHM_LOG_ERROR("inet_pton IPv4 failed: " << strerror(errno)); + return SHMEM_NOT_INITED; + } + } + + SHM_LOG_INFO("Assign IP/Port from env success"); + return SHMEM_SUCCESS; +} + +int32_t shmemi_set_ip_info(void *uid, sa_family_t &sockType, char *pta_env_ip, uint16_t pta_env_port, + bool is_from_ifa) +{ + // init default uid + shmemx_bootstrap_uid_state_t *innerUID = (shmemx_bootstrap_uid_state_t *)(uid); + SHM_LOG_INFO(" ENV IP: " << pta_env_ip << " ENV port: " << pta_env_port << " sockType: " << sockType); + SHMEM_CHECK_RET(shmemi_get_uid_magic(innerUID)); + + // fill ip port as part of uid + uint16_t port = 0; + if (is_from_ifa) { + SHM_LOG_DEBUG("Automatically obtain the value of port. port: " << port); + } else { + port = pta_env_port; + SHM_LOG_DEBUG("Get the port from the environment variable. port: " << port); + } + + if (sockType == AF_INET) { + SHM_LOG_INFO("SockType is AF_INET."); + innerUID->addr.addr.addr4.sin_family = AF_INET; + if (inet_pton(AF_INET, pta_env_ip, &(innerUID->addr.addr.addr4.sin_addr)) <= 0) { + SHM_LOG_ERROR("inet_pton IPv4 failed"); + return SHMEM_NOT_INITED; + } + innerUID->addr.addr.addr4.sin_port = htons(port); + innerUID->addr.type = ADDR_IPv4; + } else if (sockType == AF_INET6) { + SHM_LOG_INFO("SockType is AF_INET6."); + innerUID->addr.addr.addr6.sin6_family = AF_INET6; + if (inet_pton(AF_INET6, pta_env_ip, &(innerUID->addr.addr.addr6.sin6_addr)) <= 0) { + SHM_LOG_ERROR("inet_pton IPv6 failed"); + return SHMEM_NOT_INITED; + } + innerUID->addr.addr.addr6.sin6_port = htons(port); + innerUID->addr.type = ADDR_IPv6; + } else { + SHM_LOG_ERROR("IP Type is not IPv4 or IPv6"); + return SHMEM_INVALID_PARAM; + } + SHM_LOG_INFO("gen unique id success."); + return SHMEM_SUCCESS; +} + +static int shmemi_bootstrap_uid_finalize(shmemi_bootstrap_handle_t *handle) { + if (!handle) { + return SHMEM_SUCCESS; + } + + if (handle->bootstrap_state) { + uid_bootstrap_state* state = (uid_bootstrap_state*) handle->bootstrap_state; + + socket_close(&state->listen_sock); + socket_close(&state->ring_send_sock); + socket_close(&state->ring_recv_sock); + + SHMEM_BOOTSTRAP_PTR_FREE(state->peer_addrs); + state->peer_addrs = nullptr; + SHMEM_BOOTSTRAP_PTR_FREE(state); + handle->bootstrap_state = nullptr; + } + + if (handle->pre_init_ops) { + SHMEM_BOOTSTRAP_PTR_FREE(handle->pre_init_ops); + handle->pre_init_ops = nullptr; + } + + return SHMEM_SUCCESS; +} + + +static int shmemi_bootstrap_uid_allgather(const void *in, void *out, int len, shmemi_bootstrap_handle_t *handle) { + if (!in || !out || !handle || !handle->bootstrap_state) { + SHM_LOG_ERROR("bootstrap allgather: invalid arguments."); + return SHMEM_BOOTSTRAP_ERROR; + } + + uid_bootstrap_state* state = (uid_bootstrap_state*) handle->bootstrap_state; + int rank = state->rank; + int nranks = state->nranks; + char* send_buf = (char*)in; + + if (state->ring_send_sock.state != SOCKET_STATE_READY || + state->ring_recv_sock.state != SOCKET_STATE_READY) { + SHM_LOG_ERROR("bootstrap allgather: rank " << rank << ": sockets not ready for allgather"); + return SHMEM_BOOTSTRAP_ERROR; + } + + if (in != BOOTSTRAP_IN_PLACE) { + memcpy((char*)out + (rank % nranks) * len, send_buf, len); + } + + for (int i = 0; i < nranks - 1; i++) { + size_t rslice = (rank - i - 1 + nranks) % nranks; + size_t sslice = (rank - i + nranks) % nranks; + + SHMEM_CHECK_RET(socket_send(&state->ring_send_sock, ((char*)out + sslice * len), len), "rank " << rank << ": barrier send failed"); + SHMEM_CHECK_RET(socket_recv(&state->ring_recv_sock, ((char*)out + rslice * len), len), "rank " << rank << ": barrier recv failed"); + } + return SHMEM_SUCCESS; } -int shmemi_bootstrap_uid_finalize(shmemi_bootstrap_handle_t *boot_handle) { +static int shmemi_bootstrap_uid_barrier(shmemi_bootstrap_handle_t *handle) { + if (!handle || !handle->bootstrap_state) { + SHM_LOG_ERROR("bootstrap barrier: invalid arguments"); + return SHMEM_BOOTSTRAP_ERROR; + } + + uid_bootstrap_state* state = (uid_bootstrap_state*) handle->bootstrap_state; + int rank = state->rank; + int nranks = state->nranks; + + if (nranks == 1) { + return SHMEM_SUCCESS; + } + + if (state->ring_send_sock.state != SOCKET_STATE_READY || + state->ring_recv_sock.state != SOCKET_STATE_READY) { + SHM_LOG_ERROR("bootstrap barrier: rank " << rank << ": sockets not ready for barrier"); + return SHMEM_BOOTSTRAP_ERROR; + } + + char token = 0; + if (rank == 0) { + SHMEM_CHECK_RET(socket_send(&state->ring_send_sock, &token, 1), "rank 0: barrier send failed"); + SHMEM_CHECK_RET(socket_recv(&state->ring_recv_sock, &token, 1), "rank 0: barrier recv failed"); + } else { + SHMEM_CHECK_RET(socket_recv(&state->ring_recv_sock, &token, 1), "rank " << rank << ": barrier recv failed"); + SHMEM_CHECK_RET(socket_send(&state->ring_send_sock, &token, 1), "rank " << rank << ": barrier send failed"); + } + return SHMEM_SUCCESS; } -int shmemi_bootstrap_uid_allgather(void *dst, void *src, size_t size, shmemi_bootstrap_handle_t *boot_handle) { +static int shmemi_bootstrap_uid_alltoall(const void *sendbuf, void *recvbuf, int length, + shmemi_bootstrap_handle_t *handle) { } -int shmemi_bootstrap_uid_barrier(shmemi_bootstrap_handle_t *boot_handle) { +static void shmemi_bootstrap_uid_global_exit(int status) { + +} + +static bool is_loopback_addr(const sockaddr_t* addr) { + if (addr == nullptr) { + return false; + } + if (addr->type == ADDR_IPv4) { + return is_ipv4_loopback(&addr->addr.addr4.sin_addr); + } else if (addr->type == ADDR_IPv6) { + return is_ipv6_loopback(&addr->addr.addr6.sin6_addr) != 0; + } else { + return false; + } +} + +static bool matchSubnet(struct ifaddrs local_if, sockaddr_t* remote) { + int family; + bool is_lo_interface = (strncmp(local_if.ifa_name, "lo", 2) == 0); + if (remote->type == ADDR_IPv4) { + family = AF_INET; + } else if (remote->type == ADDR_IPv6) { + family = AF_INET6; + } else { + return false; + } + is_loopback_addr(remote); + SHM_LOG_DEBUG("local_if family: " << local_if.ifa_addr->sa_family << " remote family: " << family); + if (family != local_if.ifa_addr->sa_family) { + SHM_LOG_DEBUG(" matchSubnet family unmatch."); + return false; + } + + if (family == AF_INET) { + struct sockaddr_in* local_addr = (struct sockaddr_in*)(local_if.ifa_addr); + struct sockaddr_in* mask = (struct sockaddr_in*)(local_if.ifa_netmask); + struct sockaddr_in* remote_addr = &remote->addr.addr4; + + uint32_t local_subnet = local_addr->sin_addr.s_addr & mask->sin_addr.s_addr; + uint32_t remote_subnet = remote_addr->sin_addr.s_addr & mask->sin_addr.s_addr; + SHM_LOG_DEBUG("ipv4 matchSubnet result:" << (local_subnet == remote_subnet)); + return local_subnet == remote_subnet; + } else if (family == AF_INET6) { + struct sockaddr_in6* local_addr = (struct sockaddr_in6*)(local_if.ifa_addr); + struct sockaddr_in6* mask = (struct sockaddr_in6*)(local_if.ifa_netmask); + struct sockaddr_in6* remote_addr = &remote->addr.addr6; + bool same = true; + for (int c = 0; c < 16; c++) { + uint8_t l = local_addr->sin6_addr.s6_addr[c] & mask->sin6_addr.s6_addr[c]; + uint8_t r = remote_addr->sin6_addr.s6_addr[c] & mask->sin6_addr.s6_addr[c]; + if (l != r) { + same = false; + break; + } + } + if (is_lo_interface) { + SHM_LOG_DEBUG("IPv6 on lo interface, skipping sin6_scope_id validation"); + SHM_LOG_DEBUG("ipv6 matchSubnet result:" << same); + return same; + } + same &= (local_addr->sin6_scope_id == remote_addr->sin6_scope_id); + SHM_LOG_DEBUG("ipv6 matchSubnet result:" << same << " local_addr->sin6_scope_id: " <sin6_scope_id << " remote_addr->sin6_scope_id: "<< remote_addr->sin6_scope_id); + return same; + } + return false; } -#endif // SHMEM_BOOTSTRAP_UID \ No newline at end of file +static int find_interface_match_subnet(char* ifNames, sockaddr_t* localAddrs, sockaddr_t* remoteAddr) { + int found = 0; + struct ifaddrs *interfaces, *interface; + if (getifaddrs(&interfaces) != 0) { + return SHMEM_BOOTSTRAP_ERROR; + } + if (remoteAddr) { + if (remoteAddr->type == ADDR_IPv4) { + char ip_str[INET_ADDRSTRLEN]; + SHMEM_CHECK_RET(inet_ntop(AF_INET, &remoteAddr->addr.addr4.sin_addr, ip_str, INET_ADDRSTRLEN) == nullptr, "convert remote ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(remoteAddr->addr.addr4.sin_port); + SHM_LOG_INFO(" Type: IPv4" << " IP: " << ip_str <<" Port: " << (port ? port : 0) << " (0 means not set)"); + } else if (remoteAddr->type == ADDR_IPv6) { + char ip_str[INET6_ADDRSTRLEN]; + SHMEM_CHECK_RET(inet_ntop(AF_INET6, &remoteAddr->addr.addr6.sin6_addr, ip_str, INET6_ADDRSTRLEN) == nullptr, "convert remote ipv6 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(remoteAddr->addr.addr6.sin6_port); + SHM_LOG_INFO(" Type: IPv6" << " IP: " << ip_str <<" Port: " << (port ? port : 0) << " (0 means not set)"); + } else { + SHM_LOG_ERROR(" remoteAddr: Unknown address type is not within IPv4 or IPv6."); + return SHMEM_BOOTSTRAP_ERROR; + } + } else { + SHM_LOG_ERROR(" remoteAddr is NULL."); + return SHMEM_BOOTSTRAP_ERROR; + } + + bool remote_is_loopback = is_loopback_addr(remoteAddr); + if (remote_is_loopback) { + SHM_LOG_DEBUG("Remote address is loopback, check lo interface first"); + for (interface = interfaces; interface && !found; interface = interface->ifa_next) { + if (interface->ifa_addr == NULL) continue; + int family = interface->ifa_addr->sa_family; + if (family != AF_INET && family != AF_INET6) continue; + if (strcmp(interface->ifa_name, "lo") != 0) continue; + + if (matchSubnet(*interface, remoteAddr)) { + if (family == AF_INET) { + localAddrs->type = ADDR_IPv4; + memcpy(&localAddrs->addr.addr4, interface->ifa_addr, sizeof(struct sockaddr_in)); + } else { + localAddrs->type = ADDR_IPv6; + memcpy(&localAddrs->addr.addr6, interface->ifa_addr, sizeof(struct sockaddr_in6)); + } + strncpy(ifNames, interface->ifa_name, MAX_IF_NAME_SIZE); + ifNames[MAX_IF_NAME_SIZE] = '\0'; + found = 1; + break; + } + } + } + if (!found) { + for (interface = interfaces; interface && !found; interface = interface->ifa_next) { + if (interface->ifa_addr == NULL) continue; + int family = interface->ifa_addr->sa_family; + if (family != AF_INET && family != AF_INET6) continue; + + if (!remote_is_loopback && strcmp(interface->ifa_name, "lo") == 0) { + continue; + } + + if (matchSubnet(*interface, remoteAddr)) { + if (family == AF_INET) { + localAddrs->type = ADDR_IPv4; + memcpy(&localAddrs->addr.addr4, interface->ifa_addr, sizeof(struct sockaddr_in)); + } else { + localAddrs->type = ADDR_IPv6; + memcpy(&localAddrs->addr.addr6, interface->ifa_addr, sizeof(struct sockaddr_in6)); + } + strncpy(ifNames, interface->ifa_name, MAX_IF_NAME_SIZE); + ifNames[MAX_IF_NAME_SIZE] = '\0'; + found = 1; + break; + } + } + } + + freeifaddrs(interfaces); + return (found == 0) ? SHMEM_BOOTSTRAP_ERROR : SHMEM_SUCCESS; +} + +static int bootstrap_get_sock_addr(socket_t* sock, sockaddr_t* addr) { + if (sock == NULL) return SHMEM_BOOTSTRAP_ERROR; + struct sockaddr_storage temp_storage; + memset(&temp_storage, 0, sizeof(temp_storage)); + struct sockaddr* temp_addr = reinterpret_cast(&temp_storage); + socklen_t addr_len = 0; + int ret = socket_get_sainfo(sock, temp_addr, &addr_len); + if (ret != 0) { + return SHMEM_BOOTSTRAP_ERROR; + } + + if (temp_storage.ss_family == AF_INET) { + addr->type = ADDR_IPv4; + const struct sockaddr_in* ipv4_src = reinterpret_cast(&temp_storage); + memcpy(&addr->addr.addr4, ipv4_src, sizeof(struct sockaddr_in)); + } else if (temp_storage.ss_family == AF_INET6) { + addr->type = ADDR_IPv6; + const struct sockaddr_in6* ipv6_src = reinterpret_cast(&temp_storage); + memcpy(&addr->addr.addr6, ipv6_src, sizeof(struct sockaddr_in6)); + } else { + SHM_LOG_ERROR("Unknown address type is not within IPv4 or IPv6."); + return SHMEM_BOOTSTRAP_ERROR; + } + + return SHMEM_SUCCESS; +} + +// Network Initialization (Locating Local Interface Matching Subnet / initialize root node UID information when is_arg_init == false) +static int shmemi_bootstrap_net_init(shmemx_bootstrap_uid_state_t* uid_args, bool is_arg_init = true) { + SHM_LOG_INFO(" Network Initialization, Finding Interfaces Matching Subnets"); + pthread_mutex_lock(&priv_info.bootstrap_netlock); + + if (!is_arg_init) { + SHM_LOG_INFO("net_init uid_args is NULL, get uid arg"); + bool is_from_ifa = false; + + if (env_ip_port != nullptr) { + SHM_LOG_INFO("Environment variable SHMEM_UID_SESSION_ID has been set."); + SHMEM_CHECK_RET(shmemi_get_ip_from_env(uid_args, env_ip_port), + "No available addresses were found with env_ip_port."); + } else { + SHM_LOG_INFO("Environment variable SHMEM_UID_SESSION_ID is not set, automatically obtaining ipPort."); + is_from_ifa = true; + SHMEM_CHECK_RET(shmemi_get_ip_from_ifa(uid_args, env_ifname), + "No available addresses were found with auto."); + } + + SHM_LOG_INFO("Get uid arg success."); + is_arg_init = true; + } + + if (priv_info.bootstrap_netinitdone) { + // Initialized, printing currently saved information + SHM_LOG_INFO(" priv_info already inited: " << " bootstrap_netifname: " << (priv_info.bootstrap_netifname ? priv_info.bootstrap_netifname : "nullptr")); + if (priv_info.bootstrap_netifaddr.type == ADDR_IPv4) { + char ip_str[INET_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET, &priv_info.bootstrap_netifaddr.addr.addr4.sin_addr, ip_str, sizeof(ip_str)) == nullptr, "convert bootstrap_netifaddr ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + SHM_LOG_INFO(" bootstrap_netifaddr (IPv4): " << ip_str << ":" << ntohs(priv_info.bootstrap_netifaddr.addr.addr4.sin_port)); + } else if (priv_info.bootstrap_netifaddr.type == ADDR_IPv6) { + char ip_str[INET6_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET6, &priv_info.bootstrap_netifaddr.addr.addr6.sin6_addr, ip_str, sizeof(ip_str)) == nullptr, "convert bootstrap_netifaddr ipv6 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + SHM_LOG_INFO(" bootstrap_netifaddr (IPv6): " << ip_str << ":" << ntohs(priv_info.bootstrap_netifaddr.addr.addr6.sin6_port)); + } else { + SHM_LOG_ERROR(" bootstrap_netifaddr: Unknown address type is not within IPv4 or IPv6."); + return SHMEM_BOOTSTRAP_ERROR; + } + + pthread_mutex_unlock(&priv_info.bootstrap_netlock); + return SHMEM_SUCCESS; + } + + // Print the root node address to be matched (uid_args->addr) + if (uid_args->addr.type == ADDR_IPv4) { + char ip_str[INET_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET, &uid_args->addr.addr.addr4.sin_addr, ip_str, sizeof(ip_str)) == nullptr, "convert uid_args addr ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + SHM_LOG_INFO(" Root address (IPv4): " << ip_str << ":" << ntohs(uid_args->addr.addr.addr4.sin_port)); + } else if (uid_args->addr.type == ADDR_IPv6) { + char ip_str[INET6_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET6, &uid_args->addr.addr.addr6.sin6_addr, ip_str, sizeof(ip_str)) == nullptr, "convert uid_args addr ipv6 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + SHM_LOG_INFO(" Root address (IPv6): " << ip_str << ":" << ntohs(uid_args->addr.addr.addr6.sin6_port)); + } else { + SHM_LOG_ERROR(" Root address: Unknown address type is not within IPv4 or IPv6."); + return SHMEM_BOOTSTRAP_ERROR; + } + + // Find the local interface that matches the remote address + SHM_LOG_INFO("Trying to find interface matching root address."); + int find_result = find_interface_match_subnet(priv_info.bootstrap_netifname, + &priv_info.bootstrap_netifaddr, + &uid_args->addr); + if (find_result != 0) { + SHM_LOG_ERROR(" Failed to find matching interface."); + pthread_mutex_unlock(&priv_info.bootstrap_netlock); + return SHMEM_BOOTSTRAP_ERROR; + } + + // Print the information of priv_info. + if (priv_info.bootstrap_netifaddr.type == ADDR_IPv4) { + char ip_str[INET_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET, &priv_info.bootstrap_netifaddr.addr.addr4.sin_addr, ip_str, sizeof(ip_str)) == nullptr, "convert bootstrap_netifaddr ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + SHM_LOG_INFO(" bootstrap_netifaddr (IPv4): " << ip_str + << ":" << ntohs(priv_info.bootstrap_netifaddr.addr.addr4.sin_port)); + } else if (priv_info.bootstrap_netifaddr.type == ADDR_IPv6) { + char ip_str[INET6_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET6, &priv_info.bootstrap_netifaddr.addr.addr6.sin6_addr, ip_str, sizeof(ip_str)) == nullptr, "convert bootstrap_netifaddr ipv6 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + SHM_LOG_INFO(" bootstrap_netifaddr (IPv6): " << ip_str + << ":" << ntohs(priv_info.bootstrap_netifaddr.addr.addr6.sin6_port)); + } else { + SHM_LOG_ERROR(" Root bootstrap_netifaddr: Unknown address type is not within IPv4 or IPv6."); + return SHMEM_BOOTSTRAP_ERROR; + } + + priv_info.bootstrap_netinitdone = 1; + pthread_mutex_unlock(&priv_info.bootstrap_netlock); + SHM_LOG_INFO(" Net init success, priv_info.bootstrap_netinitdone = 1"); + return SHMEM_SUCCESS; +} + +static int set_files_limit() { + struct rlimit files_limit, old_limit; + + SHMEM_CHECK_RET(getrlimit(RLIMIT_NOFILE, &old_limit), "getrlimit failed", SHMEM_BOOTSTRAP_ERROR); + SHM_LOG_DEBUG("Original file descriptor limit - soft limit: " << old_limit.rlim_cur << ", hard limit: " << old_limit.rlim_max); + + files_limit = old_limit; + files_limit.rlim_cur = files_limit.rlim_max; + SHMEM_CHECK_RET(setrlimit(RLIMIT_NOFILE, &files_limit), "setrlimit failed", SHMEM_BOOTSTRAP_ERROR); + + struct rlimit new_limit; + getrlimit(RLIMIT_NOFILE, &new_limit); + SHM_LOG_DEBUG("Updated file descriptor limit - soft limit: " << new_limit.rlim_cur << ", hard limit: " << new_limit.rlim_max); + + return SHMEM_SUCCESS; +} + +static void* bootstrap_root(void* rargs) { + struct bootstrap_root_args* args = (struct bootstrap_root_args*)rargs; + if (args == NULL || args->listen_sock == NULL) { + SHM_LOG_ERROR("bootstrap_root: invalid args"); + return NULL; + } + + socket_t* listen_sock = args->listen_sock; + uint64_t magic = args->magic; + int root_version = args->version; + int nranks = 0; + int c = 0; // Number of received nodes. + bootstrap_ext_info info; + sockaddr_t* zero_addr = nullptr; + SHMEM_BOOTSTRAP_CALLOC(&zero_addr, 1); + sockaddr_t* rank_addrs = NULL; // Store the common listening addresses of all nodes. + sockaddr_t* rank_addrs_root = NULL; // Store the dedicated root addresses for all nodes. + + if (zero_addr == NULL) { + SHM_LOG_ERROR("bootstrap_root: calloc zero_addr failed"); + SHMEM_BOOTSTRAP_PTR_FREE(args); + return NULL; + } + + // Adjusting file descriptor limits (the root node needs to handle multiple connections) + if (set_files_limit() != 0) { + SHM_LOG_ERROR("bootstrap_root: set_files_limit failed"); + SHMEM_BOOTSTRAP_PTR_FREE(zero_addr); + SHMEM_BOOTSTRAP_PTR_FREE(args); + return NULL; + } + + // Continuously receive connections and information from all slave nodes + while (1) { + socket_t client_sock; + // Initialize client socket (for receiving connections from a single slave node) + if (socket_init(&client_sock, SOCKET_TYPE_BOOTSTRAP, SOCKET_MAGIC, NULL) != 0) { + SHM_LOG_ERROR("bootstrap_root: socket_init failed"); + break; + } + + // Accept connections from the node (blocking wait) + if (socket_accept(&client_sock, listen_sock) != 0) { + SHM_LOG_ERROR("bootstrap_root: socket_accept failed"); + socket_close(&client_sock); + break; + } + + // Version verification + int peer_version; + if (socket_recv(&client_sock, &peer_version, sizeof(peer_version)) != 0) { + SHM_LOG_ERROR("bootstrap_root: recv peer_version failed"); + socket_close(&client_sock); + break; + } + if (socket_send(&client_sock, &root_version, sizeof(root_version)) != 0) { + SHM_LOG_ERROR("bootstrap_root: send root_version failed"); + socket_close(&client_sock); + break; + } + if (peer_version != root_version) { + SHM_LOG_ERROR("bootstrap_root: version mismatch"); + socket_close(&client_sock); + break; + } + + // Receive address information from the node + if (socket_recv(&client_sock, &info, sizeof(info)) != 0) { + SHM_LOG_ERROR("bootstrap_root: recv info failed"); + socket_close(&client_sock); + break; + } + socket_close(&client_sock); + + // Initialize the address array upon first reception + if (c == 0) { + nranks = info.nranks; + if (nranks <= 0) { + SHM_LOG_ERROR("bootstrap_root: invalid nranks"); + break; + } + SHMEM_BOOTSTRAP_CALLOC(&rank_addrs, nranks); + SHMEM_BOOTSTRAP_CALLOC(&rank_addrs_root, nranks); + if (rank_addrs == NULL || rank_addrs_root == NULL) { + SHM_LOG_ERROR("bootstrap_root: calloc addr arrays failed"); + break; + } + } + + if (info.nranks != nranks || info.rank < 0 || info.rank >= nranks) { + SHM_LOG_ERROR("bootstrap_root: invalid info from rank " << info.rank); + break; + } + // Check if the rank is duplicated + if (memcmp(zero_addr, &rank_addrs_root[info.rank], sizeof(sockaddr_t)) != 0) { + SHM_LOG_ERROR("bootstrap_root: duplicate rank " << info.rank); + break; + } + + memcpy(&rank_addrs_root[info.rank], &info.ext_address_listen_root, sizeof(sockaddr_t)); + memcpy(&rank_addrs[info.rank], &info.ext_addr_listen, sizeof(sockaddr_t)); + c++; + + if (c >= nranks) { + SHM_LOG_INFO("bootstrap_root: Address receiving completed"); + break; + } + } + + if (c == nranks && rank_addrs != NULL && rank_addrs_root != NULL) { + SHM_LOG_INFO("bootstrap_root: Start distributing addresses."); + for (int r = 0; r < nranks; r++) { + int next_rank = (r + 1) % nranks; + socket_t send_sock; + + if (socket_init(&send_sock, SOCKET_TYPE_BOOTSTRAP, magic, &rank_addrs_root[r]) != 0) { + SHM_LOG_ERROR("bootstrap_root: init send_sock for rank " << r << " failed"); + break; + } + + if (socket_connect(&send_sock) != 0) { + SHM_LOG_ERROR("bootstrap_root: connect to rank " << r << " failed"); + socket_close(&send_sock); + break; + } + + if (socket_send(&send_sock, &rank_addrs[next_rank], sizeof(sockaddr_t)) != 0) { + SHM_LOG_ERROR("bootstrap_root: send next_addr to rank " << r << " failed"); + socket_close(&send_sock); + break; + } + + socket_close(&send_sock); + } + } + + SHMEM_BOOTSTRAP_PTR_FREE(zero_addr); + SHMEM_BOOTSTRAP_PTR_FREE(rank_addrs); + SHMEM_BOOTSTRAP_PTR_FREE(rank_addrs_root); + if (listen_sock != NULL) { + socket_close(listen_sock); + SHMEM_BOOTSTRAP_PTR_FREE(listen_sock); + } + SHMEM_BOOTSTRAP_PTR_FREE(args); + return NULL; +} + +static int bootstrap_create_root(shmemx_bootstrap_uid_state_t* uid_args) { + if (uid_args == NULL) { + SHM_LOG_ERROR("bootstrap_create_root: invalid uid_args"); + return SHMEM_BOOTSTRAP_ERROR; + } + + // 1. Create a dedicated listening socket for the root node. + socket_t* listen_sock_root = nullptr; + SHMEM_CHECK_RET(SHMEM_BOOTSTRAP_CALLOC(&listen_sock_root, 1), "bootstrap_create_root: malloc listen_sock_root failed"); + + // 2. Initialize the listening socket (using the global network interface address) + SHMEM_CHECK_RET(socket_init(listen_sock_root, SOCKET_TYPE_BOOTSTRAP, uid_args->magic, &uid_args->addr), "bootstrap_create_root: socket_init failed"); + + SHMEM_CHECK_RET(socket_listen(listen_sock_root), "Listen_sock_root failed while executing listen. fd=" << listen_sock_root->fd); + + // 3. Write the root node's listening address into uid_args (for slave nodes to connect to). + memcpy(&uid_args->addr, &listen_sock_root->addr, sizeof(sockaddr_t)); + + // 4. Prepare thread parameters + struct bootstrap_root_args* args = nullptr; + SHMEM_CHECK_RET(SHMEM_BOOTSTRAP_CALLOC(&args, 1), "bootstrap_create_root: malloc args failed"); + + args->listen_sock = listen_sock_root; + args->magic = uid_args->magic; + args->version = uid_args->version; + + // 5. Create detached thread + pthread_attr_t attr; + pthread_attr_init(&attr); + pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_DETACHED); + int ret = pthread_create(&priv_info.bootstrap_root, &attr, bootstrap_root, args); + if (ret != 0) { + SHM_LOG_ERROR("bootstrap_create_root: pthread_create failed"); + SHMEM_BOOTSTRAP_PTR_FREE(args); + socket_close(listen_sock_root); + SHMEM_BOOTSTRAP_PTR_FREE(listen_sock_root); + return SHMEM_BOOTSTRAP_ERROR; + } + pthread_attr_destroy(&attr); + return SHMEM_SUCCESS; +} + + + +int shmemi_bootstrap_get_unique_id(void* uid) { + shmemx_bootstrap_uid_state_t* uid_args = (shmemx_bootstrap_uid_state_t*)uid; + + if (env_ip_port == nullptr) { + const char* envip = std::getenv("SHMEM_UID_SESSION_ID"); + if (envip != nullptr) { + env_ip_port = envip; + SHM_LOG_DEBUG("SHMEM_UID_SESSION_ID is: " << env_ip_port); + } else { + SHM_LOG_DEBUG("The environment variable SHMEM_UID_SESSION_ID is not set."); + } + } + + if (env_ifname == nullptr) { + const char* envinfo = std::getenv("SHMEM_UID_SOCK_IFNAME"); + if (envinfo != nullptr) { + env_ifname = envinfo; + SHM_LOG_DEBUG("SHMEM_UID_SOCK_IFNAME is: " << env_ifname); + } else { + SHM_LOG_DEBUG("The environment variable SHMEM_UID_SOCK_IFNAME is not set."); + } + } + + SHMEM_CHECK_RET(shmemi_bootstrap_net_init(uid_args, false), "rank 0: failed to init bootstrap net."); + SHMEM_CHECK_RET(bootstrap_create_root(uid_args), "rank 0: failed to create root thread"); + return SHMEM_SUCCESS; +} + +// Plugin pre-initialization entry function. +int shmemi_bootstrap_plugin_pre_init(shmemi_bootstrap_handle_t* handle) { + if (handle->pre_init_ops == nullptr) { + SHM_LOG_DEBUG(" bootstrap plugin pre init start."); + SHMEM_CHECK_RET(SHMEM_BOOTSTRAP_CALLOC(&handle->pre_init_ops, 1)); + handle->pre_init_ops->get_unique_id = shmemi_bootstrap_get_unique_id; + handle->pre_init_ops->cookie = nullptr; + SHM_LOG_DEBUG(" bootstrap plugin pre init end."); + } else { + SHM_LOG_DEBUG(" pre_init_ops had already prepared."); + } + return SHMEM_SUCCESS; +} + + +int shmemi_bootstrap_plugin_init(void* comm, shmemi_bootstrap_handle_t* handle) { + + if (comm == nullptr || handle == nullptr) { + SHM_LOG_ERROR(" shmemi_bootstrap_plugin_init: invalid arguments (nullptr)"); + return SHMEM_BOOTSTRAP_ERROR; + } + + socket_t sock, listen_sock_root; + uid_bootstrap_state* state = nullptr; + SHMEM_CHECK_RET(SHMEM_BOOTSTRAP_CALLOC(&state, 1)); + shmemx_bootstrap_uid_state_t* uid_args = (shmemx_bootstrap_uid_state_t*)comm; + sockaddr_t next_addr; + bootstrap_ext_info info = {}; + + int rank = uid_args->rank; + int nranks = uid_args->nranks; + uint64_t magic = uid_args->magic; + + SHMEM_CHECK_RET(shmemi_bootstrap_net_init(uid_args), " rank: " << rank << ": network interface init failed."); + + if (state == nullptr) { + SHM_LOG_ERROR(" rank: " << rank << ": failed to allocate uid_bootstrap_state"); + return SHMEM_BOOTSTRAP_ERROR; + } + + state->rank = rank; + state->nranks = nranks; + state->magic = magic; + + SHMEM_CHECK_RET(SHMEM_BOOTSTRAP_CALLOC(&state->peer_addrs, nranks)); + + if (state->peer_addrs == nullptr) { + SHM_LOG_ERROR(" rank: " << rank << ": failed to allocate peer_addrs"); + SHMEM_BOOTSTRAP_PTR_FREE(state); + return SHMEM_BOOTSTRAP_ERROR; + } + + handle->bootstrap_state = state; + handle->mype = rank; + handle->npes = nranks; + + SHMEM_CHECK_RET(socket_init(&state->listen_sock, SOCKET_TYPE_BOOTSTRAP, state->magic, &priv_info.bootstrap_netifaddr), "State's listen_sock failed while executing init. fd=" << state->listen_sock.fd); + SHMEM_CHECK_RET(socket_listen(&state->listen_sock), "State's listen_sock failed while executing listen. fd=" << state->listen_sock.fd); + SHMEM_CHECK_RET(bootstrap_get_sock_addr(&state->listen_sock, &info.ext_addr_listen), "Get addr failed, the listen_sock in state maybe null. fd=" << state->listen_sock.fd); + + SHMEM_CHECK_RET(socket_init(&listen_sock_root, SOCKET_TYPE_BOOTSTRAP, state->magic, &priv_info.bootstrap_netifaddr), "Listen_sock_root failed while executing init. fd=" << listen_sock_root.fd); + SHMEM_CHECK_RET(socket_listen(&listen_sock_root), "listen_sock_root failed while executing listen. fd=" << listen_sock_root.fd); + SHMEM_CHECK_RET(bootstrap_get_sock_addr(&listen_sock_root, &info.ext_address_listen_root), "Get addr failed, the listen_sock_root maybe null. fd=" << listen_sock_root.fd); + + + SHMEM_CHECK_RET(socket_init(&sock, SOCKET_TYPE_BOOTSTRAP, magic, &uid_args->addr), "Sock failed while executing init. fd=" << sock.fd); + SHMEM_CHECK_RET(socket_connect(&sock), "Sock failed while executing connect. fd=" << sock.fd); + int peer_version = uid_args->version; + int root_version; + SHMEM_CHECK_RET(socket_send(&sock, &peer_version, sizeof(peer_version)), "Sock failed while executing send peer_version. fd=" << sock.fd); + SHMEM_CHECK_RET(socket_recv(&sock, &root_version, sizeof(root_version)), "Sock failed while executing recv root_version. fd=" << sock.fd); + SHMEM_CHECK_RET(peer_version != root_version, " rank: " << rank << " . version mismatch with root", SHMEM_SMEM_ERROR); + + info.rank = rank; + info.nranks = nranks; + + if (info.ext_addr_listen.type == ADDR_IPv4) { + struct sockaddr_in* ipv4 = &info.ext_addr_listen.addr.addr4; + char ip_str[INET_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET, &ipv4->sin_addr, ip_str, sizeof(ip_str)) == nullptr, "convert ext_addr_listen ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(ipv4->sin_port); + SHM_LOG_INFO(" Ext_addr_listen socket: Type: IPv6, IP: " << ip_str << ", Port: " << port << ", sa_family: " << ipv4->sin_family); + + } else if (info.ext_addr_listen.type == ADDR_IPv6) { + struct sockaddr_in6* ipv6 = &info.ext_addr_listen.addr.addr6; + char ip_str[INET6_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET6, &ipv6->sin6_addr, ip_str, sizeof(ip_str)) == nullptr, "convert ext_addr_listen ipv6 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(ipv6->sin6_port); + SHM_LOG_INFO(" Ext_addr_listen socket: Type: IPv6, IP: " << ip_str << ", Port: " << port << ", sa_family: " << ipv6->sin6_family); + } else { + SHM_LOG_ERROR(" Ext_address_listen_root socket: Type: Unknown address type is not within IPv4 or IPv6. (type=" << info.ext_addr_listen.type << ")"); + return SHMEM_BOOTSTRAP_ERROR; + } + + if (info.ext_address_listen_root.type == ADDR_IPv4) { + struct sockaddr_in* ipv4 = &info.ext_address_listen_root.addr.addr4; + char ip_str[INET_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET, &ipv4->sin_addr, ip_str, sizeof(ip_str)) == nullptr, "convert ext_address_listen_root ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(ipv4->sin_port); + SHM_LOG_INFO(" Ext_address_listen_root socket: Type: IPv6, IP: " << ip_str << ", Port: " << port << ", sa_family: " << ipv4->sin_family); + + } else if (info.ext_address_listen_root.type == ADDR_IPv6) { + struct sockaddr_in6* ipv6 = &info.ext_address_listen_root.addr.addr6; + char ip_str[INET6_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET6, &ipv6->sin6_addr, ip_str, sizeof(ip_str)) == nullptr, "convert ext_address_listen_root ipv6 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(ipv6->sin6_port); + SHM_LOG_INFO(" Ext_address_listen_root socket: Type: IPv6, IP: " << ip_str << ", Port: " << port << ", sa_family: " << ipv6->sin6_family); + } else { + SHM_LOG_ERROR(" Ext_address_listen_root socket: Type: Unknown address type is not within IPv4 or IPv6. (type=" << info.ext_address_listen_root.type << ")"); + return SHMEM_BOOTSTRAP_ERROR; + + } + + + SHMEM_CHECK_RET(socket_send(&sock, &info, sizeof(info)), "Sock failed while executing send info. fd=" << sock.fd); + SHMEM_CHECK_RET(socket_close(&sock), "Sock failed while executing close. fd=" << sock.fd); + + + SHMEM_CHECK_RET(socket_init(&sock, SOCKET_TYPE_BOOTSTRAP, SOCKET_MAGIC, nullptr), "Sock failed while executing init. fd=" << sock.fd); + SHMEM_CHECK_RET(socket_accept(&sock, &listen_sock_root), "Sock failed while executing accept listen_sock_root. fd=" << sock.fd); + SHMEM_CHECK_RET(socket_recv(&sock, &next_addr, sizeof(next_addr)), "Sock failed while executing recv next_addr. fd=" << sock.fd); + SHMEM_CHECK_RET(socket_close(&sock), "Sock failed while executing close. fd=" << sock.fd); + SHMEM_CHECK_RET(socket_close(&listen_sock_root), "Listen_sock_root failed while executing close. fd=" << listen_sock_root.fd); + + + if (next_addr.type == ADDR_IPv4) { + char ip_str[INET_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET, &next_addr.addr.addr4.sin_addr, ip_str, sizeof(ip_str)) == nullptr, "convert next_addr ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(next_addr.addr.addr4.sin_port); + SHM_LOG_INFO(" Received next socket: Type: IPv4, IP: " << ip_str << ", Port: " << port); + } else if (next_addr.type == ADDR_IPv6) { + char ip_str[INET6_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET6, &next_addr.addr.addr6.sin6_addr, ip_str, sizeof(ip_str)) == nullptr, "convert next_addr ipv6 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(next_addr.addr.addr6.sin6_port); + SHM_LOG_INFO(" Received next socket: Type: IPv6, IP: " << ip_str << ", Port: " << port); + } else { + SHM_LOG_ERROR(" Received next socket: Type: Unknown address type is not within IPv4 or IPv6."); + return SHMEM_BOOTSTRAP_ERROR; + } + + // Initialize ring send socket + SHMEM_CHECK_RET(socket_init(&state->ring_send_sock, SOCKET_TYPE_BOOTSTRAP, magic, &next_addr), "State's ring_send_sock failed while executing init. fd=" << state->ring_send_sock.fd); + SHMEM_CHECK_RET(socket_connect(&state->ring_send_sock), "State's ring_send_sock failed while executing connect. fd=" << state->ring_send_sock.fd); + SHMEM_CHECK_RET(socket_init(&state->ring_recv_sock, SOCKET_TYPE_BOOTSTRAP, SOCKET_MAGIC, nullptr), "State's ring_recv_sock failed while executing init. fd=" << state->ring_recv_sock.fd); + SHMEM_CHECK_RET(socket_accept(&state->ring_recv_sock, &state->listen_sock),"State's ring_recv_sock failed while executing accept State's listen_sock. fd=" << state->ring_recv_sock.fd); + SHMEM_CHECK_RET(bootstrap_get_sock_addr(&state->listen_sock, state->peer_addrs + handle->mype), "Get addr failed, the listen_sock in state maybe null. fd=" << state->listen_sock.fd); + + SHMEM_CHECK_RET(shmemi_bootstrap_uid_allgather(BOOTSTRAP_IN_PLACE, state->peer_addrs, sizeof(sockaddr_t), handle), "Bootstrap_uid_allgather failed"); + + handle->allgather = shmemi_bootstrap_uid_allgather; + handle->barrier = shmemi_bootstrap_uid_barrier; + handle->finalize = shmemi_bootstrap_uid_finalize; + handle->alltoall = nullptr; + handle->global_exit = nullptr; + + SHM_LOG_INFO("rank " << rank << ": bootstrap plugin initialized successfully"); + return SHMEM_SUCCESS; +} \ No newline at end of file diff --git a/src/modules/bootstrap/socket/uid_socket.cpp b/src/modules/bootstrap/socket/uid_socket.cpp new file mode 100644 index 00000000..3103ee86 --- /dev/null +++ b/src/modules/bootstrap/socket/uid_socket.cpp @@ -0,0 +1,565 @@ + +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include +#include +#include +#include "uid_socket.h" + +static int socket_progress(int op, socket_t* sock, void* ptr, int size, int* offset, bool block = false, bool state_check = true) { + if (sock == nullptr || ptr == nullptr || offset == nullptr || size < 0 || *offset < 0 || *offset > size) { + SHM_LOG_ERROR("Invalid arguments: sock=" << sock << ", ptr=" << ptr + << ", size=" << size << ", offset=" << *offset); + return SHMEM_BOOTSTRAP_ERROR; + } + if (state_check && sock->state != SOCKET_STATE_READY) { + SHM_LOG_ERROR("socket_progress: invalid state " << sock->state << " (expected READY)"); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + int bytes = 0; + int closed = 0; + char* data = (char*)(ptr); + SHM_LOG_DEBUG("socket_progress: start"); + do { + if (op == SOCKET_TYPE_RECV) { + int flags = block ? 0 : MSG_DONTWAIT; + SHM_LOG_DEBUG("Executing RECV operation - fd: " << sock->fd << ", buffer offset: " << *offset << ", remaining size: " << (size - *offset) << ", flags: " << flags); + bytes = recv(sock->fd, data + *offset, size - *offset, flags); + SHM_LOG_DEBUG("RECV result - bytes received: " << bytes); + } else if (op == SOCKET_TYPE_SEND) { + int flags = block ? MSG_NOSIGNAL : (MSG_DONTWAIT | MSG_NOSIGNAL); + SHM_LOG_DEBUG("Executing SEND operation - fd: " << sock->fd << ", buffer offset: " << *offset << ", remaining size: " << (size - *offset) << ", flags: " << flags); + bytes = send(sock->fd, data + *offset, size - *offset, flags); + SHM_LOG_DEBUG("SEND result - bytes sent: " << bytes); + } else { + SHM_LOG_ERROR("Invalid operation type: " << op); + return SHMEM_BOOTSTRAP_ERROR; + } + + if (op == SOCKET_TYPE_RECV && bytes == 0) { + SHM_LOG_DEBUG("RECV operation got 0 bytes - remote peer closed the connection (fd: " << sock->fd << ")"); + closed = 1; + break; + } + + if (bytes == -1) { + int err = errno; + if (err != EINTR && err != EWOULDBLOCK && err != EAGAIN) { + SHM_LOG_ERROR("Socket operation failed (fd: " << sock->fd << ", op: " << op << ") - error: " << strerror(err) << " (errno: " << err << ")"); + return SHMEM_BOOTSTRAP_ERROR; + } else { + SHM_LOG_DEBUG("Socket operation would block (fd: " << sock->fd << ", op: " << op << ") - errno: " << err << ", setting bytes to 0"); + bytes = 0; + } + } + + *offset += bytes; + SHM_LOG_DEBUG("Updated buffer offset - current offset: " << *offset << ", total size: " << size); + } while (bytes > 0 && *offset < size); + + if (closed) { + SHM_LOG_ERROR("Loop exited - remote connection closed (fd: " << sock->fd << ")"); + return SHMEM_BOOTSTRAP_ERROR; + } + SHM_LOG_DEBUG("socket_progress: success"); + + return SHMEM_SUCCESS; +} + +static int socket_wait(int op, socket_t* sock, void* ptr, int size, int* offset, bool block = false, bool state_check = true) { + while (*offset < size) + if (socket_progress(op, sock, ptr, size, offset, block, state_check) != SHMEM_SUCCESS) { + SHM_LOG_ERROR("socket_wait fail!"); + return SHMEM_BOOTSTRAP_ERROR; + } + return SHMEM_SUCCESS; +} + +int socket_send(socket_t* sock, void* ptr, int size) { + SHM_LOG_DEBUG("socket_send: start"); + int offset = 0; + if (sock == NULL || ptr == NULL || size <= 0 ) { + SHM_LOG_ERROR("send sock == NULL"); + return SHMEM_BOOTSTRAP_ERROR; + } + + return socket_wait(SOCKET_TYPE_SEND, sock, ptr, size, &offset); +} + +int socket_recv(socket_t* sock, void* ptr, int size) { + SHM_LOG_DEBUG("socket_recv: start"); + int offset = 0; + if (sock == NULL) { + SHM_LOG_ERROR("recv sock == NULL"); + return SHMEM_BOOTSTRAP_ERROR; + } + return socket_wait(SOCKET_TYPE_RECV, sock, ptr, size, &offset); +} + + +int socket_close(socket_t* sock) { + if (sock) { + if (sock->fd >= 0) { + shutdown(sock->fd, SHUT_RDWR); + close(sock->fd); + } + sock->fd = -1; + sock->accept_fd = -1; + sock->state = SOCKET_STATE_CLOSED; + } else { + SHM_LOG_DEBUG("socket_close: sock is null"); + } + SHM_LOG_DEBUG("socket_close: success"); + return SHMEM_SUCCESS; +} + +int socket_get_sainfo(socket_t* sock, sockaddr* sa, socklen_t* addr_len) { + if (sock == nullptr || sa == nullptr || addr_len == nullptr) { + SHM_LOG_ERROR("Some of sock, sa and addr_len are null."); + return SHMEM_BOOTSTRAP_ERROR; + } + + if (sock->addr.type == ADDR_IPv4) { + SHM_LOG_DEBUG("socket_get_sainfo memcpy addr4"); + memcpy(sa, &sock->addr.addr.addr4, sizeof(struct sockaddr_in)); + *addr_len = sizeof(struct sockaddr_in); + } else { + SHM_LOG_DEBUG("socket_get_sainfo memcpy addr6"); + memcpy(sa, &sock->addr.addr.addr6, sizeof(struct sockaddr_in6)); + *addr_len = sizeof(struct sockaddr_in6); + } + + return SHMEM_SUCCESS; +} + + +int socket_listen(socket_t* sock) { + if (!sock || sock->fd < 0 || sock->state == SOCKET_STATE_ERROR) { + SHM_LOG_ERROR("socket_listen Precondition failed! " + << "sock is null: " << (sock == nullptr) + << ", invalid fd: " << (sock ? (sock->fd < 0) : true) + << ", state is error: " << (sock ? (sock->state == SOCKET_STATE_ERROR) : false)); + if (sock) sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + SHM_LOG_INFO("socket_listen Entering. sock fd: " << (sock ? sock->fd : -1) + << ", current state: " << (sock ? sock->state : -1)); + + if (sock->state == SOCKET_STATE_CREATED) { + SHM_LOG_DEBUG("socket_listen State is CREATED, starting bind process"); + struct sockaddr_storage sa_storage; + memset(&sa_storage, 0, sizeof(sa_storage)); + struct sockaddr* sa = reinterpret_cast(&sa_storage); + socklen_t addr_len; + + SHMEM_CHECK_RET(socket_get_sainfo(sock, sa, &addr_len),"socket_listen socket_get_sainfo failed"); + + + std::string target_ip = "unknown"; + uint16_t target_port = 0; + if (sa->sa_family == AF_INET) { + struct sockaddr_in* ipv4 = reinterpret_cast(sa); + char ip_str[INET_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET, &ipv4->sin_addr, ip_str, sizeof(ip_str)) == nullptr, "convert ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + target_ip = ip_str; + target_port = ntohs(ipv4->sin_port); + } else if (sa->sa_family == AF_INET6) { + struct sockaddr_in6* ipv6 = reinterpret_cast(sa); + char ip_str[INET6_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET6, &ipv6->sin6_addr, ip_str, sizeof(ip_str)) == nullptr, "convert ipv6 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + target_ip = ip_str; + target_port = ntohs(ipv6->sin6_port); + } + SHM_LOG_DEBUG("socket_listen socket_get_sainfo succeeded, addr_len: " << addr_len + << ", target IP: " << target_ip << ", target port: " << target_port); + + int opt = 1; + if (setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) { + SHM_LOG_ERROR("socket_listen setsockopt(SO_REUSEADDR) failed! " + << "errno: " << errno << ", reason: " << strerror(errno)); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + SHM_LOG_DEBUG("socket_listen setsockopt(SO_REUSEADDR) succeeded"); + + if (bind(sock->fd, sa, addr_len) < 0) { + SHM_LOG_ERROR("socket_listen bind failed! " + << "errno: " << errno << ", reason: " << strerror(errno) + << ", fd: " << sock->fd << ", addr_len: " << addr_len + << ", target IP: " << target_ip << ", target port: " << target_port); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + SHM_LOG_DEBUG("[socket_listen] bind succeeded"); + + if (getsockname(sock->fd, &sock->addr.addr.sa, &addr_len) < 0) { + SHM_LOG_ERROR("socket_listen getsockname failed! " + << "errno: " << errno << ", reason: " << strerror(errno)); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + + if (sock->addr.type == ADDR_IPv4) { + struct sockaddr_in* ipv4 = &sock->addr.addr.addr4; + char ip_str[INET_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET, &ipv4->sin_addr, ip_str, sizeof(ip_str)) == nullptr, "convert ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(ipv4->sin_port); + SHM_LOG_DEBUG(" Stored IPv4 address: " << ip_str << ":" << port << " sa_family: " << ipv4->sin_family << " (expected AF_INET=" << AF_INET << ")"); + } else if (sock->addr.type == ADDR_IPv6) { + struct sockaddr_in6* ipv6 = &sock->addr.addr.addr6; + char ip_str[INET6_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET6, &ipv6->sin6_addr, ip_str, sizeof(ip_str)) == nullptr, "convert ipv6 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(ipv6->sin6_port); + SHM_LOG_DEBUG(" Stored IPv6 address: " << ip_str << ":" << port << " sa_family: " << ipv6->sin6_family << " (expected AF_INET6=" << AF_INET6 << ")"); + } else { + SHM_LOG_ERROR(" Stored address type: unknown (type=" << sock->addr.type << ")"); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + + sock->state = SOCKET_STATE_BOUND; + SHM_LOG_DEBUG("socket_listen State updated to BOUND"); + } + + if (sock->state == SOCKET_STATE_BOUND) { + SHM_LOG_DEBUG("socket_listen State is BOUND, starting listen"); + if (listen(sock->fd, SOCKET_BACKLOG) < 0) { + SHM_LOG_ERROR("socket_listen] listen failed! " + << "errno: " << errno << ", reason: " << strerror(errno) + << ", fd: " << sock->fd << ", backlog: " << SOCKET_BACKLOG); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + sock->accept_fd = sock->fd; + sock->state = SOCKET_STATE_LISTENING; + SHM_LOG_DEBUG("socket_listen listen succeeded. New state: LISTENING, accept_fd: " << sock->accept_fd); + } else { + SHM_LOG_ERROR("socket_listen Skip listen: current state is " << sock->state << " (expected BOUND)"); + return SHMEM_BOOTSTRAP_ERROR; + } + + SHM_LOG_DEBUG("socket_listen Exiting with success"); + return SHMEM_SUCCESS; +} + +static int socket_try_accept(socket_t* sock) { + if (sock->state != SOCKET_STATE_ACCEPTING) { + SHM_LOG_ERROR("socket_try_accept: invalid state " << sock->state); + return SHMEM_BOOTSTRAP_ERROR; + } + + struct sockaddr sa; + socklen_t socklen = sizeof(sa); + + sock->fd = accept(sock->accept_fd, &sa, &socklen); + if (sock->fd != -1) { + if (sa.sa_family == AF_INET) { + sock->addr.type = ADDR_IPv4; + memcpy(&sock->addr.addr.addr4, &sa, sizeof(struct sockaddr_in)); + } else { + sock->addr.type = ADDR_IPv6; + memcpy(&sock->addr.addr.addr6, &sa, sizeof(struct sockaddr_in6)); + } + sock->state = SOCKET_STATE_ACCEPTED; + } else if (errno != EAGAIN && errno != EWOULDBLOCK) { + SHM_LOG_ERROR("socket_try_accept failed: " << strerror(errno)); + return SHMEM_BOOTSTRAP_ERROR; + } + + return SHMEM_SUCCESS; +} + +static int socket_finalize_accept(socket_t* sock) { + if (sock->state != SOCKET_STATE_ACCEPTED) { + SHM_LOG_ERROR("socket_finalize_accept: invalid state " << sock->state); + return SHMEM_BOOTSTRAP_ERROR; + } + + uint64_t magic; + socket_type_t type; + int received = 0; + const int one = 1; + + if (setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)) < 0) { + SHM_LOG_ERROR("setsockopt TCP_NODELAY failed: " << strerror(errno)); + close(sock->fd); + sock->fd = -1; + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + + if (socket_progress(SOCKET_TYPE_RECV, sock, &magic, sizeof(magic), &received, false, false) != SHMEM_SUCCESS) { + return SHMEM_BOOTSTRAP_ERROR; + } + if (received == 0) return SHMEM_SUCCESS; + if (socket_wait(SOCKET_TYPE_RECV, sock, &magic, sizeof(magic), &received, false, false) != SHMEM_SUCCESS) { + return SHMEM_BOOTSTRAP_ERROR; + } + + if (magic != sock->magic) { + SHM_LOG_DEBUG("socket_finalize_accept: wrong magic " << magic << " != " << sock->magic); + close(sock->fd); + sock->fd = -1; + sock->state = SOCKET_STATE_ACCEPTING; + return SHMEM_SUCCESS; + } + + received = 0; + if (socket_wait(SOCKET_TYPE_RECV, sock, &type, sizeof(type), &received, false, false) != SHMEM_SUCCESS) { + return SHMEM_BOOTSTRAP_ERROR; + } + if (type != sock->type) { + SHM_LOG_ERROR("socket_finalize_accept: wrong type " << type << " != " << sock->type); + close(sock->fd); + sock->fd = -1; + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + + sock->state = SOCKET_STATE_READY; + return SHMEM_SUCCESS; +} + +static int socket_start_connect(socket_t* sock) { + if (sock->state != SOCKET_STATE_CONNECTING) { + SHM_LOG_ERROR("socket_start_connect: invalid state " << sock->state); + return SHMEM_BOOTSTRAP_ERROR; + } + + struct sockaddr_storage sa_storage; + memset(&sa_storage, 0, sizeof(sa_storage)); + struct sockaddr* sa = reinterpret_cast(&sa_storage); + socklen_t addr_len; + if (socket_get_sainfo(sock, sa, &addr_len) != 0) { + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + + int ret = connect(sock->fd, sa, addr_len); + if (ret == 0) { + sock->state = SOCKET_STATE_CONNECTED; + SHM_LOG_DEBUG("socket_start_connect: success!"); + } else if (errno == ECONNREFUSED) { + SHM_LOG_DEBUG("socket_start_connect: refused retry time:" << sock->refused_retries); + if (++sock->refused_retries >= RETRY_REFUSED_TIMES) { + SHM_LOG_ERROR("exceeded refused retries"); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + usleep(SLEEP_INT); + } else if (errno == ETIMEDOUT) { + SHM_LOG_DEBUG("socket_start_connect: timeout retry time:" << sock->timeout_retries); + if (++sock->timeout_retries >= RETRY_TIMEDOUT_TIMES) { + SHM_LOG_ERROR("exceeded timeout retries"); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + usleep(SLEEP_INT); + } else { + SHM_LOG_ERROR("connect failed: " << strerror(errno)); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + SHM_LOG_DEBUG("socket_start_connect: end!"); + + return SHMEM_SUCCESS; +} + + +static int socket_finalize_connect(socket_t* sock) { + SHM_LOG_DEBUG("socket_finalize_connect socket_finalize_connect: start!"); + if (sock->state != SOCKET_STATE_CONNECTED) { + SHM_LOG_ERROR("socket_finalize_connect: invalid state " << sock->state); + return SHMEM_BOOTSTRAP_ERROR; + } + + int sent = 0; + if (socket_progress(SOCKET_TYPE_SEND, sock, &sock->magic, sizeof(sock->magic), &sent, false, false) != SHMEM_SUCCESS) { + return SHMEM_BOOTSTRAP_ERROR; + } + if (sent == 0) return SHMEM_SUCCESS; + if (socket_wait(SOCKET_TYPE_SEND, sock, &sock->magic, sizeof(sock->magic), &sent, false, false) != SHMEM_SUCCESS) { + return SHMEM_BOOTSTRAP_ERROR; + } + + sent = 0; + if (socket_wait(SOCKET_TYPE_SEND, sock, &sock->type, sizeof(sock->type), &sent, false, false) != SHMEM_SUCCESS) { + return SHMEM_BOOTSTRAP_ERROR; + } + SHM_LOG_DEBUG("socket_finalize_connect socket_finalize_connect: end!"); + + sock->state = SOCKET_STATE_READY; + return SHMEM_SUCCESS; +} + +static int socket_progress_state(socket_t* sock) { + if (sock == nullptr) { + SHM_LOG_ERROR("socket_progress_state: null socket"); + return SHMEM_BOOTSTRAP_ERROR; + } + + if (sock->state == SOCKET_STATE_ACCEPTING) { + SHMEM_CHECK_RET(socket_try_accept(sock), "socket_try_accept failed"); + } + if (sock->state == SOCKET_STATE_ACCEPTED) { + SHMEM_CHECK_RET(socket_finalize_accept(sock), "socket_finalize_accept failed"); + } + if (sock->state == SOCKET_STATE_CONNECTING) { + SHMEM_CHECK_RET(socket_start_connect(sock), "socket_start_connect failed"); + } + + if (sock->state == SOCKET_STATE_CONNECTED) { + SHMEM_CHECK_RET(socket_finalize_connect(sock), "socket_finalize_connect failed"); + } + + return SHMEM_SUCCESS; +} + +int socket_connect(socket_t* sock) { + if (sock == nullptr) { + SHM_LOG_ERROR("socket_connect: NULL socket"); + return SHMEM_BOOTSTRAP_ERROR; + } + if (sock->fd == -1) { + SHM_LOG_ERROR("socket_connect: invalid fd (-1)"); + return SHMEM_BOOTSTRAP_ERROR; + } + + if (sock->state != SOCKET_STATE_CREATED) { + SHM_LOG_ERROR("socket_connect: invalid state " << sock->state << " (expected CREATED)"); + return SHMEM_BOOTSTRAP_ERROR; + } + + const int one = 1; + // Disabling the Nagle algorithm + if (setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)) < 0) { + SHM_LOG_ERROR("setsockopt TCP_NODELAY failed: " << strerror(errno)); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + + sock->state = SOCKET_STATE_CONNECTING; + SHM_LOG_DEBUG("socket_connect: start!"); + do { + if (socket_progress_state(sock) != SHMEM_SUCCESS) { + return SHMEM_BOOTSTRAP_ERROR; + } + } while (sock->state == SOCKET_STATE_CONNECTING || + sock->state == SOCKET_STATE_CONNECTED); + + switch (sock->state) { + case SOCKET_STATE_READY: + return SHMEM_SUCCESS; + case SOCKET_STATE_ERROR: + return SHMEM_BOOTSTRAP_ERROR; + default: + return SHMEM_BOOTSTRAP_ERROR; + } +} + +int socket_accept(socket_t* client_sock, socket_t* listen_sock) { + if (listen_sock == nullptr || client_sock == nullptr) { + SHM_LOG_ERROR("socket_accept: NULL socket"); + return SHMEM_BOOTSTRAP_ERROR; + } + + if (listen_sock->state != SOCKET_STATE_LISTENING) { + SHM_LOG_ERROR("socket_accept: listen socket state " << listen_sock->state << " (expected LISTENING)"); + return SHMEM_BOOTSTRAP_ERROR; + } + + if (client_sock->accept_fd == -1) { + client_sock->addr = listen_sock->addr; + client_sock->magic = listen_sock->magic; + client_sock->type = listen_sock->type; + client_sock->refused_retries = 0; + client_sock->timeout_retries = 0; + client_sock->accept_fd = listen_sock->fd; + client_sock->fd = -1; + client_sock->state = SOCKET_STATE_ACCEPTING; + } + SHM_LOG_DEBUG("socket_accept: start!"); + do { + if (socket_progress_state(client_sock) != SHMEM_SUCCESS) { + return SHMEM_BOOTSTRAP_ERROR; + } + } while (client_sock->state == SOCKET_STATE_ACCEPTING || + client_sock->state == SOCKET_STATE_ACCEPTED); + + switch (client_sock->state) { + case SOCKET_STATE_READY: + return SHMEM_SUCCESS; + case SOCKET_STATE_ERROR: + return SHMEM_BOOTSTRAP_ERROR; + default: + return SHMEM_BOOTSTRAP_ERROR; + } +} + +int socket_init(socket_t* sock, socket_type_t type, uint64_t magic, const sockaddr_t* init_addr) { + if (sock == nullptr) { + SHM_LOG_ERROR("socket_init: NULL socket"); + return SHMEM_BOOTSTRAP_ERROR; + } + SHM_LOG_DEBUG("socket_init: start"); + memset(sock, 0, sizeof(socket_t)); + sock->fd = -1; + sock->accept_fd = -1; + sock->state = SOCKET_STATE_CREATED; + sock->type = type; + sock->magic = magic; + sock->refused_retries = 0; + sock->timeout_retries = 0; + + if (init_addr != nullptr) { + int family; + if (init_addr->type == ADDR_IPv4) { + family = AF_INET; + memcpy(&sock->addr.addr.addr4, &init_addr->addr.addr4, sizeof(struct sockaddr_in)); + } else if (init_addr->type == ADDR_IPv6) { + family = AF_INET6; + memcpy(&sock->addr.addr.addr6, &init_addr->addr.addr6, sizeof(struct sockaddr_in6)); + } else { + SHM_LOG_ERROR("socket_init: unsupported address type " << init_addr->type); + return SHMEM_BOOTSTRAP_ERROR; + } + sock->addr.type = init_addr->type; + + sock->fd = socket(family, SOCK_STREAM, 0); + if (sock->fd == -1) { + SHM_LOG_ERROR("socket_init: create socket failed: " << strerror(errno)); + return SHMEM_BOOTSTRAP_ERROR; + } + } else { + SHM_LOG_DEBUG("socket_init: init_addr is null"); + memset(&sock->addr, 0, sizeof(sock->addr)); + sock->addr.type = ADDR_IPv4; + } + + // set blocking + if (sock->fd >= 0) { + int32_t value = 1; + if ((value = fcntl(sock->fd, F_GETFL)) == -1) { + SHM_LOG_ERROR("sock: " << sock->fd <<" failed to get control value"); + return SHMEM_BOOTSTRAP_ERROR; + } + int new_flags = value & ~O_NONBLOCK; + if (fcntl(sock->fd, F_SETFL, new_flags) == -1) { + SHM_LOG_ERROR("sock: " << sock->fd << "Failed to set control value of link"); + return SHMEM_BOOTSTRAP_ERROR; + } + } + + SHM_LOG_DEBUG("socket_init: success"); + return SHMEM_SUCCESS; +} \ No newline at end of file diff --git a/src/modules/bootstrap/socket/uid_socket.h b/src/modules/bootstrap/socket/uid_socket.h new file mode 100644 index 00000000..faab6093 --- /dev/null +++ b/src/modules/bootstrap/socket/uid_socket.h @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef SHMEM_SOCKET_H +#define SHMEM_SOCKET_H + +#include "host/shmem_host_def.h" +#include "common/shmemi_logger.h" +#include "common/shmemi_host_types.h" +#include "internal/host/shmemi_host_def.h" +#include "bootstrap/shmemi_bootstrap.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define MAX_IF_NAME_SIZE 16 +#define SOCKET_TYPE_SEND 0 +#define SOCKET_TYPE_RECV 1 + +#define RETRY_REFUSED_TIMES 50 +#define RETRY_TIMEDOUT_TIMES 50 +#define SLEEP_INT 1000 // 重试间隔(微秒) + +#define SOCKET_BACKLOG 16384 + +typedef enum { + // 初始状态:套接字刚创建,未执行任何操作 + SOCKET_STATE_CREATED, + + // 服务器端状态 + SOCKET_STATE_BOUND, // 已绑定地址(bind 成功) + SOCKET_STATE_LISTENING, // 正在监听连接(listen 成功) + SOCKET_STATE_ACCEPTING, // 正在等待接受连接(准备调用 accept) + SOCKET_STATE_ACCEPTED, // 已接受连接(accept 成功,未完成验证) + + // 客户端状态 + SOCKET_STATE_CONNECTING, // 正在发起连接(connect 调用中) + SOCKET_STATE_CONNECTED, // 已建立连接(未完成验证) + + // 公共状态 + SOCKET_STATE_READY, // 连接已验证,可进行数据传输(最终就绪状态) + SOCKET_STATE_ERROR, // 发生错误 + SOCKET_STATE_CLOSED // 已关闭 +} socket_state_t; + +typedef enum { + SOCKET_TYPE_BOOTSTRAP, // 用于初始化信息交换 + SOCKET_TYPE_DATA // 用于实际数据传输 +} socket_type_t; + +typedef struct { + int fd; // 套接字fd + int accept_fd; // 监听用fd,初始化-1 + sockaddr_t addr; // 存储地址信息(socket_init阶段初始化) + socket_state_t state; + uint64_t magic; + socket_type_t type; + int refused_retries; // 连接被拒绝重试计数 + int timeout_retries; // 超时重试计数 +} socket_t; + +struct bootstrap_root_args { + socket_t* listen_sock; + uint64_t magic; + int version; +}; + +// 其他内部结构体 +typedef struct { + int rank; + int nranks; + sockaddr_t ext_addr_listen; + sockaddr_t ext_address_listen_root; +} bootstrap_ext_info; + +struct bootstrap_netstate { + char bootstrap_netifname[MAX_IF_NAME_SIZE + 1]; /* Socket Interface Name */ + sockaddr_t bootstrap_netifaddr; /* Socket Interface Address */ + int bootstrap_netinitdone = 0; /* Socket Interface Init Status */ + pthread_mutex_t bootstrap_netlock = PTHREAD_MUTEX_INITIALIZER; /* Socket Interface Lock */ + pthread_t bootstrap_root; /* Socket Root Thread for phoning root to non-root peers */ +}; + +typedef struct { + int rank; + int nranks; + uint64_t magic; + socket_t listen_sock; + socket_t ring_send_sock; + socket_t ring_recv_sock; + sockaddr_t* peer_addrs; +} uid_bootstrap_state; + +int socket_init(socket_t* sock, socket_type_t type, uint64_t magic, const sockaddr_t* init_addr); +int socket_listen(socket_t* sock); +int socket_connect(socket_t* sock); +int socket_accept(socket_t* client_sock, socket_t* listen_sock); +int socket_send(socket_t* sock, void* ptr, int size); +int socket_recv(socket_t* sock, void* ptr, int size); +int socket_close(socket_t* sock); +int socket_get_sainfo(socket_t* sock, sockaddr* sa, socklen_t* addr_len); + +#ifdef __cplusplus +} +#endif +#endif // SHMEM_SOCKET_H \ No newline at end of file diff --git a/src/modules/bootstrap/socket/uid_utils.h b/src/modules/bootstrap/socket/uid_utils.h new file mode 100644 index 00000000..5159b9bc --- /dev/null +++ b/src/modules/bootstrap/socket/uid_utils.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef SHMEM_UID_UTILS_H +#define SHMEM_UID_UTILS_H + + +#include "common/shmemi_logger.h" +#include // malloc, free +#include // memset + +template +inline int bootstrap_calloc(T** ptr, size_t nelem, const char* file, int line) { + if (ptr == nullptr || nelem == 0) { // 校验输入:指针为空或元素数为 0 均为无效 + SHM_LOG_ERROR("Invalid arguments: ptr=" << ptr << ", nelem=" << nelem + << " (" << file << ":" << line << ")"); + return SHMEM_BOOTSTRAP_ERROR; + } + size_t total_size = nelem * sizeof(T); // 计算总内存大小 + void* p = malloc(total_size); + if (p == nullptr) { + SHM_LOG_ERROR("Allocation failed: " << total_size << " bytes (nelem=" << nelem + << ") at " << file << ":" << line); + return SHMEM_BOOTSTRAP_ERROR; + } + + memset(p, 0, total_size); // 内存清零 + *ptr = static_cast(p); // 类型转换,赋值给输出指针 + + // 调试日志:输出分配信息 + SHM_LOG_DEBUG("Allocated " << total_size << " bytes (" << nelem + << " elements of " << sizeof(T) << " bytes) at " + << static_cast(p) << " (" << file << ":" << line << ")"); + return SHMEM_SUCCESS; +} +#define SHMEM_BOOTSTRAP_CALLOC(ptr, nelem) \ + bootstrap_calloc((ptr), (nelem), __FILE__, __LINE__) + + +#define SHMEM_BOOTSTRAP_PTR_FREE(ptr) \ + do { \ + if ((ptr) != NULL) { \ + free(ptr); \ + } \ + } while (0) + +#endif //SHMEM_UID_UTILS_H \ No newline at end of file -- Gitee From 32a957cc90a71a2a87839acf90bec1b0756035cf Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Mon, 17 Nov 2025 10:23:01 +0800 Subject: [PATCH 54/74] Add RDMA option and fix compile issue. --- CMakeLists.txt | 2 + examples/CMakeLists.txt | 13 +++- examples/rdma_demo/main.cpp | 1 + examples/rdma_perftest/main.cpp | 4 ++ scripts/build.sh | 4 ++ src/CMakeLists.txt | 1 - .../default/shmemi_init_default.cpp | 3 +- .../default/shmemi_init_default.h | 1 + src/host/transport/shmemi_transport.cpp | 71 +++++++++++-------- src/host/transport/shmemi_transport.h | 2 +- .../transport/rdma/device_qp_manager.cpp | 2 - .../transport/rdma/device_qp_manager.h | 1 - 12 files changed, 65 insertions(+), 40 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4aef7718..a19109db 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,6 +36,8 @@ message(STATUS "USE_UNIT_TEST:${USE_UNIT_TEST}") message(STATUS "USE_EXAMPLES:${USE_EXAMPLES}") option(USE_FUZZ_TEST "USE_FUZZ_TEST" OFF) message(STATUS "USE_FUZZ_TEST:${USE_FUZZ_TEST}") +option(SHMEM_RDMA_SUPPORT "SHMEM_RDMA_SUPPORT" OFF) +message(STATUS "SHMEM_RDMA_SUPPORT:${SHMEM_RDMA_SUPPORT}") set(CMAKE_COMPILER bisheng) set(CMAKE_C_COMPILER ${CMAKE_COMPILER}) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 4fd3339d..e3a43512 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -55,8 +55,15 @@ endfunction() foreach(EXAMPLE allgather # matmul_allreduce - rdma_perftest - rdma_demo ) add_subdirectory(${EXAMPLE}) -endforeach() \ No newline at end of file +endforeach() + +if(SHMEM_RDMA_SUPPORT) + foreach(EXAMPLE + rdma_perftest + rdma_demo + ) + add_subdirectory(${EXAMPLE}) +endforeach() +endif() \ No newline at end of file diff --git a/examples/rdma_demo/main.cpp b/examples/rdma_demo/main.cpp index fd14fd7f..30771125 100644 --- a/examples/rdma_demo/main.cpp +++ b/examples/rdma_demo/main.cpp @@ -37,6 +37,7 @@ int test_shmem_team_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size shmem_init_attr_t *attributes; status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); uint8_t *ptr = (uint8_t*)shmem_malloc(1024); diff --git a/examples/rdma_perftest/main.cpp b/examples/rdma_perftest/main.cpp index 446b22b6..5c969464 100644 --- a/examples/rdma_perftest/main.cpp +++ b/examples/rdma_perftest/main.cpp @@ -45,6 +45,7 @@ int test_shmem_rdma_highlevel_put_pingpong_latency(int rank_id, int n_ranks, uin shmem_init_attr_t *attributes; status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); uint64_t fftsConfig = shmemx_get_ffts_config(); @@ -90,6 +91,7 @@ int test_shmem_rdma_postsend_cost(int rank_id, int n_ranks, uint64_t local_mem_s shmem_init_attr_t *attributes; status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); uint64_t fftsConfig = shmemx_get_ffts_config(); @@ -133,6 +135,7 @@ int test_shmem_rdma_highlevel_put_bw(int rank_id, int n_ranks, uint64_t local_me shmem_init_attr_t *attributes; status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); uint64_t fftsConfig = shmemx_get_ffts_config(); @@ -174,6 +177,7 @@ int test_shmem_rdma_mte_put_bw(int rank_id, int n_ranks, uint64_t local_mem_size shmem_init_attr_t *attributes; status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); shmem_mte_set_ub_params(0, 128 * 1024, 0); diff --git a/scripts/build.sh b/scripts/build.sh index 9ce9d8ee..284a0342 100644 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -275,6 +275,10 @@ while [[ $# -gt 0 ]]; do COMPILE_OPTIONS="${COMPILE_OPTIONS} -DUSE_EXAMPLES=ON" shift ;; + -enable_rdma) + COMPILE_OPTIONS="${COMPILE_OPTIONS} -DSHMEM_RDMA_SUPPORT=ON" + shift + ;; -python_extension) PYEXPAND_TYPE=ON shift diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 93b37f69..f778cd2b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -100,7 +100,6 @@ install(TARGETS shmem_bootstrap_uid LIBRARY DESTINATION lib) -set(SHMEM_RDMA_SUPPORT ON) if(SHMEM_RDMA_SUPPORT) add_library( shmem_transport_rdma SHARED diff --git a/src/host/init/init_backends/default/shmemi_init_default.cpp b/src/host/init/init_backends/default/shmemi_init_default.cpp index 908f0713..2d92ebb6 100644 --- a/src/host/init/init_backends/default/shmemi_init_default.cpp +++ b/src/host/init/init_backends/default/shmemi_init_default.cpp @@ -14,6 +14,7 @@ shmemi_init_default::shmemi_init_default(shmem_init_attr_t *attr) { mype = attr->my_rank; npes = attr->n_ranks; + option_attr_ = attr->option_attr; auto status = aclrtGetDevice(&device_id); if (status != 0) { SHM_LOG_ERROR("Get Device_id error"); @@ -86,7 +87,7 @@ int shmemi_init_default::release_heap() int shmemi_init_default::transport_init(shmemi_device_host_state_t &g_state) { - SHMEM_CHECK_RET(shmemi_transport_init(g_state)); // mte init && rdma init + SHMEM_CHECK_RET(shmemi_transport_init(g_state, option_attr_)); // mte init && rdma init SHMEM_CHECK_RET(shmemi_build_transport_map(g_state)); // build transport_map SHMEM_CHECK_RET(shmemi_transport_setup_connections(g_state)); // connect_endpoints by transpost_map return SHMEM_SUCCESS; diff --git a/src/host/init/init_backends/default/shmemi_init_default.h b/src/host/init/init_backends/default/shmemi_init_default.h index ed6295b5..d4a2f72c 100644 --- a/src/host/init/init_backends/default/shmemi_init_default.h +++ b/src/host/init/init_backends/default/shmemi_init_default.h @@ -51,6 +51,7 @@ private: // heap_obj shmem_symmetric_heap *heap_obj = nullptr; + shmem_init_optional_attr_t option_attr_; }; #endif // SHMEMI_INIT_NORMAL_H \ No newline at end of file diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp index e853c09f..cec439fe 100644 --- a/src/host/transport/shmemi_transport.cpp +++ b/src/host/transport/shmemi_transport.cpp @@ -19,8 +19,9 @@ uint64_t *host_hash_list; shmemi_host_state_t g_host_state; -int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state) { - g_host_state.num_choosen_transport = 2; +int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state, shmem_init_optional_attr_t& option_attr) { + // Initialize MTE by default + g_host_state.num_choosen_transport = 1; g_host_state.transport_map = (int *)calloc(g_state.npes * g_state.npes, sizeof(int)); g_host_state.pe_info = (shmemi_transport_pe_info *)calloc(g_state.npes, sizeof(shmemi_transport_pe_info)); @@ -48,32 +49,35 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state) { shmemi_transport_pe_info_t my_info; my_info.pe = g_state.mype; my_info.host_hash = g_state.host_hash; - - int32_t device_id; - SHMEM_CHECK_RET(aclrtGetDevice(&device_id)); - my_info.dev_id = device_id; - int32_t logicDeviceId = -1; - rtLibLoader& loader = rtLibLoader::getInstance(); - if (loader.isLoaded()) { - loader.getLogicDevId(device_id, &logicDeviceId); - } - g_host_state.choosen_transports[1].logical_dev_id = logicDeviceId; - g_host_state.choosen_transports[1].dev_id = device_id; - + // AllGather All pe's host info g_boot_handle.allgather((void *)&my_info, g_host_state.pe_info, sizeof(shmemi_transport_pe_info_t), &g_boot_handle); - SHMEM_CHECK_RET(init_mte_fn(&g_host_state.choosen_transports[0], &g_state)); - transport_init_func init_rdma_fn; - init_rdma_fn = (transport_init_func)dlsym(transport_rdma_lib, "shmemi_rdma_init"); - if (!init_rdma_fn) { - dlclose(transport_rdma_lib); - transport_rdma_lib = NULL; - SHM_LOG_ERROR("Unable to get info from " << "shmem_transport_rdma.so" << "."); - return SHMEM_INVALID_VALUE; + // If enable RDMA + if (option_attr.data_op_engine_type & SHMEM_DATA_OP_ROCE) { + g_host_state.num_choosen_transport++; + int32_t device_id; + SHMEM_CHECK_RET(aclrtGetDevice(&device_id)); + my_info.dev_id = device_id; + int32_t logicDeviceId = -1; + rtLibLoader& loader = rtLibLoader::getInstance(); + if (loader.isLoaded()) { + loader.getLogicDevId(device_id, &logicDeviceId); + } + g_host_state.choosen_transports[g_host_state.num_choosen_transport - 1].logical_dev_id = logicDeviceId; + g_host_state.choosen_transports[g_host_state.num_choosen_transport - 1].dev_id = device_id; + + transport_init_func init_rdma_fn; + init_rdma_fn = (transport_init_func)dlsym(transport_rdma_lib, "shmemi_rdma_init"); + if (!init_rdma_fn) { + dlclose(transport_rdma_lib); + transport_rdma_lib = NULL; + SHM_LOG_ERROR("Unable to get info from " << "shmem_transport_rdma.so" << "."); + return SHMEM_INVALID_VALUE; + } + SHMEM_CHECK_RET(init_rdma_fn(&g_host_state.choosen_transports[g_host_state.num_choosen_transport - 1], &g_state)); } - SHMEM_CHECK_RET(init_rdma_fn(&g_host_state.choosen_transports[1], &g_state)); return SHMEM_SUCCESS; } @@ -131,8 +135,11 @@ int32_t shmemi_transport_setup_connections(shmemi_device_host_state_t &g_state) } t.connect_peers(&t, mte_peer_list, mte_peer_num, &g_state); - t = g_host_state.choosen_transports[1]; - t.connect_peers(&t, mte_peer_list, mte_peer_num, &g_state); + + if (g_host_state.num_choosen_transport > 1) { + t = g_host_state.choosen_transports[1]; + t.connect_peers(&t, mte_peer_list, mte_peer_num, &g_state); + } return 0; } @@ -147,13 +154,15 @@ int32_t shmemi_transport_finalize() { dlclose(transport_mte_lib); transport_mte_lib = NULL; } - - t = g_host_state.choosen_transports[1]; - t.finalize(&t, &g_state); - if (transport_rdma_lib != NULL) { - dlclose(transport_rdma_lib); - transport_rdma_lib = NULL; + if (g_host_state.num_choosen_transport > 1) { + t = g_host_state.choosen_transports[1]; + t.finalize(&t, &g_state); + + if (transport_rdma_lib != NULL) { + dlclose(transport_rdma_lib); + transport_rdma_lib = NULL; + } } return 0; } diff --git a/src/host/transport/shmemi_transport.h b/src/host/transport/shmemi_transport.h index 7ddb4da5..b958c908 100644 --- a/src/host/transport/shmemi_transport.h +++ b/src/host/transport/shmemi_transport.h @@ -12,7 +12,7 @@ typedef int(*transport_init_func)(shmemi_transport_t *transport, shmemi_device_host_state_t *g_state); -int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state); +int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state, shmem_init_optional_attr_t &option_attr); int32_t shmemi_build_transport_map(shmemi_device_host_state_t &g_state); diff --git a/src/modules/transport/rdma/device_qp_manager.cpp b/src/modules/transport/rdma/device_qp_manager.cpp index df2330b7..dc37f32a 100644 --- a/src/modules/transport/rdma/device_qp_manager.cpp +++ b/src/modules/transport/rdma/device_qp_manager.cpp @@ -391,7 +391,6 @@ int DeviceQpManager::WaitConnectionsReady(std::unordered_mapsecond.remoteIp.s_addr, it->first); } - std::this_thread::sleep_for(std::chrono::milliseconds(1)); auto role = (&connections == &clientConnections_) ? 1 : 0; auto ret = DlHccpApi::RaGetSockets(role, socketInfos.data(), socketInfos.size(), successCount); if (ret != 0) { @@ -481,7 +480,6 @@ int DeviceQpManager::CreateQpWaitingReady(std::unordered_map #include #include "dl_hccp_api.h" -#include class DeviceQpManager { public: -- Gitee From 3499276acd861c157c47aef17d739f30c6dd845a Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Wed, 26 Nov 2025 14:55:58 +0800 Subject: [PATCH 55/74] RDMA transport minor fix. --- src/host/transport/shmemi_transport.cpp | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp index cec439fe..99475a73 100644 --- a/src/host/transport/shmemi_transport.cpp +++ b/src/host/transport/shmemi_transport.cpp @@ -65,8 +65,8 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state, shmem_init_op if (loader.isLoaded()) { loader.getLogicDevId(device_id, &logicDeviceId); } - g_host_state.choosen_transports[g_host_state.num_choosen_transport - 1].logical_dev_id = logicDeviceId; - g_host_state.choosen_transports[g_host_state.num_choosen_transport - 1].dev_id = device_id; + g_host_state.choosen_transports[1].logical_dev_id = logicDeviceId; + g_host_state.choosen_transports[1].dev_id = device_id; transport_init_func init_rdma_fn; init_rdma_fn = (transport_init_func)dlsym(transport_rdma_lib, "shmemi_rdma_init"); @@ -76,7 +76,7 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state, shmem_init_op SHM_LOG_ERROR("Unable to get info from " << "shmem_transport_rdma.so" << "."); return SHMEM_INVALID_VALUE; } - SHMEM_CHECK_RET(init_rdma_fn(&g_host_state.choosen_transports[g_host_state.num_choosen_transport - 1], &g_state)); + SHMEM_CHECK_RET(init_rdma_fn(&g_host_state.choosen_transports[1], &g_state)); } return SHMEM_SUCCESS; @@ -137,8 +137,22 @@ int32_t shmemi_transport_setup_connections(shmemi_device_host_state_t &g_state) t.connect_peers(&t, mte_peer_list, mte_peer_num, &g_state); if (g_host_state.num_choosen_transport > 1) { + int *rdma_peer_list; + int rdma_peer_num = 0; + rdma_peer_list = (int *)calloc(g_state.npes, sizeof(int)); + + int local_offset = g_state.mype * g_state.npes; + for (int i = 0; i < g_state.npes; i++) { + if (i == g_state.mype) + continue; + if (g_host_state.transport_map[local_offset + i] & 2) { + shmemi_transport_pe_info_t *peer_info = (g_host_state.pe_info + i); + rdma_peer_list[rdma_peer_num] = peer_info->dev_id; + ++rdma_peer_num; + } + } t = g_host_state.choosen_transports[1]; - t.connect_peers(&t, mte_peer_list, mte_peer_num, &g_state); + t.connect_peers(&t, rdma_peer_list, rdma_peer_num, &g_state); } return 0; -- Gitee From 0ced6c7a795b98f613f6838e5ae33616c2532084 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Thu, 27 Nov 2025 10:27:16 +0800 Subject: [PATCH 56/74] Support build MTE on 910_93 && Support 910B 16 cards topo --- examples/allgather/README.md | 14 ++++- examples/allgather/scripts/data_gen.py | 2 + src/host/mem/shmemi_heap.cpp | 79 ++++++++++++------------- src/host/mem/shmemi_heap.h | 22 +++++-- src/host/transport/shmemi_transport.cpp | 29 +++++---- src/modules/transport/shmemi_mte.cpp | 32 ++++++++-- 6 files changed, 115 insertions(+), 63 deletions(-) diff --git a/examples/allgather/README.md b/examples/allgather/README.md index 039fb7b3..1f291184 100644 --- a/examples/allgather/README.md +++ b/examples/allgather/README.md @@ -6,4 +6,16 @@ # 完成RANKS卡下的allgather同时验证精度,性能数据会输出在result.csv中。 # RANKS : [2, 4, 8] # TYPES : [int, int32_t, float16_t, bfloat16_t] - bash run.sh -ranks ${RANKS} -type ${TYPES} \ No newline at end of file + bash run.sh -ranks ${RANKS} -type ${TYPES} + +跨机使用方式: +1.在shmem/目录编译: + bash scripts/build.sh + +2.在两台机器上shmem/examples/allgather目录中分别生成golden数据: + rm -rf ./golden + mkdir -p golden + python3 ./scripts/data_gen.py 8 "int" + +3. 在其中一台机器上shmem/examples/allgather执行(ip_host1和ip_host2为各自机器的ip地址, PROJECT_ROOT为shmem/目录) + mpirun -host ip_host1:4,ip_host2:4 -np 8 ${PROJECT_ROOT}/build/bin/allgather \ No newline at end of file diff --git a/examples/allgather/scripts/data_gen.py b/examples/allgather/scripts/data_gen.py index aae577c8..c5671434 100644 --- a/examples/allgather/scripts/data_gen.py +++ b/examples/allgather/scripts/data_gen.py @@ -12,6 +12,8 @@ import os from ml_dtypes import bfloat16 +# Set seed for multi-node situation +np.random.seed(42) def gen_random_data(size, dtype): return np.random.uniform(low=0.0, high=10.0, size=size).astype(dtype) diff --git a/src/host/mem/shmemi_heap.cpp b/src/host/mem/shmemi_heap.cpp index f45871cb..9c30b860 100644 --- a/src/host/mem/shmemi_heap.cpp +++ b/src/host/mem/shmemi_heap.cpp @@ -14,9 +14,8 @@ shmem_symmetric_heap::shmem_symmetric_heap(int pe_id, int pe_size, int dev_id): mype(pe_id), npes(pe_size), device_id(dev_id) { - physical_handle_list.resize(pe_size); - share_handle_list.resize(pe_size); pid_list.resize(pe_size); + sdid_list.resize(pe_size); memprop.handleType = ACL_MEM_HANDLE_TYPE_NONE; memprop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED; @@ -50,8 +49,11 @@ int shmem_symmetric_heap::reserve_heap(size_t size) int shmem_symmetric_heap::export_memory() { - // Get share_handle - SHMEM_CHECK_RET(aclrtMemExportToShareableHandle(local_handle, memprop.handleType, 0, &share_handle)); + // Get memory_name + char memoryName[IPC_NAME_SIZE] = {}; + SHMEM_CHECK_RET(rtIpcSetMemoryName(peer_heap_base_p2p_[mype], alloc_size, memoryName, IPC_NAME_SIZE)); + + memory_name = memoryName; return SHMEM_SUCCESS; } @@ -59,6 +61,12 @@ int shmem_symmetric_heap::export_pid() { // Get local pid SHMEM_CHECK_RET(aclrtDeviceGetBareTgid(&my_pid)); + + // Get Sdid + const int rtModuleTypeSystem = 0; + const int infoTypeSdid = 26; + SHMEM_CHECK_RET(rtGetDeviceInfo(device_id, rtModuleTypeSystem, infoTypeSdid, &my_sdid)); + return SHMEM_SUCCESS; } @@ -67,36 +75,34 @@ int shmem_symmetric_heap::import_pid() // Get all pids g_boot_handle.allgather(&my_pid, pid_list.data(), 1 * sizeof(int), &g_boot_handle); - // Add Pid into white list - std::vector share_pid = {}; + // Get all sdids + g_boot_handle.allgather(&my_sdid, sdid_list.data(), 1 * sizeof(int64_t), &g_boot_handle); + + // Set Sdid and pid into Shared Memory + int local_offset = mype * npes; for (int i = 0; i < npes; i++) { - if (i == mype) { - continue; - } - // Check if p2p connected - if (peer_heap_base_p2p_[i] == NULL) { + if (i == mype || !(g_host_state.transport_map[local_offset + i] & 0x1)) { continue; } - share_pid.push_back(pid_list[i]); + SHMEM_CHECK_RET(rtSetIpcMemorySuperPodPid(memory_name.c_str(), sdid_list[i], &pid_list[i], 1)); } - SHMEM_CHECK_RET(aclrtMemSetPidToShareableHandle(share_handle, share_pid.data(), npes - 1)); return SHMEM_SUCCESS; } int shmem_symmetric_heap::import_memory() { - g_boot_handle.allgather(&share_handle, share_handle_list.data(), 1 * sizeof(uint64_t), &g_boot_handle); + g_boot_handle.allgather(memory_name.c_str(), names, IPC_NAME_SIZE, &g_boot_handle); + + static std::mutex mut; + std::lock_guard lock(mut); + + int local_offset = mype * npes; for (int i = 0; i < npes; i++) { - if (i == mype) { - physical_handle_list[i] = local_handle; + if (i == mype || !(g_host_state.transport_map[local_offset + i] & 0x1)) { continue; } - // Check if p2p connected - if (peer_heap_base_p2p_[i] == NULL) { - continue; - } - SHMEM_CHECK_RET(aclrtMemImportFromShareableHandle(share_handle_list[i], device_id, &(physical_handle_list[i]))); + SHMEM_CHECK_RET(rtIpcOpenMemory(reinterpret_cast(&peer_heap_base_p2p_[i]), names[i])); } return SHMEM_SUCCESS; @@ -104,39 +110,30 @@ int shmem_symmetric_heap::import_memory() int shmem_symmetric_heap::setup_heap() { - // MTE p2p_heap_base_ reserve - int local_offset = mype * npes; - for (int i = 0; i < npes; i++) { - if (i == mype) - continue; - - if (g_host_state.transport_map[local_offset + i] & 1) { - SHMEM_CHECK_RET(aclrtReserveMemAddress(&(peer_heap_base_p2p_[i]), alloc_size, 0, nullptr, 1)); - } - } - SHMEM_CHECK_RET(export_memory()); SHMEM_CHECK_RET(export_pid()); SHMEM_CHECK_RET(import_pid()); SHMEM_CHECK_RET(import_memory()); - // Shareable Handle Map - for (int i = 0; i < npes; i++) { - // Check if p2p connected - if (i != mype && peer_heap_base_p2p_[i] != NULL) { - SHMEM_CHECK_RET(aclrtMapMem(peer_heap_base_p2p_[i], alloc_size, 0, physical_handle_list[i], 0)); - } - } return SHMEM_SUCCESS; } int shmem_symmetric_heap::remove_heap() { for (int i = 0; i < npes; i++) { - if (peer_heap_base_p2p_[i] != NULL) { - SHMEM_CHECK_RET(aclrtUnmapMem(peer_heap_base_p2p_[i])); + if (i == mype || peer_heap_base_p2p_[i] == NULL) { + continue; } + SHMEM_CHECK_RET(rtIpcCloseMemory(static_cast(peer_heap_base_p2p_[i]))); + peer_heap_base_p2p_[i] = NULL; } + + // This barrier is necessary, otherwise Unmap will fail. + g_boot_handle.barrier(&g_boot_handle); + + SHMEM_CHECK_RET(rtIpcDestroyMemoryName(memory_name.c_str())); + + SHMEM_CHECK_RET(aclrtUnmapMem(peer_heap_base_p2p_[mype])); return SHMEM_SUCCESS; } diff --git a/src/host/mem/shmemi_heap.h b/src/host/mem/shmemi_heap.h index 68385d0a..4dfbc4b7 100644 --- a/src/host/mem/shmemi_heap.h +++ b/src/host/mem/shmemi_heap.h @@ -13,12 +13,21 @@ #include #include #include +#include #include +#include "internal/host_device/shmemi_types.h" #include "common/shmemi_host_types.h" #include "bootstrap/shmemi_bootstrap.h" +#include "runtime/kernel.h" +#include "runtime/mem.h" +#include "runtime/dev.h" +#include "runtime/rt_ffts.h" + +const int IPC_NAME_SIZE = 65; + class shmem_symmetric_heap { public: shmem_symmetric_heap() {} @@ -53,15 +62,18 @@ private: // handle used to map local virtual ptr aclrtPhysicalMemProp memprop; aclrtDrvMemHandle local_handle; - std::vector physical_handle_list = {}; - // pid used to set white list + // names used to share memory + std::string memory_name; + char names[SHMEM_MAX_RANKS][IPC_NAME_SIZE]; + + // pid set to memory_name int32_t my_pid = 0UL; std::vector pid_list = {}; - // handle used to share physical memory - uint64_t share_handle = 0UL; - std::vector share_handle_list = {}; + // sdid set to memory_name in 910_93 + int64_t my_sdid = 0UL; + std::vector sdid_list = {}; }; diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp index 99475a73..9fde6037 100644 --- a/src/host/transport/shmemi_transport.cpp +++ b/src/host/transport/shmemi_transport.cpp @@ -30,11 +30,6 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state, shmem_init_op SHM_LOG_ERROR("Transport unable to load " << "shmem_transport_mte.so" << ", err is: " << stderr); return SHMEM_INVALID_VALUE; } - transport_rdma_lib = dlopen("shmem_transport_rdma.so", RTLD_NOW); - if (!transport_rdma_lib) { - SHM_LOG_ERROR("Transport unable to load " << "shmem_transport_rdma.so" << ", err is: " << stderr); - return SHMEM_INVALID_VALUE; - } transport_init_func init_mte_fn; init_mte_fn = (transport_init_func)dlsym(transport_mte_lib, "shmemi_mte_init"); @@ -46,8 +41,12 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state, shmem_init_op } // Package my_info + int32_t device_id; + SHMEM_CHECK_RET(aclrtGetDevice(&device_id)); + shmemi_transport_pe_info_t my_info; my_info.pe = g_state.mype; + my_info.dev_id = device_id; my_info.host_hash = g_state.host_hash; // AllGather All pe's host info @@ -57,9 +56,7 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state, shmem_init_op // If enable RDMA if (option_attr.data_op_engine_type & SHMEM_DATA_OP_ROCE) { g_host_state.num_choosen_transport++; - int32_t device_id; - SHMEM_CHECK_RET(aclrtGetDevice(&device_id)); - my_info.dev_id = device_id; + int32_t logicDeviceId = -1; rtLibLoader& loader = rtLibLoader::getInstance(); if (loader.isLoaded()) { @@ -68,6 +65,12 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state, shmem_init_op g_host_state.choosen_transports[1].logical_dev_id = logicDeviceId; g_host_state.choosen_transports[1].dev_id = device_id; + transport_rdma_lib = dlopen("shmem_transport_rdma.so", RTLD_NOW); + if (!transport_rdma_lib) { + SHM_LOG_ERROR("Transport unable to load " << "shmem_transport_rdma.so" << ", err is: " << stderr); + return SHMEM_INVALID_VALUE; + } + transport_init_func init_rdma_fn; init_rdma_fn = (transport_init_func)dlsym(transport_rdma_lib, "shmemi_rdma_init"); if (!init_rdma_fn) { @@ -127,10 +130,14 @@ int32_t shmemi_transport_setup_connections(shmemi_device_host_state_t &g_state) for (int i = 0; i < g_state.npes; i++) { if (i == g_state.mype) continue; - if (g_host_state.transport_map[local_offset + i] & 1) { + /* Check if MTE connected. */ + if (g_host_state.transport_map[local_offset + i] & 0x1) { shmemi_transport_pe_info_t *peer_info = (g_host_state.pe_info + i); - mte_peer_list[mte_peer_num] = peer_info->dev_id; - ++mte_peer_num; + // Only PEs in the same Node need to build up MTE connection. + if (g_state.host_hash == peer_info->host_hash) { + mte_peer_list[mte_peer_num] = peer_info->dev_id; + ++mte_peer_num; + } } } diff --git a/src/modules/transport/shmemi_mte.cpp b/src/modules/transport/shmemi_mte.cpp index 539f7250..911dbd68 100644 --- a/src/modules/transport/shmemi_mte.cpp +++ b/src/modules/transport/shmemi_mte.cpp @@ -22,12 +22,34 @@ extern "C" { #endif int shmemi_mte_can_access_peer(int *access, shmemi_transport_pe_info_t *peer_info, shmemi_transport *t, shmemi_device_host_state_t *g_state) { - // host_id same return 1, otherwise 0 - if (g_state->host_hash == peer_info->host_hash) { - *access = 1; - } else { - *access = 0; + // origin access set to 0. + *access = 0; + + auto sName = aclrtGetSocName(); + std::string socName{sName}; + if (socName.find("Ascend910B") != std::string::npos) { // Ascend910B Topo + int64_t hccs_connected = -1; + SHMEM_CHECK_RET(rtGetPairDevicesInfo(g_state->mype, peer_info->dev_id, 0, &hccs_connected)); + + // In 910B, Flag 0 -> HCCS. + const static int SELF_FLAG = 0; + if (hccs_connected == SELF_FLAG) { + *access = 1; + } + } else if (socName.find("Ascend910_93") != std::string::npos) { // Ascend910_93 Topo + int64_t hccs_connected = -1; + /* TODO: This func now doesn't support 910_93 multiNode HCCS Check. Only Check in the same Node. */ + SHMEM_CHECK_RET(rtGetPairDevicesInfo(g_state->mype, peer_info->dev_id, 0, &hccs_connected)); + + // In 910_93, Flag 0 -> SELF, 5 -> SIO, 6 -> HCCS. + const static int SELF_FLAG = 0; + const static int SIO_FLAG = 5; + const static int HCCS_FLAG = 6; + if (hccs_connected == SELF_FLAG || hccs_connected == SIO_FLAG || hccs_connected == HCCS_FLAG) { + *access = 1; + } } + return 0; } -- Gitee From 18151fcb9340ffb119280e66397c23d12c8dd26b Mon Sep 17 00:00:00 2001 From: zhangyunqi05 Date: Fri, 28 Nov 2025 15:37:57 +0800 Subject: [PATCH 57/74] develop fix ut --- src/host/common/shmemi_host_types.h | 1 + src/host/init/shmem_init.cpp | 16 + src/host/init/shmemi_init.h | 2 + .../bootstrap/shmemi_bootstrap_uid.cpp | 28 ++ src/modules/bootstrap/socket/uid_socket.h | 2 +- tests/fuzz/device/mem/shmem_ptr_kernel.cpp | 4 +- tests/fuzz/device/sync/barrier_kernel.cpp | 8 +- tests/fuzz/device/sync/order_kernel.cpp | 4 +- tests/fuzz/device/team/team_kernel.cpp | 4 +- tests/fuzz/host/sync/barrier_host_fuzz.cpp | 8 +- .../mem/atomic_add/atomic_add_kernel.cpp | 4 +- .../device/mem/rdma_mem/rdma_mem_kernel.cpp | 16 +- .../unittest/device/mem/shmem_ptr_kernel.cpp | 4 +- .../device/mem/ub_mem/ub_mem_kernel.cpp | 2 - .../ub_non_contiguous_kernel.cpp | 2 - .../device/sync/barrier/barrier_kernel.cpp | 8 +- .../device/sync/order/order_kernel.cpp | 4 +- .../unittest/device/team/team/team_kernel.cpp | 4 +- tests/unittest/host/init/init_host_test.cpp | 300 ++++-------------- tests/unittest/host/main_test.cpp | 34 +- .../mem/atomic_add/atomic_add_host_test.cpp | 2 +- .../host/mem/gm_mem/gm_mem_host_test.cpp | 2 +- .../gm_mem_disable_L2_host_test.cpp | 2 +- .../gm_non_contiguous_host_test.cpp | 2 +- .../host/mem/rdma_mem/rdma_mem_host_test.cpp | 16 +- .../host/mem/shmem_host_get_stream_test.cpp | 4 +- .../host/mem/shmem_host_heap_test.cpp | 30 +- .../unittest/host/mem/shmem_ptr_host_test.cpp | 6 +- .../host/mem/ub_mem/ub_mem_host_test.cpp | 2 +- .../ub_non_contiguous_host_test.cpp | 2 +- .../host/sync/barrier/barrier_host_test.cpp | 22 +- .../host/team/team/team_host_test.cpp | 2 +- .../include/unittest/mem_host_direct_test.cpp | 6 +- .../unittest/mem_host_get_and_put_test.cpp | 6 +- .../unittest/mem_putmem_signal_test.cpp | 4 +- 35 files changed, 231 insertions(+), 332 deletions(-) diff --git a/src/host/common/shmemi_host_types.h b/src/host/common/shmemi_host_types.h index a5d771c5..280855b1 100644 --- a/src/host/common/shmemi_host_types.h +++ b/src/host/common/shmemi_host_types.h @@ -26,6 +26,7 @@ typedef struct shmemi_bootstrap_attr { typedef struct shmemi_bootstrap_init_ops { void *cookie; int (*get_unique_id)(void *cookit); + int (*get_unique_id_static_magic)(void *uid_info, bool is_root); } shmemi_bootstrap_init_ops_t; typedef struct shmemi_bootstrap_handle { diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index 1929bc83..0fbe1f56 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -329,6 +329,22 @@ int32_t shmem_get_uniqueid(shmemx_uniqueid_t *uid){ #endif } +int32_t shmemi_get_uniqueid_static_magic(shmemx_uniqueid_t *uid, bool is_root) { + shmem_set_log_level(shm::INFO_LEVEL); + int status = 0; + SHMEM_CHECK_RET(shmemi_options_init(), "Bootstrap failed during the preloading step."); + SHMEM_CHECK_RET(shmemi_bootstrap_pre_init(SHMEMX_INIT_WITH_UNIQUEID, &g_boot_handle), "Get uniqueid failed during the bootstrap preloading step."); + + if (g_boot_handle.pre_init_ops) { + SHMEM_CHECK_RET(g_boot_handle.pre_init_ops->get_unique_id_static_magic((void *)uid, is_root), "Get uniqueid failed during the get uniqueid step."); + } else { + SHM_LOG_ERROR("Pre_init_ops is empty, unique_id cannot be obtained."); + status = SHMEM_INVALID_PARAM; + } + return SHMEM_SUCCESS; + +} + int32_t shmem_set_log_level(int level) { // use env first, input level secondly, user may change level from env instead call func diff --git a/src/host/init/shmemi_init.h b/src/host/init/shmemi_init.h index f7d2091d..4a18da6b 100644 --- a/src/host/init/shmemi_init.h +++ b/src/host/init/shmemi_init.h @@ -26,4 +26,6 @@ int32_t shmemi_control_barrier_all(); int32_t update_device_state(void); +int32_t shmemi_get_uniqueid_static_magic(shmemx_uniqueid_t *uid, bool is_root); + #endif // SHMEMI_INIT_H diff --git a/src/modules/bootstrap/shmemi_bootstrap_uid.cpp b/src/modules/bootstrap/shmemi_bootstrap_uid.cpp index b3e4cde3..17b261d0 100644 --- a/src/modules/bootstrap/shmemi_bootstrap_uid.cpp +++ b/src/modules/bootstrap/shmemi_bootstrap_uid.cpp @@ -40,6 +40,7 @@ static const char* env_ip_port = nullptr; static const char* env_ifname = nullptr; static shmemx_bootstrap_uid_state_t shmemi_bootstrap_uid_state; static struct bootstrap_netstate priv_info; +static int static_magic_count = 1; bool is_ipv6_loopback(const struct in6_addr *addr6) { static const struct in6_addr loopback6 = IN6ADDR_LOOPBACK_INIT; @@ -936,12 +937,39 @@ int shmemi_bootstrap_get_unique_id(void* uid) { return SHMEM_SUCCESS; } +int shmemi_bootstrap_get_unique_id_static_magic(void* uid, bool is_root) { + shmemx_bootstrap_uid_state_t* uid_args = (shmemx_bootstrap_uid_state_t*)uid; + + if (env_ip_port == nullptr) { + const char* envip = std::getenv("SHMEM_UID_SESSION_ID"); + if (envip != nullptr) { + env_ip_port = envip; + SHM_LOG_DEBUG("SHMEM_UID_SESSION_ID is: " << env_ip_port); + } else { + SHM_LOG_DEBUG("The environment variable SHMEM_UID_SESSION_ID is not set."); + } + } + if (env_ip_port == nullptr) { + SHM_LOG_ERROR("Using method get_unique_id_static_magic requires setting SHMEM_UID_SESSION_ID."); + return SHMEM_BOOTSTRAP_ERROR; + } + + SHMEM_CHECK_RET(shmemi_bootstrap_net_init(uid_args, false), "rank 0: failed to init bootstrap net."); + uid_args->magic = SOCKET_MAGIC + static_magic_count; + static_magic_count++; + if (is_root) { + SHMEM_CHECK_RET(bootstrap_create_root(uid_args), "rank 0: failed to create root thread"); + } + return SHMEM_SUCCESS; +} + // Plugin pre-initialization entry function. int shmemi_bootstrap_plugin_pre_init(shmemi_bootstrap_handle_t* handle) { if (handle->pre_init_ops == nullptr) { SHM_LOG_DEBUG(" bootstrap plugin pre init start."); SHMEM_CHECK_RET(SHMEM_BOOTSTRAP_CALLOC(&handle->pre_init_ops, 1)); handle->pre_init_ops->get_unique_id = shmemi_bootstrap_get_unique_id; + handle->pre_init_ops->get_unique_id_static_magic = shmemi_bootstrap_get_unique_id_static_magic; handle->pre_init_ops->cookie = nullptr; SHM_LOG_DEBUG(" bootstrap plugin pre init end."); } else { diff --git a/src/modules/bootstrap/socket/uid_socket.h b/src/modules/bootstrap/socket/uid_socket.h index faab6093..ed0e4780 100644 --- a/src/modules/bootstrap/socket/uid_socket.h +++ b/src/modules/bootstrap/socket/uid_socket.h @@ -24,7 +24,7 @@ extern "C" { #define SOCKET_TYPE_SEND 0 #define SOCKET_TYPE_RECV 1 -#define RETRY_REFUSED_TIMES 50 +#define RETRY_REFUSED_TIMES 1e5 // 100s超时 #define RETRY_TIMEDOUT_TIMES 50 #define SLEEP_INT 1000 // 重试间隔(微秒) diff --git a/tests/fuzz/device/mem/shmem_ptr_kernel.cpp b/tests/fuzz/device/mem/shmem_ptr_kernel.cpp index 498d07c3..b4812499 100644 --- a/tests/fuzz/device/mem/shmem_ptr_kernel.cpp +++ b/tests/fuzz/device/mem/shmem_ptr_kernel.cpp @@ -13,8 +13,8 @@ public: __aicore__ inline void Init(GM_ADDR gva) { gva_gm = (__gm__ int *)gva; - rank = smem_shm_get_global_rank(); - rank_size = smem_shm_get_global_rank_size(); + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_my_pe(); } __aicore__ inline void Process() { diff --git a/tests/fuzz/device/sync/barrier_kernel.cpp b/tests/fuzz/device/sync/barrier_kernel.cpp index be13d7c5..9915458c 100644 --- a/tests/fuzz/device/sync/barrier_kernel.cpp +++ b/tests/fuzz/device/sync/barrier_kernel.cpp @@ -25,7 +25,7 @@ extern "C" SHMEM_GLOBAL void increase(uint64_t config, GM_ADDR addr, int rank_id uint64_t val = shmemi_load((__gm__ uint64_t *)addr); shmem_barrier_all(); - GM_ADDR remote = shmemi_ptr(addr, (rank_id + 1) % rank_size); + GM_ADDR remote = (GM_ADDR)shmem_ptr(addr, (rank_id + 1) % rank_size); shmemi_store((__gm__ uint64_t *)remote, val + 1); shmem_barrier_all(); #endif @@ -39,7 +39,7 @@ extern "C" SHMEM_GLOBAL void increase_vec(uint64_t config, GM_ADDR addr, int ran uint64_t val = shmemi_load((__gm__ uint64_t *)addr); shmemx_barrier_all_vec(); - GM_ADDR remote = shmemi_ptr(addr, (rank_id + 1) % rank_size); + GM_ADDR remote = (GM_ADDR)shmem_ptr(addr, (rank_id + 1) % rank_size); shmemi_store((__gm__ uint64_t *)remote, val + 1); shmemx_barrier_all_vec(); #endif @@ -61,7 +61,7 @@ extern "C" SHMEM_GLOBAL void increase_odd_team(uint64_t config, GM_ADDR addr, in shmem_barrier(team_id); if (rank_id & 1) { - GM_ADDR remote = shmemi_ptr(addr, (rank_id + 2) % rank_size); + GM_ADDR remote = (GM_ADDR)shmem_ptr(addr, (rank_id + 2) % rank_size); shmemi_store((__gm__ uint64_t *)remote, val + 1); } shmem_barrier(team_id); @@ -78,7 +78,7 @@ extern "C" SHMEM_GLOBAL void increase_vec_odd_team(uint64_t config, GM_ADDR addr shmemx_barrier_vec(team_id); if (rank_id & 1) { - GM_ADDR remote = shmemi_ptr(addr, (rank_id + 2) % rank_size); + GM_ADDR remote = (GM_ADDR)shmem_ptr(addr, (rank_id + 2) % rank_size); shmemi_store((__gm__ uint64_t *)remote, val + 1); } shmemx_barrier_vec(team_id); diff --git a/tests/fuzz/device/sync/order_kernel.cpp b/tests/fuzz/device/sync/order_kernel.cpp index 3bcd8fa1..73e28ad4 100644 --- a/tests/fuzz/device/sync/order_kernel.cpp +++ b/tests/fuzz/device/sync/order_kernel.cpp @@ -24,7 +24,7 @@ extern "C" SHMEM_GLOBAL void quiet_order(uint64_t config, GM_ADDR addr, int rank if (rank_id == 1) { uint64_t seen_b; - __gm__ uint64_t *remote = shmemi_ptr(base, 0); + __gm__ uint64_t *remote = (__gm__ uint64_t *)shmem_ptr(base, 0); do { dcci_cacheline((__gm__ uint8_t *)(remote + 32)); seen_b = shmemi_load(remote + 32); @@ -51,7 +51,7 @@ extern "C" SHMEM_GLOBAL void fence_order(uint64_t config, GM_ADDR addr, int rank if (rank_id == 1) { uint64_t seen_b; - __gm__ uint64_t *remote = shmemi_ptr(base, 0); + __gm__ uint64_t *remote = (__gm__ uint64_t *)shmem_ptr(base, 0); do { dcci_cacheline((__gm__ uint8_t *)(remote + 16)); seen_b = shmemi_load(remote + 16); diff --git a/tests/fuzz/device/team/team_kernel.cpp b/tests/fuzz/device/team/team_kernel.cpp index 85002089..8dd5821c 100644 --- a/tests/fuzz/device/team/team_kernel.cpp +++ b/tests/fuzz/device/team/team_kernel.cpp @@ -18,8 +18,8 @@ public: gva_gm = (__gm__ int *)gva; team_idx= team_id; - rank = smem_shm_get_global_rank(); - rank_size = smem_shm_get_global_rank_size(); + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_my_pe(); } __aicore__ inline void Process() { diff --git a/tests/fuzz/host/sync/barrier_host_fuzz.cpp b/tests/fuzz/host/sync/barrier_host_fuzz.cpp index 6f74daa2..700bb1a0 100644 --- a/tests/fuzz/host/sync/barrier_host_fuzz.cpp +++ b/tests/fuzz/host/sync/barrier_host_fuzz.cpp @@ -73,7 +73,7 @@ TEST_F(ShmemSyncBarrierFuzz, shmem_barrier_black_box_success) ASSERT_EQ(aclrtSynchronizeStream(scope.stream), ACL_SUCCESS); ASSERT_EQ(aclrtMemcpy(addr_host, size, addr_dev, size, ACL_MEMCPY_DEVICE_TO_HOST), ACL_SUCCESS); ASSERT_EQ((*addr_host), i); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); } ASSERT_EQ(aclrtFreeHost(addr_host), ACL_SUCCESS); @@ -111,7 +111,7 @@ TEST_F(ShmemSyncBarrierFuzz, shmem_vec_barrier_black_box_success) ASSERT_EQ(aclrtMemcpy(addr_host_vec, size, addr_dev_vec, size, ACL_MEMCPY_DEVICE_TO_HOST), ACL_SUCCESS); ASSERT_EQ((*addr_host_vec), i); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); } ASSERT_EQ(aclrtFreeHost(addr_host_vec), ACL_SUCCESS); @@ -156,7 +156,7 @@ TEST_F(ShmemSyncBarrierFuzz, shmem_barrier_black_box_odd_team_success) ASSERT_EQ(aclrtSynchronizeStream(scope.stream), ACL_SUCCESS); ASSERT_EQ(aclrtMemcpy(addr_host, size, addr_dev, size, ACL_MEMCPY_DEVICE_TO_HOST), ACL_SUCCESS); ASSERT_EQ((*addr_host), i); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); } } @@ -205,7 +205,7 @@ TEST_F(ShmemSyncBarrierFuzz, shmem_vec_barrier_black_box_odd_team_success) ASSERT_EQ(aclrtMemcpy(addr_host_vec, size, addr_dev_vec, size, ACL_MEMCPY_DEVICE_TO_HOST), ACL_SUCCESS); ASSERT_EQ((*addr_host_vec), i); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); } } diff --git a/tests/unittest/device/mem/atomic_add/atomic_add_kernel.cpp b/tests/unittest/device/mem/atomic_add/atomic_add_kernel.cpp index 045f8440..29c27274 100644 --- a/tests/unittest/device/mem/atomic_add/atomic_add_kernel.cpp +++ b/tests/unittest/device/mem/atomic_add/atomic_add_kernel.cpp @@ -17,8 +17,8 @@ constexpr uint64_t MESSAGE_SIZE = 64; extern "C" __global__ __aicore__ void test_atomic_add_##NAME##_kernel(GM_ADDR gva, uint64_t config) \ { \ shmemx_set_ffts_config(config); \ - int64_t rank = smem_shm_get_global_rank(); \ - int64_t rank_size = smem_shm_get_global_rank_size(); \ + int64_t rank = shmem_my_pe(); \ + int64_t rank_size = shmem_my_pe(); \ GM_ADDR dst_addr; \ \ for (int64_t peer = 0; peer < rank_size; peer++) \ diff --git a/tests/unittest/device/mem/rdma_mem/rdma_mem_kernel.cpp b/tests/unittest/device/mem/rdma_mem/rdma_mem_kernel.cpp index 4e558c3e..1cab4f1c 100644 --- a/tests/unittest/device/mem/rdma_mem/rdma_mem_kernel.cpp +++ b/tests/unittest/device/mem/rdma_mem/rdma_mem_kernel.cpp @@ -20,8 +20,8 @@ extern "C" __global__ __aicore__ void RDMAGetTestLowLevel(GM_ADDR gva, uint64_t pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); AscendC::LocalTensor ubLocal = buf.GetWithOffset(UB_ALIGN_SIZE * 2, 0); - int64_t rank = smem_shm_get_global_rank(); - int64_t rank_size = smem_shm_get_global_rank_size(); + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); GM_ADDR dest_addr; for (int64_t peer = 0; peer < rank_size; peer++) { @@ -46,8 +46,8 @@ extern "C" __global__ __aicore__ void RDMAPutTestLowLevel(GM_ADDR gva, uint64_t pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); AscendC::LocalTensor ubLocal = buf.GetWithOffset(UB_ALIGN_SIZE * 2, 0); - int64_t rank = smem_shm_get_global_rank(); - int64_t rank_size = smem_shm_get_global_rank_size(); + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); GM_ADDR src_addr; for (int64_t peer = 0; peer < rank_size; peer++) { @@ -67,8 +67,8 @@ void test_rdma_put_low_level(uint32_t block_dim, void* stream, uint8_t* gva, uin extern "C" __global__ __aicore__ void RDMAGetTestHighLevel(GM_ADDR gva, uint64_t config) { shmemx_set_ffts_config(config); - int64_t rank = smem_shm_get_global_rank(); - int64_t rank_size = smem_shm_get_global_rank_size(); + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); GM_ADDR dest_addr; for (int64_t peer = 0; peer < rank_size; peer++) { @@ -89,8 +89,8 @@ void test_rdma_get_high_level(uint32_t block_dim, void* stream, uint8_t* gva, ui extern "C" __global__ __aicore__ void RDMAPutTestHighLevel(GM_ADDR gva, uint64_t config) { shmemx_set_ffts_config(config); - int64_t rank = smem_shm_get_global_rank(); - int64_t rank_size = smem_shm_get_global_rank_size(); + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); GM_ADDR src_addr; for (int64_t peer = 0; peer < rank_size; peer++) { diff --git a/tests/unittest/device/mem/shmem_ptr_kernel.cpp b/tests/unittest/device/mem/shmem_ptr_kernel.cpp index e1b19cb3..142add0e 100644 --- a/tests/unittest/device/mem/shmem_ptr_kernel.cpp +++ b/tests/unittest/device/mem/shmem_ptr_kernel.cpp @@ -16,8 +16,8 @@ public: __aicore__ inline void Init(GM_ADDR gva) { gva_gm = (__gm__ int *)gva; - rank = smem_shm_get_global_rank(); - rank_size = smem_shm_get_global_rank_size(); + rank = shmem_my_pe(); + rank_size = shmem_n_pes(); } __aicore__ inline void Process() { diff --git a/tests/unittest/device/mem/ub_mem/ub_mem_kernel.cpp b/tests/unittest/device/mem/ub_mem/ub_mem_kernel.cpp index dce768d8..04457e00 100644 --- a/tests/unittest/device/mem/ub_mem/ub_mem_kernel.cpp +++ b/tests/unittest/device/mem/ub_mem/ub_mem_kernel.cpp @@ -8,8 +8,6 @@ * See LICENSE in the root of the software repository for the full text of the License. */ #include "kernel_operator.h" -#include "smem_shm_aicore_base_api.h" - #include "shmem_api.h" #include "unittest/utils/func_type.h" diff --git a/tests/unittest/device/mem/ub_non_contiguous/ub_non_contiguous_kernel.cpp b/tests/unittest/device/mem/ub_non_contiguous/ub_non_contiguous_kernel.cpp index 651b2f1c..0b6d645b 100644 --- a/tests/unittest/device/mem/ub_non_contiguous/ub_non_contiguous_kernel.cpp +++ b/tests/unittest/device/mem/ub_non_contiguous/ub_non_contiguous_kernel.cpp @@ -8,8 +8,6 @@ * See LICENSE in the root of the software repository for the full text of the License. */ #include "kernel_operator.h" -#include "smem_shm_aicore_base_api.h" - #include "shmem_api.h" #include "unittest/utils/func_type.h" diff --git a/tests/unittest/device/sync/barrier/barrier_kernel.cpp b/tests/unittest/device/sync/barrier/barrier_kernel.cpp index f115a1c7..7b52e842 100644 --- a/tests/unittest/device/sync/barrier/barrier_kernel.cpp +++ b/tests/unittest/device/sync/barrier/barrier_kernel.cpp @@ -23,7 +23,7 @@ extern "C" SHMEM_GLOBAL void increase(uint64_t config, GM_ADDR addr, int rank_id uint64_t val = shmemi_load((__gm__ uint64_t *)addr); shmem_barrier_all(); - GM_ADDR remote = shmemi_ptr(addr, (rank_id + 1) % rank_size); + GM_ADDR remote = (GM_ADDR)shmem_ptr(addr, (rank_id + 1) % rank_size); shmemi_store((__gm__ uint64_t *)remote, val + 1); shmem_barrier_all(); #endif @@ -36,7 +36,7 @@ extern "C" SHMEM_GLOBAL void increase_vec(uint64_t config, GM_ADDR addr, int ran uint64_t val = shmemi_load((__gm__ uint64_t *)addr); shmemx_barrier_all_vec(); - GM_ADDR remote = shmemi_ptr(addr, (rank_id + 1) % rank_size); + GM_ADDR remote = (GM_ADDR)shmem_ptr(addr, (rank_id + 1) % rank_size); shmemi_store((__gm__ uint64_t *)remote, val + 1); shmemx_barrier_all_vec(); #endif @@ -57,7 +57,7 @@ extern "C" SHMEM_GLOBAL void increase_odd_team(uint64_t config, GM_ADDR addr, in shmem_barrier(team_id); if (rank_id & 1) { - GM_ADDR remote = shmemi_ptr(addr, (rank_id + 2) % rank_size); + GM_ADDR remote = (GM_ADDR)shmem_ptr(addr, (rank_id + 2) % rank_size); shmemi_store((__gm__ uint64_t *)remote, val + 1); } shmem_barrier(team_id); @@ -73,7 +73,7 @@ extern "C" SHMEM_GLOBAL void increase_vec_odd_team(uint64_t config, GM_ADDR addr shmemx_barrier_vec(team_id); if (rank_id & 1) { - GM_ADDR remote = shmemi_ptr(addr, (rank_id + 2) % rank_size); + GM_ADDR remote = (GM_ADDR)shmem_ptr(addr, (rank_id + 2) % rank_size); shmemi_store((__gm__ uint64_t *)remote, val + 1); } shmemx_barrier_vec(team_id); diff --git a/tests/unittest/device/sync/order/order_kernel.cpp b/tests/unittest/device/sync/order/order_kernel.cpp index 6b5bd284..610ba46c 100644 --- a/tests/unittest/device/sync/order/order_kernel.cpp +++ b/tests/unittest/device/sync/order/order_kernel.cpp @@ -22,7 +22,7 @@ extern "C" SHMEM_GLOBAL void quiet_order(uint64_t config, GM_ADDR addr, int rank if (rank_id == 1) { uint64_t seen_b; - __gm__ uint64_t *remote = shmemi_ptr(base, 0); + __gm__ uint64_t *remote = (__gm__ uint64_t *)shmem_ptr(base, 0); do { dcci_cacheline((__gm__ uint8_t *)(remote + 32)); seen_b = shmemi_load(remote + 32); @@ -48,7 +48,7 @@ extern "C" SHMEM_GLOBAL void fence_order(uint64_t config, GM_ADDR addr, int rank if (rank_id == 1) { uint64_t seen_b; - __gm__ uint64_t *remote = shmemi_ptr(base, 0); + __gm__ uint64_t *remote = (__gm__ uint64_t *)shmem_ptr(base, 0); do { dcci_cacheline((__gm__ uint8_t *)(remote + 16)); seen_b = shmemi_load(remote + 16); diff --git a/tests/unittest/device/team/team/team_kernel.cpp b/tests/unittest/device/team/team/team_kernel.cpp index 26ff4ce8..92d8d562 100644 --- a/tests/unittest/device/team/team/team_kernel.cpp +++ b/tests/unittest/device/team/team/team_kernel.cpp @@ -19,8 +19,8 @@ public: gva_gm = (__gm__ int *)gva; team_idx= team_id; - rank = smem_shm_get_global_rank(); - rank_size = smem_shm_get_global_rank_size(); + rank = shmem_my_pe(); + rank_size = shmem_n_pes(); } __aicore__ inline void Process() { diff --git a/tests/unittest/host/init/init_host_test.cpp b/tests/unittest/host/init/init_host_test.cpp index c4b59298..cc557b6d 100644 --- a/tests/unittest/host/init/init_host_test.cpp +++ b/tests/unittest/host/init/init_host_test.cpp @@ -29,55 +29,30 @@ void test_shmem_init(int rank_id, int n_ranks, uint64_t local_mem_size) int status = SHMEM_SUCCESS; EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); - shmem_init_attr_t *attributes; - shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); + shmemx_uniqueid_t uid; + if (rank_id == 0) { + status = shmemi_get_uniqueid_static_magic(&uid, true); + } else { + status = shmemi_get_uniqueid_static_magic(&uid, false); + } EXPECT_EQ(status, SHMEM_SUCCESS); - EXPECT_EQ(shm::g_state.mype, rank_id); - EXPECT_EQ(shm::g_state.npes, n_ranks); - EXPECT_NE(shm::g_state.heap_base, nullptr); - EXPECT_NE(shm::g_state.p2p_heap_host_base[rank_id], nullptr); - EXPECT_NE(shm::g_state.p2p_heap_device_base[rank_id], nullptr); - EXPECT_EQ(shm::g_state.heap_size, local_mem_size + SHMEM_EXTRA_SIZE); - EXPECT_NE(shm::g_state.team_pools[0], nullptr); - status = shmem_init_status(); - EXPECT_EQ(status, SHMEM_STATUS_IS_INITIALIZED); - status = shmem_finalize(); + shmem_init_attr_t* attributes; + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, + &uid, + &attributes); EXPECT_EQ(status, SHMEM_SUCCESS); - EXPECT_EQ(aclrtResetDevice(device_id), 0); - EXPECT_EQ(aclFinalize(), 0); - if (::testing::Test::HasFailure()) { - exit(1); - } -} - -void test_shmem_init_attr(int rank_id, int n_ranks, uint64_t local_mem_size) -{ - uint32_t device_id = rank_id % test_gnpu_num + test_first_npu; - int status = SHMEM_SUCCESS; - EXPECT_EQ(aclInit(nullptr), 0); - EXPECT_EQ(status = aclrtSetDevice(device_id), 0); - - EXPECT_EQ(status = shmem_set_conf_store_tls(false, nullptr, 0), 0); - shmem_init_attr_t *attributes = new shmem_init_attr_t{ - rank_id, n_ranks, {}, local_mem_size, {0, SHMEM_DATA_OP_MTE, 120, 120, 120}}; - std::copy_n(test_global_ipport, SHMEM_MAX_IP_PORT_LEN, attributes->ip_port); - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); - + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); EXPECT_EQ(status, SHMEM_SUCCESS); - EXPECT_EQ(shm::g_state.mype, rank_id); - EXPECT_EQ(shm::g_state.npes, n_ranks); - EXPECT_NE(shm::g_state.heap_base, nullptr); - EXPECT_NE(shm::g_state.p2p_heap_host_base[rank_id], nullptr); - EXPECT_NE(shm::g_state.p2p_heap_device_base[rank_id], nullptr); - EXPECT_EQ(shm::g_state.heap_size, local_mem_size + SHMEM_EXTRA_SIZE); - EXPECT_NE(shm::g_state.team_pools[0], nullptr); + EXPECT_EQ(g_state.mype, rank_id); + EXPECT_EQ(g_state.npes, n_ranks); + EXPECT_NE(g_state.heap_base, nullptr); + EXPECT_NE(g_state.p2p_heap_host_base[rank_id], nullptr); + EXPECT_NE(g_state.p2p_heap_device_base[rank_id], nullptr); + EXPECT_EQ(g_state.heap_size, local_mem_size + SHMEM_EXTRA_SIZE); + EXPECT_NE(g_state.team_pools[0], nullptr); status = shmem_init_status(); EXPECT_EQ(status, SHMEM_STATUS_IS_INITIALIZED); status = shmem_finalize(); - delete attributes; EXPECT_EQ(status, SHMEM_SUCCESS); EXPECT_EQ(aclrtResetDevice(device_id), 0); EXPECT_EQ(aclFinalize(), 0); @@ -93,10 +68,17 @@ void test_shmem_init_invalid_rank_id(int rank_id, int n_ranks, uint64_t local_me int status = SHMEM_SUCCESS; EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); - shmem_init_attr_t *attributes; - shmem_set_attr(erank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmemi_get_uniqueid_static_magic(&uid, true); + } else { + shmemi_get_uniqueid_static_magic(&uid, false); + } + shmem_init_attr_t* attributes; + shmemx_set_attr_uniqueid_args(erank_id, n_ranks, local_mem_size, + &uid, + &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); EXPECT_EQ(status, SHMEM_INVALID_VALUE); status = shmem_init_status(); EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); @@ -114,14 +96,18 @@ void test_shmem_init_invalid_n_ranks(int rank_id, int n_ranks, uint64_t local_me int status = SHMEM_SUCCESS; EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmemi_get_uniqueid_static_magic(&uid, true); + } else { + shmemi_get_uniqueid_static_magic(&uid, false); + } + shmem_init_attr_t* attributes; + shmemx_set_attr_uniqueid_args(rank_id, en_ranks, local_mem_size, + &uid, + &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); - shmem_init_attr_t *attributes; - status = shmem_set_attr(rank_id, en_ranks, local_mem_size, test_global_ipport, &attributes); - EXPECT_EQ(status, SHMEM_INVALID_VALUE); - status = shmem_init_attr(attributes); - EXPECT_TRUE(status != 0); - attributes->n_ranks = en_ranks; - status = shmem_init_attr(attributes); EXPECT_EQ(status, SHMEM_INVALID_VALUE); status = shmem_init_status(); EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); @@ -138,10 +124,17 @@ void test_shmem_init_rank_id_over_size(int rank_id, int n_ranks, uint64_t local_ int status = SHMEM_SUCCESS; EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); - shmem_init_attr_t *attributes; - shmem_set_attr(rank_id + n_ranks, n_ranks, local_mem_size, test_global_ipport, &attributes); - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmemi_get_uniqueid_static_magic(&uid, true); + } else { + shmemi_get_uniqueid_static_magic(&uid, false); + } + shmem_init_attr_t* attributes; + shmemx_set_attr_uniqueid_args(rank_id + n_ranks, n_ranks, local_mem_size, + &uid, + &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); EXPECT_EQ(status, SHMEM_INVALID_PARAM); status = shmem_init_status(); EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); @@ -159,32 +152,18 @@ void test_shmem_init_zero_mem(int rank_id, int n_ranks, uint64_t local_mem_size) int status = SHMEM_SUCCESS; EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); - shmem_init_attr_t *attributes; - shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); - EXPECT_EQ(status, SHMEM_INVALID_VALUE); - status = shmem_init_status(); - EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); - EXPECT_EQ(aclrtResetDevice(device_id), 0); - EXPECT_EQ(aclFinalize(), 0); - if (::testing::Test::HasFailure()) { - exit(1); + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmemi_get_uniqueid_static_magic(&uid, true); + } else { + shmemi_get_uniqueid_static_magic(&uid, false); } -} - -void test_shmem_init_invalid_mem(int rank_id, int n_ranks, uint64_t local_mem_size) -{ - // local_mem_size = invalid - uint32_t device_id = rank_id % test_gnpu_num + test_first_npu; - int status = SHMEM_SUCCESS; - EXPECT_EQ(aclInit(nullptr), 0); - EXPECT_EQ(status = aclrtSetDevice(device_id), 0); - shmem_init_attr_t *attributes; - shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); - EXPECT_EQ(status, SHMEM_SMEM_ERROR); + shmem_init_attr_t* attributes; + shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, + &uid, + &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + EXPECT_EQ(status, SHMEM_INVALID_VALUE); status = shmem_init_status(); EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); EXPECT_EQ(aclrtResetDevice(device_id), 0); @@ -194,85 +173,6 @@ void test_shmem_init_invalid_mem(int rank_id, int n_ranks, uint64_t local_mem_si } } -void test_shmem_init_set_config(int rank_id, int n_ranks, uint64_t local_mem_size) -{ - uint32_t device_id = rank_id % test_gnpu_num + test_first_npu; - int status = SHMEM_SUCCESS; - EXPECT_EQ(aclInit(nullptr), 0); - EXPECT_EQ(status = aclrtSetDevice(device_id), 0); - shmem_init_attr_t *attributes; - shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); - - shmem_set_data_op_engine_type(attributes, SHMEM_DATA_OP_MTE); - shmem_set_timeout(attributes, shm::timeout); - EXPECT_EQ(shm::g_attr.option_attr.control_operation_timeout, shm::timeout); - EXPECT_EQ(shm::g_attr.option_attr.data_op_engine_type, SHMEM_DATA_OP_MTE); - - EXPECT_EQ(status = shmem_set_conf_store_tls(false, nullptr, 0), 0); - status = shmem_init_attr(attributes); - EXPECT_EQ(status, SHMEM_SUCCESS); - EXPECT_EQ(shm::g_state.mype, rank_id); - EXPECT_EQ(shm::g_state.npes, n_ranks); - EXPECT_NE(shm::g_state.heap_base, nullptr); - EXPECT_NE(shm::g_state.p2p_heap_host_base[rank_id], nullptr); - EXPECT_NE(shm::g_state.p2p_heap_device_base[rank_id], nullptr); - EXPECT_EQ(shm::g_state.heap_size, local_mem_size + SHMEM_EXTRA_SIZE); - EXPECT_NE(shm::g_state.team_pools[0], nullptr); - - EXPECT_EQ(shm::g_attr.option_attr.control_operation_timeout, shm::timeout); - EXPECT_EQ(shm::g_attr.option_attr.data_op_engine_type, SHMEM_DATA_OP_MTE); - - status = shmem_init_status(); - EXPECT_EQ(status, SHMEM_STATUS_IS_INITIALIZED); - status = shmem_finalize(); - EXPECT_EQ(status, SHMEM_SUCCESS); - EXPECT_EQ(aclrtResetDevice(device_id), 0); - EXPECT_EQ(aclFinalize(), 0); - if (::testing::Test::HasFailure()) { - exit(1); - } -} - -void test_shmem_global_exit(int rank_id, int n_ranks, uint64_t local_mem_size) -{ - uint32_t device_id = rank_id % test_gnpu_num + test_first_npu; - int status = SHMEM_SUCCESS; - EXPECT_EQ(aclInit(nullptr), 0); - EXPECT_EQ(status = aclrtSetDevice(device_id), 0); - status = shmem_set_conf_store_tls(false, nullptr, 0); - EXPECT_EQ(status, 0); - shmem_init_attr_t *attributes; - shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); - - shmem_set_data_op_engine_type(attributes, SHMEM_DATA_OP_MTE); - shmem_set_timeout(attributes, shm::timeout); - EXPECT_EQ(shm::g_attr.option_attr.control_operation_timeout, shm::timeout); - EXPECT_EQ(shm::g_attr.option_attr.data_op_engine_type, SHMEM_DATA_OP_MTE); - - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); - EXPECT_EQ(status, SHMEM_SUCCESS); - EXPECT_EQ(shm::g_state.mype, rank_id); - EXPECT_EQ(shm::g_state.npes, n_ranks); - EXPECT_NE(shm::g_state.heap_base, nullptr); - EXPECT_NE(shm::g_state.p2p_heap_host_base[rank_id], nullptr); - EXPECT_NE(shm::g_state.p2p_heap_device_base[rank_id], nullptr); - EXPECT_EQ(shm::g_state.heap_size, local_mem_size + SHMEM_EXTRA_SIZE); - EXPECT_NE(shm::g_state.team_pools[0], nullptr); - - EXPECT_EQ(shm::g_attr.option_attr.control_operation_timeout, shm::timeout); - EXPECT_EQ(shm::g_attr.option_attr.data_op_engine_type, SHMEM_DATA_OP_MTE); - - status = shmem_init_status(); - EXPECT_EQ(status, SHMEM_STATUS_IS_INITIALIZED); - shmem_global_exit(0); - EXPECT_EQ(aclrtResetDevice(device_id), 0); - EXPECT_EQ(aclFinalize(), 0); - if (::testing::Test::HasFailure()) { - exit(1); - } -} - TEST(TestInitAPI, TestShmemInit) { const int process_count = test_gnpu_num; @@ -280,13 +180,6 @@ TEST(TestInitAPI, TestShmemInit) test_mutil_task(test_shmem_init, local_mem_size, process_count); } -TEST(TestInitAPI, TestShmemInitAttr) -{ - const int process_count = test_gnpu_num; - uint64_t local_mem_size = 1024UL * 1024UL * 1024; - test_mutil_task(test_shmem_init_attr, local_mem_size, process_count); -} - TEST(TestInitAPI, TestShmemInitErrorInvalidRankId) { const int process_count = test_gnpu_num; @@ -315,20 +208,6 @@ TEST(TestInitAPI, TestShmemInitErrorZeroMem) test_mutil_task(test_shmem_init_zero_mem, local_mem_size, process_count); } -TEST(TestInitAPI, TestShmemInitErrorInvalidMem) -{ - const int process_count = test_gnpu_num; - uint64_t local_mem_size = 1024UL * 1024UL; - test_mutil_task(test_shmem_init_invalid_mem, local_mem_size, process_count); -} - -TEST(TestInitAPI, TestSetConfig) -{ - const int process_count = test_gnpu_num; - uint64_t local_mem_size = 1024UL * 1024UL * 1024; - test_mutil_task(test_shmem_init_set_config, local_mem_size, process_count); -} - TEST(TestInitAPI, TestInfoGetVersion) { int major = 0; @@ -365,13 +244,6 @@ TEST(TestInitAPI, TestInfoGetNameNull) EXPECT_EQ(input, nullptr); } -TEST(TestInitAPI, TestShmemGlobalExit) -{ - const int process_count = test_gnpu_num; - uint64_t local_mem_size = 1024UL * 1024UL * 1024; - test_mutil_task(test_shmem_global_exit, local_mem_size, process_count); -} - TEST(TestInitAPI, TestShmemSetLogLevel) { auto ret = shmem_set_log_level(shm::DEBUG_LEVEL); @@ -393,52 +265,4 @@ TEST(TestInitAPI, TestShmemSetLogLevel) EXPECT_EQ(shmem_set_log_level(-1), 0); unsetenv("SHMEM_LOG_LEVEL"); -} - -TEST(TestInitAPI, TestShmemSetExternLogger) -{ - auto ret = shmem_set_extern_logger(shm::logger_test_example); - EXPECT_EQ(ret, 0); -} - -TEST(TestInitAPI, TestShmemGetUniqueId) -{ - const char *ipInfo = std::getenv("SHMEM_UID_SOCK_IFNAM"); - if (ipInfo == nullptr) { - return; - } - - for (int i = 0; i < 10; i++) { - shmem_uniqueid_t uid; - int ret = shmem_get_uniqueid(&uid); - EXPECT_EQ(ret, SHMEM_SUCCESS); - - shmem_uniqueid_inner_t *innerUID = reinterpret_cast(&uid); - - // test bind ip:port again - int sockfd = ::socket(AF_INET, SOCK_STREAM, 0); - if (sockfd < 0) { - std::cout << "create socket failed" << std::endl; - return; - } - - int reuse = 1; - ::setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)); - - struct sockaddr_in addr{}; - addr.sin_family = AF_INET; - addr.sin_port = innerUID->addr.addr.addr4.sin_port; - addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); // 绑定 127.0.0.1 - - bool inUse = (::bind(sockfd, reinterpret_cast(&addr), sizeof(addr)) != 0); - if (inUse) { - auto errorNum = errno; - std::cout << "the address is in use" << errorNum << std::endl; - EXPECT_TRUE(false); - break; - } - - ::close(sockfd); - EXPECT_TRUE(true); - } } \ No newline at end of file diff --git a/tests/unittest/host/main_test.cpp b/tests/unittest/host/main_test.cpp index a91dbaba..002ec244 100644 --- a/tests/unittest/host/main_test.cpp +++ b/tests/unittest/host/main_test.cpp @@ -12,6 +12,7 @@ #include #include "acl/acl.h" #include "shmem_api.h" +#include "shmemi_host_common.h" #include "unittest_main_test.h" int test_global_ranks; @@ -35,11 +36,19 @@ void test_init(int rank_id, int n_ranks, uint64_t local_mem_size, aclrtStream *s aclrtStream stream = nullptr; EXPECT_EQ(status = aclrtCreateStream(&stream), 0); - EXPECT_EQ(status = shmem_set_conf_store_tls(false, nullptr, 0), 0); - + shmemx_uniqueid_t uid; + if (rank_id == 0) { + status = shmemi_get_uniqueid_static_magic(&uid, true); + } else { + status = shmemi_get_uniqueid_static_magic(&uid, false); + } + EXPECT_EQ(status, SHMEM_SUCCESS); shmem_init_attr_t* attributes; - shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); - status = shmem_init_attr(attributes); + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, + &uid, + &attributes); + EXPECT_EQ(status, SHMEM_SUCCESS); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); EXPECT_EQ(status, 0); *st = stream; } @@ -59,12 +68,21 @@ int32_t test_rdma_init(int rank_id, int n_ranks, uint64_t local_mem_size, aclrtS aclrtStream stream = nullptr; EXPECT_EQ(status = aclrtCreateStream(&stream), 0); - EXPECT_EQ(status = shmem_set_conf_store_tls(false, nullptr, 0), 0); - + shmemx_uniqueid_t uid; + if (rank_id == 0) { + status = shmemi_get_uniqueid_static_magic(&uid, true); + } else { + status = shmemi_get_uniqueid_static_magic(&uid, false); + } + EXPECT_EQ(status, SHMEM_SUCCESS); shmem_init_attr_t* attributes; - shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, + &uid, + &attributes); + EXPECT_EQ(status, SHMEM_SUCCESS); attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; - status = shmem_init_attr(attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + EXPECT_EQ(status, 0); *st = stream; return status; } diff --git a/tests/unittest/host/mem/atomic_add/atomic_add_host_test.cpp b/tests/unittest/host/mem/atomic_add/atomic_add_host_test.cpp index 9074db02..0f126f4a 100644 --- a/tests/unittest/host/mem/atomic_add/atomic_add_host_test.cpp +++ b/tests/unittest/host/mem/atomic_add/atomic_add_host_test.cpp @@ -69,7 +69,7 @@ SHMEM_ATOMIC_ADD_FUNC_TYPE_HOST(TEST_SHMEM_ATOMIC_ADD_HOST); aclrtStream stream; \ test_init(rank_id, n_ranks, local_mem_size, &stream); \ ASSERT_NE(stream, nullptr); \ - test_atomic_add_##NAME##_host(stream, (uint8_t *)shm::g_state.heap_base, rank_id, n_ranks); \ + test_atomic_add_##NAME##_host(stream, (uint8_t *)g_state.heap_base, rank_id, n_ranks); \ std::cout << "[TEST] begin to exit...... rank_id: " << rank_id << std::endl; \ test_finalize(stream, device_id); \ if (::testing::Test::HasFailure()) \ diff --git a/tests/unittest/host/mem/gm_mem/gm_mem_host_test.cpp b/tests/unittest/host/mem/gm_mem/gm_mem_host_test.cpp index fe75b96c..995ee698 100644 --- a/tests/unittest/host/mem/gm_mem/gm_mem_host_test.cpp +++ b/tests/unittest/host/mem/gm_mem/gm_mem_host_test.cpp @@ -76,7 +76,7 @@ SHMEM_FUNC_TYPE_HOST(TEST_PUT_GET); test_init(rank_id, n_ranks, local_mem_size, &stream); \ ASSERT_NE(stream, nullptr); \ \ - test_##NAME##_put_get(stream, (uint8_t *)shm::g_state.heap_base, rank_id, n_ranks); \ + test_##NAME##_put_get(stream, (uint8_t *)g_state.heap_base, rank_id, n_ranks); \ std::cout << "[TEST] begin to exit...... rank_id: " << rank_id << std::endl; \ test_finalize(stream, device_id); \ if (::testing::Test::HasFailure()) { \ diff --git a/tests/unittest/host/mem/gm_mem_disable_L2/gm_mem_disable_L2_host_test.cpp b/tests/unittest/host/mem/gm_mem_disable_L2/gm_mem_disable_L2_host_test.cpp index 1aa2398a..daacfa5a 100644 --- a/tests/unittest/host/mem/gm_mem_disable_L2/gm_mem_disable_L2_host_test.cpp +++ b/tests/unittest/host/mem/gm_mem_disable_L2/gm_mem_disable_L2_host_test.cpp @@ -74,7 +74,7 @@ void test_shmemx_mte_mem(int rank_id, int n_ranks, uint64_t local_mem_size) test_init(rank_id, n_ranks, local_mem_size, &stream); ASSERT_NE(stream, nullptr); - test_shmemx_mte_put_get(stream, (uint8_t *)shm::g_state.heap_base, rank_id, n_ranks); + test_shmemx_mte_put_get(stream, (uint8_t *)g_state.heap_base, rank_id, n_ranks); std::cout << "[TEST] begin to exit...... rank_id: " << rank_id << std::endl; test_finalize(stream, device_id); if (::testing::Test::HasFailure()) { diff --git a/tests/unittest/host/mem/gm_non_contiguous/gm_non_contiguous_host_test.cpp b/tests/unittest/host/mem/gm_non_contiguous/gm_non_contiguous_host_test.cpp index d9cdb1dc..0143f5ef 100644 --- a/tests/unittest/host/mem/gm_non_contiguous/gm_non_contiguous_host_test.cpp +++ b/tests/unittest/host/mem/gm_non_contiguous/gm_non_contiguous_host_test.cpp @@ -85,7 +85,7 @@ SHMEM_FUNC_TYPE_HOST(TEST_NON_CONTIGUOUS_PUT_GET); test_init(rank_id, n_ranks, local_mem_size, &stream); \ ASSERT_NE(stream, nullptr); \ \ - test_##NAME##_non_contiguous_put_get(stream, (uint8_t *)shm::g_state.heap_base, rank_id, n_ranks); \ + test_##NAME##_non_contiguous_put_get(stream, (uint8_t *)g_state.heap_base, rank_id, n_ranks); \ std::cout << "[TEST] begin to exit...... rank_id: " << rank_id << std::endl; \ test_finalize(stream, device_id); \ if (::testing::Test::HasFailure()) { \ diff --git a/tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp b/tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp index 0de8b59d..72d1bdc4 100644 --- a/tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp +++ b/tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp @@ -44,40 +44,40 @@ static void test_rdma_put_get(aclrtStream stream, uint8_t *gva, uint32_t rank_id } ASSERT_EQ(aclrtMemcpy(gva, totalSize, inHost, totalSize, ACL_MEMCPY_HOST_TO_DEVICE), 0); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); test_rdma_put_low_level(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); ASSERT_EQ(aclrtSynchronizeStream(stream), 0); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); ASSERT_EQ(aclrtMemcpy(outHost, totalSize, gva, totalSize, ACL_MEMCPY_DEVICE_TO_HOST), 0); for (uint32_t i = 0; i < rank_size; i++) { ASSERT_EQ(outHost[i * messageSize / sizeof(uint32_t)], i + rankOffset); } ASSERT_EQ(aclrtMemcpy(gva, totalSize, inHost, totalSize, ACL_MEMCPY_HOST_TO_DEVICE), 0); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); test_rdma_get_low_level(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); ASSERT_EQ(aclrtSynchronizeStream(stream), 0); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); ASSERT_EQ(aclrtMemcpy(outHost, totalSize, gva, totalSize, ACL_MEMCPY_DEVICE_TO_HOST), 0); for (uint32_t i = 0; i < rank_size; i++) { ASSERT_EQ(outHost[i * messageSize / sizeof(uint32_t)], i + rankOffset); } ASSERT_EQ(aclrtMemcpy(gva, totalSize, inHost, totalSize, ACL_MEMCPY_HOST_TO_DEVICE), 0); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); test_rdma_put_high_level(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); ASSERT_EQ(aclrtSynchronizeStream(stream), 0); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); ASSERT_EQ(aclrtMemcpy(outHost, totalSize, gva, totalSize, ACL_MEMCPY_DEVICE_TO_HOST), 0); for (uint32_t i = 0; i < rank_size; i++) { ASSERT_EQ(outHost[i * messageSize / sizeof(uint32_t)], i + rankOffset); } ASSERT_EQ(aclrtMemcpy(gva, totalSize, inHost, totalSize, ACL_MEMCPY_HOST_TO_DEVICE), 0); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); test_rdma_get_high_level(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); ASSERT_EQ(aclrtSynchronizeStream(stream), 0); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); ASSERT_EQ(aclrtMemcpy(outHost, totalSize, gva, totalSize, ACL_MEMCPY_DEVICE_TO_HOST), 0); for (uint32_t i = 0; i < rank_size; i++) { ASSERT_EQ(outHost[i * messageSize / sizeof(uint32_t)], i + rankOffset); diff --git a/tests/unittest/host/mem/shmem_host_get_stream_test.cpp b/tests/unittest/host/mem/shmem_host_get_stream_test.cpp index 262e9f7f..dc375f89 100644 --- a/tests/unittest/host/mem/shmem_host_get_stream_test.cpp +++ b/tests/unittest/host/mem/shmem_host_get_stream_test.cpp @@ -118,7 +118,7 @@ static void host_test_put_get_mem_stream(int rank_id, int rank_size, uint64_t lo void *ptr = shmem_malloc(1024); host_putmem(ptr, dev_ptr, rank_id, input_size); - ASSERT_EQ(aclrtSynchronizeStream(shm::g_state_host.default_stream), 0); + ASSERT_EQ(aclrtSynchronizeStream(g_state_host.default_stream), 0); sleep(sleep_time); ASSERT_EQ(aclrtMemcpy(out.data(), total_size, ptr, total_size, ACL_MEMCPY_DEVICE_TO_HOST), 0); @@ -136,7 +136,7 @@ static void host_test_put_get_mem_stream(int rank_id, int rank_size, uint64_t lo std::cout << std::endl; size_t ele_size = 16; host_test_getmem_stream((uint8_t *)ptr, (uint8_t *)dev_ptr, rank_size, ele_size, stream); - ASSERT_EQ(aclrtSynchronizeStream(shm::g_state_host.default_stream), 0); + ASSERT_EQ(aclrtSynchronizeStream(g_state_host.default_stream), 0); sleep(sleep_time); ASSERT_EQ(aclrtMemcpy(input.data(), input_size, dev_ptr, input_size, ACL_MEMCPY_DEVICE_TO_HOST), 0); diff --git a/tests/unittest/host/mem/shmem_host_heap_test.cpp b/tests/unittest/host/mem/shmem_host_heap_test.cpp index f7df19ec..f3b4d039 100644 --- a/tests/unittest/host/mem/shmem_host_heap_test.cpp +++ b/tests/unittest/host/mem/shmem_host_heap_test.cpp @@ -27,17 +27,27 @@ protected: int status = SHMEM_SUCCESS; EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); - shmem_init_attr_t *attributes; - shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); - status = shmem_init_attr(attributes); + shmemx_uniqueid_t uid; + if (rank_id == 0) { + status = shmemi_get_uniqueid_static_magic(&uid, true); + } else { + status = shmemi_get_uniqueid_static_magic(&uid, false); + } EXPECT_EQ(status, SHMEM_SUCCESS); - EXPECT_EQ(shm::g_state.mype, rank_id); - EXPECT_EQ(shm::g_state.npes, n_ranks); - EXPECT_NE(shm::g_state.heap_base, nullptr); - EXPECT_NE(shm::g_state.p2p_heap_host_base[rank_id], nullptr); - EXPECT_NE(shm::g_state.p2p_heap_device_base[rank_id], nullptr); - EXPECT_EQ(shm::g_state.heap_size, local_mem_size + SHMEM_EXTRA_SIZE); - EXPECT_NE(shm::g_state.team_pools[0], nullptr); + shmem_init_attr_t* attributes; + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, + &uid, + &attributes); + EXPECT_EQ(status, SHMEM_SUCCESS); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + EXPECT_EQ(status, SHMEM_SUCCESS); + EXPECT_EQ(g_state.mype, rank_id); + EXPECT_EQ(g_state.npes, n_ranks); + EXPECT_NE(g_state.heap_base, nullptr); + EXPECT_NE(g_state.p2p_heap_host_base[rank_id], nullptr); + EXPECT_NE(g_state.p2p_heap_device_base[rank_id], nullptr); + EXPECT_EQ(g_state.heap_size, local_mem_size + SHMEM_EXTRA_SIZE); + EXPECT_NE(g_state.team_pools[0], nullptr); status = shmem_init_status(); EXPECT_EQ(status, SHMEM_STATUS_IS_INITIALIZED); testingRank = true; diff --git a/tests/unittest/host/mem/shmem_ptr_host_test.cpp b/tests/unittest/host/mem/shmem_ptr_host_test.cpp index c4714178..c091f72a 100644 --- a/tests/unittest/host/mem/shmem_ptr_host_test.cpp +++ b/tests/unittest/host/mem/shmem_ptr_host_test.cpp @@ -92,9 +92,9 @@ TEST(TestMemApi, TestShmemMteSetUbParams) uint32_t event_id = 0; ASSERT_EQ(shmem_mte_set_ub_params(offset, ub_size, event_id), 0); - ASSERT_EQ(shm::g_state.mte_config.shmem_ub, offset); - ASSERT_EQ(shm::g_state.mte_config.ub_size, ub_size); - ASSERT_EQ(shm::g_state.mte_config.event_id, event_id); + ASSERT_EQ(g_state.mte_config.shmem_ub, offset); + ASSERT_EQ(g_state.mte_config.ub_size, ub_size); + ASSERT_EQ(g_state.mte_config.event_id, event_id); test_finalize(stream, device_id); }, local_mem_size, process_count); diff --git a/tests/unittest/host/mem/ub_mem/ub_mem_host_test.cpp b/tests/unittest/host/mem/ub_mem/ub_mem_host_test.cpp index fc539c8e..26d20ddb 100644 --- a/tests/unittest/host/mem/ub_mem/ub_mem_host_test.cpp +++ b/tests/unittest/host/mem/ub_mem/ub_mem_host_test.cpp @@ -75,7 +75,7 @@ SHMEM_FUNC_TYPE_HOST(TEST_UB_PUT_GET); test_init(rank_id, n_ranks, local_mem_size, &stream); \ ASSERT_NE(stream, nullptr); \ \ - test_ub_##NAME##_put_get(stream, (uint8_t *)shm::g_state.heap_base, rank_id, n_ranks); \ + test_ub_##NAME##_put_get(stream, (uint8_t *)g_state.heap_base, rank_id, n_ranks); \ std::cout << "[TEST] begin to exit...... rank_id: " << rank_id << std::endl; \ test_finalize(stream, device_id); \ if (::testing::Test::HasFailure()) { \ diff --git a/tests/unittest/host/mem/ub_non_contiguous/ub_non_contiguous_host_test.cpp b/tests/unittest/host/mem/ub_non_contiguous/ub_non_contiguous_host_test.cpp index d5313b13..69abd3fb 100644 --- a/tests/unittest/host/mem/ub_non_contiguous/ub_non_contiguous_host_test.cpp +++ b/tests/unittest/host/mem/ub_non_contiguous/ub_non_contiguous_host_test.cpp @@ -83,7 +83,7 @@ SHMEM_FUNC_TYPE_HOST(TEST_UB_NON_CONTIGUOUS_PUT_GET); test_init(rank_id, n_ranks, local_mem_size, &stream); \ ASSERT_NE(stream, nullptr); \ \ - TestUB##NAME##NonContiguousPutGet(stream, (uint8_t *)shm::g_state.heap_base, rank_id, n_ranks); \ + TestUB##NAME##NonContiguousPutGet(stream, (uint8_t *)g_state.heap_base, rank_id, n_ranks); \ std::cout << "[TEST] begin to exit...... rank_id: " << rank_id << std::endl; \ test_finalize(stream, device_id); \ if (::testing::Test::HasFailure()) { \ diff --git a/tests/unittest/host/sync/barrier/barrier_host_test.cpp b/tests/unittest/host/sync/barrier/barrier_host_test.cpp index fbf71afc..d4310073 100644 --- a/tests/unittest/host/sync/barrier/barrier_host_test.cpp +++ b/tests/unittest/host/sync/barrier/barrier_host_test.cpp @@ -37,7 +37,7 @@ static void test_barrier_black_box(int32_t rank_id, int32_t n_ranks, uint64_t lo ASSERT_EQ(aclrtSynchronizeStream(stream), 0); ASSERT_EQ(aclrtMemcpy(addr_host, sizeof(uint64_t), addr_dev, sizeof(uint64_t), ACL_MEMCPY_DEVICE_TO_HOST), 0); ASSERT_EQ((*addr_host), i); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); } uint64_t *addr_dev_vec = static_cast(shmem_malloc(sizeof(uint64_t))); @@ -52,7 +52,7 @@ static void test_barrier_black_box(int32_t rank_id, int32_t n_ranks, uint64_t lo ASSERT_EQ( aclrtMemcpy(addr_host_vec, sizeof(uint64_t), addr_dev_vec, sizeof(uint64_t), ACL_MEMCPY_DEVICE_TO_HOST), 0); ASSERT_EQ((*addr_host_vec), i); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); } ASSERT_EQ(aclrtFreeHost(addr_host), 0); @@ -89,29 +89,33 @@ static void test_barrier_black_box_odd_team(int32_t rank_id, int32_t n_ranks, ui uint64_t *addr_host_vec; ASSERT_EQ(aclrtMallocHost(reinterpret_cast(&addr_host_vec), sizeof(uint64_t)), 0); - if (rank_id & 1) { - for (int32_t i = 1; i <= SHMEM_BARRIER_TEST_NUM; i++) { + + for (int32_t i = 1; i <= SHMEM_BARRIER_TEST_NUM; i++) { + if (rank_id & 1) { std::cout << "[TEST] barriers test blackbox rank_id: " << rank_id << " time: " << i << std::endl; increase_do_odd_team(stream, shmemx_get_ffts_config(), (uint8_t *)addr_dev, rank_id, n_ranks, team_odd); ASSERT_EQ(aclrtSynchronizeStream(stream), 0); ASSERT_EQ(aclrtMemcpy(addr_host, sizeof(uint64_t), addr_dev, sizeof(uint64_t), ACL_MEMCPY_DEVICE_TO_HOST), - 0); + 0); ASSERT_EQ((*addr_host), i); - shm::shmemi_control_barrier_all(); } + shmemi_control_barrier_all(); + } - for (int32_t i = 1; i <= SHMEM_BARRIER_TEST_NUM; i++) { + for (int32_t i = 1; i <= SHMEM_BARRIER_TEST_NUM; i++) { + if (rank_id & 1) { std::cout << "[TEST] vec barriers test blackbox rank_id: " << rank_id << " time: " << i << std::endl; increase_vec_do_odd_team(stream, shmemx_get_ffts_config(), (uint8_t *)addr_dev_vec, rank_id, n_ranks, - team_odd); + team_odd); ASSERT_EQ(aclrtSynchronizeStream(stream), 0); ASSERT_EQ( aclrtMemcpy(addr_host_vec, sizeof(uint64_t), addr_dev_vec, sizeof(uint64_t), ACL_MEMCPY_DEVICE_TO_HOST), 0); ASSERT_EQ((*addr_host_vec), i); - shm::shmemi_control_barrier_all(); } + shmemi_control_barrier_all(); } + ASSERT_EQ(aclrtFreeHost(addr_host), 0); shmem_free(addr_dev); diff --git a/tests/unittest/host/team/team/team_host_test.cpp b/tests/unittest/host/team/team/team_host_test.cpp index 3c17926f..eb72efff 100644 --- a/tests/unittest/host/team/team/team_host_test.cpp +++ b/tests/unittest/host/team/team/team_host_test.cpp @@ -132,7 +132,7 @@ void test_shmem_team(int rank_id, int n_ranks, uint64_t local_mem_size) // #################### device代码测试 ############################## - auto status = test_get_device_state(stream, (uint8_t *)shm::g_state.heap_base, rank_id, n_ranks, team_odd, stride); + auto status = test_get_device_state(stream, (uint8_t *)g_state.heap_base, rank_id, n_ranks, team_odd, stride); EXPECT_EQ(status, SHMEM_SUCCESS); // #################### 相关资源释放 ################################ diff --git a/tests/unittest/include/unittest/mem_host_direct_test.cpp b/tests/unittest/include/unittest/mem_host_direct_test.cpp index e6f22874..a8ef5dd4 100644 --- a/tests/unittest/include/unittest/mem_host_direct_test.cpp +++ b/tests/unittest/include/unittest/mem_host_direct_test.cpp @@ -124,7 +124,7 @@ SHMEM_MEM_PUT_GET_FUNC(GET_MEM_TEST) void *ptr = shmem_malloc(1024); \ host_test_##NAME##_put((uint8_t *)ptr, (uint8_t *)dev_ptr, rank_id, rank_size, is_nbi); \ \ - ASSERT_EQ(aclrtSynchronizeStream(shm::g_state_host.default_stream), 0); \ + ASSERT_EQ(aclrtSynchronizeStream(g_state_host.default_stream), 0); \ sleep(2); \ \ ASSERT_EQ(aclrtMemcpy(input.data(), input_size, ptr, input_size, ACL_MEMCPY_DEVICE_TO_HOST), 0); \ @@ -137,7 +137,7 @@ SHMEM_MEM_PUT_GET_FUNC(GET_MEM_TEST) std::cout << std::endl; \ host_test_##NAME##_get((uint8_t *)ptr, (uint8_t *)dev_ptr, rank_id, rank_size, is_nbi); \ \ - ASSERT_EQ(aclrtSynchronizeStream(shm::g_state_host.default_stream), 0); \ + ASSERT_EQ(aclrtSynchronizeStream(g_state_host.default_stream), 0); \ sleep(2); \ \ ASSERT_EQ(aclrtMemcpy(input.data(), input_size, dev_ptr, input_size, ACL_MEMCPY_DEVICE_TO_HOST), 0); \ @@ -166,7 +166,7 @@ SHMEM_MEM_PUT_GET_FUNC(PUT_GET_TEST) aclrtStream stream; \ test_init(rank_id, n_ranks, local_mem_size, &stream); \ ASSERT_NE(stream, nullptr); \ - host_test_##NAME##_put_get((uint8_t *)shm::g_state.heap_base, rank_id, n_ranks, is_nbi); \ + host_test_##NAME##_put_get((uint8_t *)g_state.heap_base, rank_id, n_ranks, is_nbi); \ \ std::cout << "[TEST] begin to exit...... rank_id: " << rank_id << std::endl; \ test_finalize(stream, device_id); \ diff --git a/tests/unittest/include/unittest/mem_host_get_and_put_test.cpp b/tests/unittest/include/unittest/mem_host_get_and_put_test.cpp index 844f8990..8bffcf07 100644 --- a/tests/unittest/include/unittest/mem_host_get_and_put_test.cpp +++ b/tests/unittest/include/unittest/mem_host_get_and_put_test.cpp @@ -111,7 +111,7 @@ static void host_test_put_get_mem(int rank_id, int rank_size, uint64_t local_mem void *ptr = shmem_malloc(1024); host_test_putmem(ptr, dev_ptr, rank_id, input_size); - ASSERT_EQ(aclrtSynchronizeStream(shm::g_state_host.default_stream), 0); + ASSERT_EQ(aclrtSynchronizeStream(g_state_host.default_stream), 0); sleep(sleep_time); ASSERT_EQ(aclrtMemcpy(input.data(), input_size, ptr, input_size, ACL_MEMCPY_DEVICE_TO_HOST), 0); @@ -124,7 +124,7 @@ static void host_test_put_get_mem(int rank_id, int rank_size, uint64_t local_mem std::cout << std::endl; size_t ele_size = 16; host_test_getmem((uint8_t *)ptr, (uint8_t *)dev_ptr, rank_size, ele_size); - ASSERT_EQ(aclrtSynchronizeStream(shm::g_state_host.default_stream), 0); + ASSERT_EQ(aclrtSynchronizeStream(g_state_host.default_stream), 0); ASSERT_EQ(aclrtMemcpy(input.data(), input_size, dev_ptr, input_size, ACL_MEMCPY_DEVICE_TO_HOST), 0); @@ -167,7 +167,7 @@ void test_host_shmem_putmem_and_getmem(int rank_id, int n_ranks, uint64_t local_ \ void *ptr = shmem_malloc(input_size); \ shmem_##NAME##_p(static_cast(ptr), static_cast(rank_id + 10), rank_id); \ - ASSERT_EQ(aclrtSynchronizeStream(shm::g_state_host.default_stream), 0); \ + ASSERT_EQ(aclrtSynchronizeStream(g_state_host.default_stream), 0); \ sleep(2); \ \ TYPE msg; \ diff --git a/tests/unittest/include/unittest/mem_putmem_signal_test.cpp b/tests/unittest/include/unittest/mem_putmem_signal_test.cpp index 9f9264c4..883356dc 100644 --- a/tests/unittest/include/unittest/mem_putmem_signal_test.cpp +++ b/tests/unittest/include/unittest/mem_putmem_signal_test.cpp @@ -81,7 +81,7 @@ SHMEM_MEM_PUT_GET_FUNC(PUT_MEM_SIGNAL) void *ptr = shmem_malloc(1024); \ int32_t signal = 6; \ putmem_##NAME##_signal_test((TYPE *)ptr, (TYPE *)dev_ptr, (uint8_t *)signal_addr, signal, rank_id, sig_op); \ - ASSERT_EQ(aclrtSynchronizeStream(shm::g_state_host.default_stream), 0); \ + ASSERT_EQ(aclrtSynchronizeStream(g_state_host.default_stream), 0); \ sleep(2); \ \ ASSERT_EQ(aclrtMemcpy(output.data(), input_size, ptr, input_size, ACL_MEMCPY_DEVICE_TO_HOST), 0); \ @@ -180,7 +180,7 @@ SHMEM_MEM_PUT_GET_FUNC(PUT_MEM_SIGNAL_NBI) int32_t signal = 6; \ putmem_signal_##NAME##_test_nbi((TYPE *)ptr, (TYPE *)dev_ptr, (uint8_t *)signal_addr, signal, rank_id, \ sig_op); \ - ASSERT_EQ(aclrtSynchronizeStream(shm::g_state_host.default_stream), 0); \ + ASSERT_EQ(aclrtSynchronizeStream(g_state_host.default_stream), 0); \ sleep(2); \ \ ASSERT_EQ(aclrtMemcpy(output.data(), input_size, ptr, input_size, ACL_MEMCPY_DEVICE_TO_HOST), 0); \ -- Gitee From 8e9afc3abebc40dd5fe774224a0c850e13d5f5b9 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Fri, 28 Nov 2025 17:31:49 +0800 Subject: [PATCH 58/74] examples support 7/14 1.0 default backend --- examples/CMakeLists.txt | 14 ++-- examples/allgather/main.cpp | 33 ++++---- examples/allgather_matmul/main.cpp | 72 ++++++++++++----- examples/allgather_matmul/scripts/run.sh | 7 +- examples/allgather_matmul_padding/main.cpp | 72 +++++++++++------ .../allgather_matmul_padding/scripts/run.sh | 7 +- .../main.cpp | 74 +++++++++++------ .../scripts/run.sh | 7 +- examples/dispatch_gmm_combine/main.cpp | 80 +++++++++++++------ examples/dispatch_gmm_combine/scripts/run.sh | 10 ++- examples/kv_shuffle/main.cpp | 57 ++++++++----- examples/kv_shuffle/scripts/run.sh | 10 ++- examples/matmul_allreduce/main.cpp | 70 +++++++++++----- examples/matmul_allreduce/scripts/run.sh | 7 +- 14 files changed, 340 insertions(+), 180 deletions(-) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 66304b40..cbc14a7c 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -29,7 +29,7 @@ function(shmem_add_fusion_example NAME) if(DEFINED ENABLE_ASCENDC_DUMP AND ENABLE_ASCENDC_DUMP) target_link_libraries(${NAME} PRIVATE ascend_dump) endif() - target_link_libraries(${NAME} PRIVATE shmem ${MPI_CXX_COMPILE_FLAGS}) + target_link_libraries(${NAME} PRIVATE shmem MPI::MPI_CXX) target_compile_options(${NAME} PRIVATE ${MPI_CXX_COMPILE_FLAGS}) endfunction() @@ -62,13 +62,13 @@ function(shmem_add_collective_example NAME) endfunction() foreach(EXAMPLE - # kv_shuffle + kv_shuffle allgather - # allgather_matmul - # allgather_matmul_with_gather_result - # allgather_matmul_padding - # dispatch_gmm_combine - # matmul_allreduce + allgather_matmul + allgather_matmul_with_gather_result + allgather_matmul_padding + dispatch_gmm_combine + matmul_allreduce # matmul_reduce_scatter # matmul_reduce_scatter_padding # dynamic_tiling diff --git a/examples/allgather/main.cpp b/examples/allgather/main.cpp index fe6f43b9..ad6322da 100644 --- a/examples/allgather/main.cpp +++ b/examples/allgather/main.cpp @@ -47,10 +47,13 @@ constexpr int64_t UB_DMA_MAX_SIZE = 190 * 1024; constexpr int64_t GVA_BUFF_MAX_SIZE = 100 * 1024 * 1024; template -int test_shmem_all_gather(int rank_id, int n_ranks, aclrtStream stream) +int test_shmem_all_gather(int rank_id, int n_ranks) { - // 初始化ACL和SHMEM + // ACLStream init int status = 0; + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); + // Prepare FFTS address uint64_t fftsAddr = shmemx_get_ffts_config(); @@ -144,7 +147,8 @@ int test_shmem_all_gather(int rank_id, int n_ranks, aclrtStream stream) } outFile.close(); - return 0; + status = aclrtDestroyStream(stream); + return status; } int main(int argc, char *argv[]) @@ -156,6 +160,7 @@ int main(int argc, char *argv[]) MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + // Shmem uid pre-init shmemx_uniqueid_t uid; if (rank_id == 0) { shmem_get_uniqueid(&uid); @@ -163,32 +168,26 @@ int main(int argc, char *argv[]) std::cout << "MPI_Bcast!" << std::endl; MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); - uint64_t local_mem_size = 1024UL * 1024UL * 1024; - status = aclInit(nullptr); - + // Acl && Shmem init int32_t device_id = rank_id % g_npus + f_npu; + status = aclInit(nullptr); status = aclrtSetDevice(device_id); - - aclrtStream stream = nullptr; - status = aclrtCreateStream(&stream); shmem_init_attr_t *attributes; - shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, - &uid, - &attributes); + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, &uid, &attributes); status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); if (std::string(data_type) == "int") { - status = test_shmem_all_gather(rank_id, n_ranks, stream); + status = test_shmem_all_gather(rank_id, n_ranks); } else if (std::string(data_type) == "int32_t") { - status = test_shmem_all_gather(rank_id, n_ranks, stream); + status = test_shmem_all_gather(rank_id, n_ranks); } else if (std::string(data_type) == "float16_t") { - status = test_shmem_all_gather(rank_id, n_ranks, stream); + status = test_shmem_all_gather(rank_id, n_ranks); } else if (std::string(data_type) == "bfloat16_t") { - status = test_shmem_all_gather(rank_id, n_ranks, stream); + status = test_shmem_all_gather(rank_id, n_ranks); } status = shmem_finalize(); - status = aclrtDestroyStream(stream); status = aclrtResetDevice(device_id); status = aclFinalize(); if (status) { diff --git a/examples/allgather_matmul/main.cpp b/examples/allgather_matmul/main.cpp index be795a7b..2369bce8 100644 --- a/examples/allgather_matmul/main.cpp +++ b/examples/allgather_matmul/main.cpp @@ -9,6 +9,8 @@ */ #include +#include + #include #include #include @@ -217,47 +219,67 @@ struct Options { int main(int argc, char **argv) { int status = SHMEM_SUCCESS; + // MPI Init + MPI_Init(&argc, &argv); + int rank_id, n_ranks; + MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); + MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + + // Shmem uid pre-init + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmem_get_uniqueid(&uid); + } + std::cout << "MPI_Bcast!" << std::endl; + MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); + + // Acl && Shmem init + int32_t device_id = rank_id % n_ranks; + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); + + shmem_init_attr_t *attributes; + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, &uid, &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + + // ACLStream init + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); + + // Kernel-need params parse Options options; if (options.Parse(argc, argv) != 0) { std::cerr << "Invalid arguments\n"; return 1; } - int rankSize = options.rankSize; - int rankId = options.rankId; - std::string ipPort = options.ipPort; + // int rankSize = options.rankSize; + // int rankId = options.rankId; + // std::string ipPort = options.ipPort; uint32_t m = options.m; uint32_t n = options.n; uint32_t k = options.k; - int32_t deviceId = options.deviceIdList[rankId]; - std::cout << "[TEST] input rank_size: " << rankSize << " rank_id:" << rankId << " input_ip: " << ipPort << std::endl; + std::cout << "[TEST] input rank_size: " << n_ranks << " rank_id:" << rank_id << std::endl; - aclrtStream stream = nullptr; - ACL_CHECK(aclInit(nullptr)); - ACL_CHECK(aclrtSetDevice(deviceId)); - ACL_CHECK(aclrtCreateStream(&stream)); - status = shmem_set_conf_store_tls(false, nullptr, 0); - shmem_init_attr_t *attributes; - status = shmem_set_attr(rankId, rankSize, gNpuMallocSpace, ipPort.c_str(), &attributes); - status = shmem_init_attr(attributes); - status = shmem_init_status(); + // status = shmem_set_conf_store_tls(false, nullptr, 0); size_t aSize = static_cast(m) * k * sizeof(__fp16); size_t bSize = static_cast(k) * n * sizeof(__fp16); - size_t cSize = static_cast(m) * rankSize * n * sizeof(__fp16); + size_t cSize = static_cast(m) * n_ranks * n * sizeof(__fp16); uint8_t *aDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&aDevice), aSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *aHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&aHost), aSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_a.bin"), aHost, aSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_a.bin"), aHost, aSize); ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *bDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&bDevice), bSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *bHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&bHost), bSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_b.bin"), bHost, bSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_b.bin"), bHost, bSize); ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *cDevice; @@ -293,7 +315,7 @@ int main(int argc, char **argv) ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); WriteFile(options.GetDataPath("shmem_output.bin"), cHost, cSize); - if (rankId == 0) { + if (rank_id == 0) { std::printf("test finished\n"); } @@ -306,11 +328,17 @@ int main(int argc, char **argv) ACL_CHECK(aclrtFree(bDevice)); ACL_CHECK(aclrtFree(cDevice)); - std::cout << "[TEST] begin to exit...... rankId: " << rankId << std::endl; + status = aclrtDestroyStream(stream); + status = shmem_finalize(); - ACL_CHECK(aclrtDestroyStream(stream)); - ACL_CHECK(aclrtResetDevice(deviceId)); - ACL_CHECK(aclFinalize()); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + if (status) { + std::exit(EXIT_FAILURE); + } + + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; + MPI_Finalize(); return 0; } \ No newline at end of file diff --git a/examples/allgather_matmul/scripts/run.sh b/examples/allgather_matmul/scripts/run.sh index d0016b10..79d2c769 100644 --- a/examples/allgather_matmul/scripts/run.sh +++ b/examples/allgather_matmul/scripts/run.sh @@ -40,9 +40,10 @@ tail -n +2 "$CSV_FILE" | while IFS=',' read -r M K N; do IPPORT="tcp://127.0.0.1:8788" # Start Process - for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do - ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & - done + # for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do + # ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & + # done + mpirun -np ${RANK_SIZE} ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" # Wait until all process exit wait diff --git a/examples/allgather_matmul_padding/main.cpp b/examples/allgather_matmul_padding/main.cpp index d2efdc3f..d1f7da2b 100644 --- a/examples/allgather_matmul_padding/main.cpp +++ b/examples/allgather_matmul_padding/main.cpp @@ -9,6 +9,8 @@ */ #include +#include + #include #include #include @@ -219,30 +221,48 @@ struct Options { int main(int argc, char **argv) { int status = SHMEM_SUCCESS; + // MPI Init + MPI_Init(&argc, &argv); + int rank_id, n_ranks; + MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); + MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + + // Shmem uid pre-init + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmem_get_uniqueid(&uid); + } + std::cout << "MPI_Bcast!" << std::endl; + MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); + + // Acl && Shmem init + int32_t device_id = rank_id % n_ranks; + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); + + shmem_init_attr_t *attributes; + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, &uid, &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + + // ACLStream init + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); + + // Kernel-need params parse Options options; if (options.Parse(argc, argv) != 0) { std::cerr << "Invalid arguments\n"; return 1; } - int rankSize = options.rankSize; - int rankId = options.rankId; - std::string ipPort = options.ipPort; + // int rankSize = options.rankSize; + // int rankId = options.rankId; + // std::string ipPort = options.ipPort; uint32_t m = options.m; uint32_t n = options.n; uint32_t k = options.k; - int32_t deviceId = options.deviceIdList[rankId]; - - std::cout << "[TEST] input rank_size: " << rankSize << " rank_id:" << rankId << " input_ip: " << ipPort << std::endl; - aclrtStream stream = nullptr; - ACL_CHECK(aclInit(nullptr)); - ACL_CHECK(aclrtSetDevice(deviceId)); - ACL_CHECK(aclrtCreateStream(&stream)); - status = shmem_set_conf_store_tls(false, nullptr, 0); - shmem_init_attr_t *attributes; - status = shmem_set_attr(rankId, rankSize, gNpuMallocSpace, ipPort.c_str(), &attributes); - status = shmem_init_attr(attributes); - status = shmem_init_status(); + std::cout << "[TEST] input rank_size: " << n_ranks << " rank_id:" << rank_id << std::endl; LayoutB layoutB{k, n}; constexpr uint32_t alignByByte = 512; @@ -256,20 +276,20 @@ int main(int argc, char **argv) size_t aSize = static_cast(m) * k * sizeof(__fp16); size_t bSize = static_cast(k) * n * sizeof(__fp16); - size_t cSize = static_cast(m) * rankSize * n * sizeof(__fp16); + size_t cSize = static_cast(m) * n_ranks * n * sizeof(__fp16); uint8_t *aDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&aDevice), aSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *aHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&aHost), aSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_a.bin"), aHost, aSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_a.bin"), aHost, aSize); ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *bDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&bDevice), bSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *bHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&bHost), bSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_b.bin"), bHost, bSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_b.bin"), bHost, bSize); ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *cDevice; @@ -313,7 +333,7 @@ int main(int argc, char **argv) ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); WriteFile(options.GetDataPath("shmem_output.bin"), cHost, cSize); - if (rankId == 0) { + if (rank_id == 0) { std::printf("test finished\n"); } @@ -329,11 +349,17 @@ int main(int argc, char **argv) ACL_CHECK(aclrtFree(workspaceDevice)); } - std::cout << "[TEST] begin to exit...... rankId: " << rankId << std::endl; + status = aclrtDestroyStream(stream); + status = shmem_finalize(); - ACL_CHECK(aclrtDestroyStream(stream)); - ACL_CHECK(aclrtResetDevice(deviceId)); - ACL_CHECK(aclFinalize()); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + if (status) { + std::exit(EXIT_FAILURE); + } + + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; + MPI_Finalize(); return 0; } \ No newline at end of file diff --git a/examples/allgather_matmul_padding/scripts/run.sh b/examples/allgather_matmul_padding/scripts/run.sh index 9bec4c3e..1d8b9541 100644 --- a/examples/allgather_matmul_padding/scripts/run.sh +++ b/examples/allgather_matmul_padding/scripts/run.sh @@ -40,9 +40,10 @@ tail -n +2 "$CSV_FILE" | while IFS=',' read -r M K N; do IPPORT="tcp://127.0.0.1:8788" # Start Process - for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do - ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & - done + # for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do + # ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & + # done + mpirun -np ${RANK_SIZE} ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" # Wait until all process exit wait diff --git a/examples/allgather_matmul_with_gather_result/main.cpp b/examples/allgather_matmul_with_gather_result/main.cpp index 413628e2..22d0fa9a 100644 --- a/examples/allgather_matmul_with_gather_result/main.cpp +++ b/examples/allgather_matmul_with_gather_result/main.cpp @@ -9,6 +9,8 @@ */ #include +#include + #include #include #include @@ -228,48 +230,66 @@ struct Options { int main(int argc, char **argv) { int status = SHMEM_SUCCESS; + // MPI Init + MPI_Init(&argc, &argv); + int rank_id, n_ranks; + MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); + MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + + // Shmem uid pre-init + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmem_get_uniqueid(&uid); + } + std::cout << "MPI_Bcast!" << std::endl; + MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); + + // Acl && Shmem init + int32_t device_id = rank_id % n_ranks; + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); + + shmem_init_attr_t *attributes; + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, &uid, &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + + // ACLStream init + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); + + // Kernel-need params parse Options options; if (options.Parse(argc, argv) != 0) { std::cerr << "Invalid arguments\n"; return 1; } - int rankSize = options.rankSize; - int rankId = options.rankId; - std::string ipPort = options.ipPort; + // int rankSize = options.rankSize; + // int rankId = options.rankId; + // std::string ipPort = options.ipPort; uint32_t m = options.m; uint32_t n = options.n; uint32_t k = options.k; - int32_t deviceId = options.deviceIdList[rankId]; - - std::cout << "[TEST] input rank_size: " << rankSize << " rank_id:" << rankId << " input_ip: " << ipPort << std::endl; - aclrtStream stream = nullptr; - ACL_CHECK(aclInit(nullptr)); - ACL_CHECK(aclrtSetDevice(deviceId)); - ACL_CHECK(aclrtCreateStream(&stream)); - status = shmem_set_conf_store_tls(false, nullptr, 0); - shmem_init_attr_t *attributes; - status = shmem_set_attr(rankId, rankSize, gNpuMallocSpace, ipPort.c_str(), &attributes); - status = shmem_init_attr(attributes); - status = shmem_init_status(); + std::cout << "[TEST] input rank_size: " << n_ranks << " rank_id:" << rank_id << std::endl; size_t aSize = static_cast(m) * k * sizeof(__fp16); size_t bSize = static_cast(k) * n * sizeof(__fp16); - size_t cSize = static_cast(m) * rankSize * n * sizeof(__fp16); - size_t gatherASize = static_cast(m) * rankSize * k * sizeof(__fp16); + size_t cSize = static_cast(m) * n_ranks * n * sizeof(__fp16); + size_t gatherASize = static_cast(m) * n_ranks * k * sizeof(__fp16); uint8_t *aDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&aDevice), aSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *aHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&aHost), aSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_a.bin"), aHost, aSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_a.bin"), aHost, aSize); ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *bDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&bDevice), bSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *bHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&bHost), bSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_b.bin"), bHost, bSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_b.bin"), bHost, bSize); ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *cDevice; @@ -298,7 +318,7 @@ int main(int argc, char **argv) ACL_CHECK(aclrtSynchronizeStream(stream)); std::cout << "After calling AG_MM kernel " << std::endl; - if (rankId == 0) { + if (rank_id == 0) { ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); ACL_CHECK(aclrtMemcpy(gatherAHost, gatherASize, gatherADevice, gatherASize, ACL_MEMCPY_DEVICE_TO_HOST)); WriteFile(options.GetDataPath("shmem_output.bin"), cHost, cSize); @@ -317,11 +337,17 @@ int main(int argc, char **argv) ACL_CHECK(aclrtFree(cDevice)); ACL_CHECK(aclrtFree(gatherADevice)); - std::cout << "[TEST] begin to exit...... rankId: " << rankId << std::endl; + status = aclrtDestroyStream(stream); + status = shmem_finalize(); - ACL_CHECK(aclrtDestroyStream(stream)); - ACL_CHECK(aclrtResetDevice(deviceId)); - ACL_CHECK(aclFinalize()); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + if (status) { + std::exit(EXIT_FAILURE); + } + + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; + MPI_Finalize(); return 0; } diff --git a/examples/allgather_matmul_with_gather_result/scripts/run.sh b/examples/allgather_matmul_with_gather_result/scripts/run.sh index 436fbcfe..6a9d8543 100644 --- a/examples/allgather_matmul_with_gather_result/scripts/run.sh +++ b/examples/allgather_matmul_with_gather_result/scripts/run.sh @@ -40,9 +40,10 @@ tail -n +2 "$CSV_FILE" | while IFS=',' read -r M K N; do IPPORT="tcp://127.0.0.1:8788" # Start Process - for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do - ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & - done + # for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do + # ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & + # done + mpirun -np ${RANK_SIZE} ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" # Wait until all process exit wait diff --git a/examples/dispatch_gmm_combine/main.cpp b/examples/dispatch_gmm_combine/main.cpp index 758fa0ff..6c37ccc4 100644 --- a/examples/dispatch_gmm_combine/main.cpp +++ b/examples/dispatch_gmm_combine/main.cpp @@ -9,6 +9,8 @@ */ #include +#include + #include #include @@ -283,25 +285,45 @@ void InitData(uint8_t **hostPtr, uint8_t **devicePtr, size_t aSize, std::string int main(int argc, char **argv) { int status = SHMEM_SUCCESS; - int rankSize = atoi(argv[1]); - int rankId = atoi(argv[2]); - std::string ipport = argv[3]; + // MPI Init + MPI_Init(&argc, &argv); + int rank_id, n_ranks; + MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); + MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + + // Shmem uid pre-init + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmem_get_uniqueid(&uid); + } + std::cout << "MPI_Bcast!" << std::endl; + MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); + + // Acl && Shmem init + int32_t device_id = rank_id % n_ranks; + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); - ACL_CHECK(aclInit(nullptr)); - int32_t deviceId = atoi(argv[4]) + rankId % gNpuNum; - ACL_CHECK(aclrtSetDevice(deviceId)); - aclrtStream stream = nullptr; - ACL_CHECK(aclrtCreateStream(&stream)); - status = shmem_set_conf_store_tls(false, nullptr, 0); shmem_init_attr_t *attributes; - status = shmem_set_attr(rankId, rankSize, gNpuMallocSpace, ipport.c_str(), &attributes); - status = shmem_init_attr(attributes); - status = shmem_init_status(); + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, &uid, &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + + // ACLStream init + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); + + // Kernel-need params parse + // int rankSize = atoi(argv[1]); + // int rankId = atoi(argv[2]); + // std::string ipport = argv[3]; + + // status = shmem_set_conf_store_tls(false, nullptr, 0); uint32_t m = atoi(argv[5]); uint32_t k = atoi(argv[6]); uint32_t n = atoi(argv[7]); - uint32_t EP = rankSize; + uint32_t EP = n_ranks; uint32_t expertPerRank = atoi(argv[8]); uint32_t dataType = atoi(argv[9]); uint32_t weightNz = atoi(argv[10]); @@ -360,13 +382,13 @@ int main(int argc, char **argv) "_" + std::to_string(dataType) + "_1_" + std::to_string(m) + "_" + std::to_string(k) + "_" + std::to_string(n) + "_" + std::to_string(expertPerRank) + "_" + std::to_string(EP) + "_1.bin"; - InitData(&b1Host, &b1Device, b1Size, filePrefix + "matrix_b1_" + std::to_string(rankId) + fileSuffix); - InitData(&b2Host, &b2Device, b2Size, filePrefix + "matrix_b2_" + std::to_string(rankId) + fileSuffix); + InitData(&b1Host, &b1Device, b1Size, filePrefix + "matrix_b1_" + std::to_string(rank_id) + fileSuffix); + InitData(&b2Host, &b2Device, b2Size, filePrefix + "matrix_b2_" + std::to_string(rank_id) + fileSuffix); InitData(&cHost, &cDevice, cSize); InitData(&scale1Host, &scale1Device, dequantScale1Size, - filePrefix + "matrix_dequant_scale1_" + std::to_string(rankId) + fileSuffix); + filePrefix + "matrix_dequant_scale1_" + std::to_string(rank_id) + fileSuffix); InitData(&scale2Host, &scale2Device, dequantScale2Size, - filePrefix + "matrix_dequant_scale2_" + std::to_string(rankId) + fileSuffix); + filePrefix + "matrix_dequant_scale2_" + std::to_string(rank_id) + fileSuffix); InitData(&probsHost, &probsDevice, probsSize, filePrefix + "probs" + fileSuffix); uint8_t *expertIdx; @@ -386,9 +408,9 @@ int main(int argc, char **argv) int64_t quantMode = 1; std::string dispatchFileSuffix = ""; InitData(&aHost, &aDevice, m * k * sizeof(float16_t), - filePrefix + "matrix_a_" + std::to_string(rankId) + fileSuffix); + filePrefix + "matrix_a_" + std::to_string(rank_id) + fileSuffix); InitData(&expertIdxHost, &expertIdx, m * topK * sizeof(int32_t), - filePrefix + "expert_idx_" + std::to_string(rankId) + fileSuffix); + filePrefix + "expert_idx_" + std::to_string(rank_id) + fileSuffix); moeInitRoutingQuantV2Scale = nullptr; moeInitRoutingQuantV2Offset = nullptr; @@ -405,7 +427,7 @@ int main(int argc, char **argv) size_t initRoutingWorkspace = moeInitRoutingQuantV2TilingBase.workspaceSize_; workspaceSize += initRoutingWorkspace; printf("!!!!!!!!!! initRoutingQuantTilingKey %lu\n\n", initRoutingQuantTilingKey); - if (rankId == 0) { + if (rank_id == 0) { moeInitRoutingQuantV2TilingBase.ShowTilingData(); } @@ -444,8 +466,8 @@ int main(int argc, char **argv) ACL_CHECK(aclrtSynchronizeStream(stream)); ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); - WriteFile("./out/output_" + std::to_string(rankId) + ".bin", cHost, cSize); - if (rankId == 0) { + WriteFile("./out/output_" + std::to_string(rank_id) + ".bin", cHost, cSize); + if (rank_id == 0) { std::printf("\ntest finished\n"); } shmem_free(symmPtr); @@ -458,11 +480,17 @@ int main(int argc, char **argv) ACL_CHECK(aclrtFreeHost(expertIdxHost)); ACL_CHECK(aclrtFree(expertIdx)); - std::cout << "[TEST] begin to exit...... rankId: " << rankId << std::endl; + status = aclrtDestroyStream(stream); + status = shmem_finalize(); - ACL_CHECK(aclrtDestroyStream(stream)); - ACL_CHECK(aclrtResetDevice(deviceId)); - ACL_CHECK(aclFinalize()); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + if (status) { + std::exit(EXIT_FAILURE); + } + + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; + MPI_Finalize(); return 0; } diff --git a/examples/dispatch_gmm_combine/scripts/run.sh b/examples/dispatch_gmm_combine/scripts/run.sh index 2b506b1a..2274a287 100644 --- a/examples/dispatch_gmm_combine/scripts/run.sh +++ b/examples/dispatch_gmm_combine/scripts/run.sh @@ -115,10 +115,12 @@ EXEC_BIN=${PROJECT_ROOT}/build/bin/dispatch_gmm_combine cd ${PROJECT_ROOT}/examples/dispatch_gmm_combine/ echo "Test Case, M: ${M}, K: ${K}, N: ${N}, expertPerRank: ${expertPerRank}" -export LD_LIBRARY_PATH=${PROJECT_ROOT}/install/shmem/lib:${ASCEND_HOME_PATH}/lib64:${PROJECT_ROOT}/install/memfabric_hybrid/lib:$LD_LIBRARY_PATH -for (( idx =0; idx < ${RANK_SIZE}; idx = idx + 1 )); do - INPUT_PATH=${CURRENT_DIR}/utils/test_data/ ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$FIRST_NPU" "$M" "$K" "$N" "$expertPerRank" "$dataType" "$weightNz" "$transB" & -done +# export LD_LIBRARY_PATH=${PROJECT_ROOT}/install/shmem/lib:${ASCEND_HOME_PATH}/lib64:${PROJECT_ROOT}/install/memfabric_hybrid/lib:$LD_LIBRARY_PATH +# for (( idx =0; idx < ${RANK_SIZE}; idx = idx + 1 )); do +# INPUT_PATH=${CURRENT_DIR}/utils/test_data/ ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$FIRST_NPU" "$M" "$K" "$N" "$expertPerRank" "$dataType" "$weightNz" "$transB" & +# done +export INPUT_PATH=${CURRENT_DIR}/utils/test_data/ +mpirun -np ${RANK_SIZE} ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$FIRST_NPU" "$M" "$K" "$N" "$expertPerRank" "$dataType" "$weightNz" "$transB" # Wait until all process exit wait diff --git a/examples/kv_shuffle/main.cpp b/examples/kv_shuffle/main.cpp index e00c11d5..6bd29506 100644 --- a/examples/kv_shuffle/main.cpp +++ b/examples/kv_shuffle/main.cpp @@ -30,6 +30,8 @@ using fp16_t = op::fp16_t; using bfloat16 = op::bfloat16; +#include + #include "acl/acl.h" #include "shmem_api.h" #include "kv_shuffle_kernel.h" @@ -46,21 +48,13 @@ constexpr int64_t max_block_nums = MAX_SEQLEN * MAX_BATCH / page_size; constexpr int64_t kv_head_num = 8; constexpr int64_t head_dim = 128; -int test_shmem_kv_shuffle(int rank_id, int n_ranks, uint64_t local_mem_size) +int test_shmem_kv_shuffle(int rank_id, int n_ranks) { - // 初始化ACL和SHMEM - int32_t device_id = rank_id % g_npus + f_npu; + // ACLStream init int status = 0; aclrtStream stream = nullptr; - - status = aclInit(nullptr); - status = aclrtSetDevice(device_id); status = aclrtCreateStream(&stream); - shmem_init_attr_t *attributes; - status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); - status = shmem_init_attr(attributes); - uint32_t BLOCK_NUM = 16; int64_t kv_cache_size = max_block_nums * kv_head_num * page_size * head_dim * sizeof(int8_t); @@ -179,26 +173,51 @@ int test_shmem_kv_shuffle(int rank_id, int n_ranks, uint64_t local_mem_size) status = aclrtFreeHost(k_output_host); status = aclrtFreeHost(v_output_host); - status = shmem_finalize(); status = aclrtDestroyStream(stream); - status = aclrtResetDevice(device_id); - status = aclFinalize(); - return 0; + return status; } int main(int argc, char *argv[]) { int status = 0; - int n_ranks = atoi(argv[1]); - int rank_id = atoi(argv[2]); - ipport = argv[3]; + // MPI Init + MPI_Init(&argc, &argv); + int rank_id, n_ranks; + MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); + MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + + // Shmem uid pre-init + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmem_get_uniqueid(&uid); + } + std::cout << "MPI_Bcast!" << std::endl; + MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); + + // Acl && Shmem init + int32_t device_id = rank_id % g_npus + f_npu; + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); + + shmem_init_attr_t *attributes; uint64_t local_mem_size = 1024UL * 1024UL * 1024; - int32_t ret = shmem_set_conf_store_tls(false, nullptr, 0); + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, &uid, &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); - status = test_shmem_kv_shuffle(rank_id, n_ranks, local_mem_size); + // int32_t ret = shmem_set_conf_store_tls(false, nullptr, 0); + + status = test_shmem_kv_shuffle(rank_id, n_ranks); + + status = shmem_finalize(); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + if (status) { + std::exit(EXIT_FAILURE); + } std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; + MPI_Finalize(); return 0; } diff --git a/examples/kv_shuffle/scripts/run.sh b/examples/kv_shuffle/scripts/run.sh index a2b3d3b0..32a002aa 100644 --- a/examples/kv_shuffle/scripts/run.sh +++ b/examples/kv_shuffle/scripts/run.sh @@ -23,10 +23,12 @@ rm -rf scripts/output/*.bin python3 scripts/golden.py $RANK_SIZE # Start Process -for (( idx =0; idx < ${RANK_SIZE}; idx = idx + 1 )); do - APP="$EXEC_BIN $RANK_SIZE $idx $IPPORT" - ${APP}& -done +mpirun -np ${RANK_SIZE} ${PROJECT_ROOT}/build/bin/kv_shuffle + +# for (( idx =0; idx < ${RANK_SIZE}; idx = idx + 1 )); do +# APP="$EXEC_BIN $RANK_SIZE $idx $IPPORT" +# ${APP}& +# done # Wait until all process exit wait diff --git a/examples/matmul_allreduce/main.cpp b/examples/matmul_allreduce/main.cpp index 11f297aa..bb245c98 100644 --- a/examples/matmul_allreduce/main.cpp +++ b/examples/matmul_allreduce/main.cpp @@ -9,6 +9,8 @@ */ #include +#include + #include #include #include @@ -215,30 +217,48 @@ struct Options { int main(int argc, char **argv) { int status = SHMEM_SUCCESS; + // MPI Init + MPI_Init(&argc, &argv); + int rank_id, n_ranks; + MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); + MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + + // Shmem uid pre-init + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmem_get_uniqueid(&uid); + } + std::cout << "MPI_Bcast!" << std::endl; + MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); + + // Acl && Shmem init + int32_t device_id = rank_id % n_ranks; + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); + + shmem_init_attr_t *attributes; + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, &uid, &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + + // ACLStream init + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); + + // Kernel-need params parse Options options; if (options.Parse(argc, argv) != 0) { std::cerr << "Invalid arguments\n"; return 1; } - int rankSize = options.rankSize; - int rankId = options.rankId; - std::string ipPort = options.ipPort; + // int rankSize = options.rankSize; + // int rankId = options.rankId; + // std::string ipPort = options.ipPort; uint32_t m = options.m; uint32_t n = options.n; uint32_t k = options.k; - int32_t deviceId = options.deviceIdList[rankId]; - - std::cout << "[TEST] input rank_size: " << rankSize << " rank_id:" << rankId << " input_ip: " << ipPort << "\n"; - aclrtStream stream = nullptr; - ACL_CHECK(aclInit(nullptr)); - ACL_CHECK(aclrtSetDevice(deviceId)); - ACL_CHECK(aclrtCreateStream(&stream)); - status = shmem_set_conf_store_tls(false, nullptr, 0); - shmem_init_attr_t *attributes; - status = shmem_set_attr(rankId, rankSize, NPU_MALLOC_SPACE, ipPort.c_str(), &attributes); - status = shmem_init_attr(attributes); - status = shmem_init_status(); + std::cout << "[TEST] input rank_size: " << n_ranks << " rank_id:" << rank_id << "\n"; size_t aSize = static_cast(m) * k * sizeof(__fp16); size_t bSize = static_cast(k) * n * sizeof(__fp16); @@ -248,14 +268,14 @@ int main(int argc, char **argv) ACL_CHECK(aclrtMalloc(reinterpret_cast(&aDevice), aSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *aHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&aHost), aSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_a.bin"), aHost, aSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_a.bin"), aHost, aSize); ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *bDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&bDevice), bSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *bHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&bHost), bSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_b.bin"), bHost, bSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_b.bin"), bHost, bSize); ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *dDevice; @@ -281,7 +301,7 @@ int main(int argc, char **argv) ACL_CHECK(aclrtSynchronizeStream(stream)); std::cout << "After calling MM_AR kernel " << std::endl; - if (rankId == 0) { + if (rank_id == 0) { ACL_CHECK(aclrtMemcpy(dHost, dSize, dDevice, dSize, ACL_MEMCPY_DEVICE_TO_HOST)); WriteFile(options.GetDataPath("shmem_output.bin"), dHost, dSize); std::printf("test finished\n"); @@ -296,11 +316,17 @@ int main(int argc, char **argv) ACL_CHECK(aclrtFree(bDevice)); ACL_CHECK(aclrtFree(dDevice)); - std::cout << "[TEST] begin to exit...... rankId: " << rankId << std::endl; + status = aclrtDestroyStream(stream); + status = shmem_finalize(); - ACL_CHECK(aclrtDestroyStream(stream)); - ACL_CHECK(aclrtResetDevice(deviceId)); - ACL_CHECK(aclFinalize()); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + if (status) { + std::exit(EXIT_FAILURE); + } + + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; + MPI_Finalize(); return 0; } diff --git a/examples/matmul_allreduce/scripts/run.sh b/examples/matmul_allreduce/scripts/run.sh index 75ff4ce4..e4f3f6cc 100644 --- a/examples/matmul_allreduce/scripts/run.sh +++ b/examples/matmul_allreduce/scripts/run.sh @@ -39,9 +39,10 @@ tail -n +2 "$CSV_FILE" | while IFS=',' read -r M K N; do IPPORT="tcp://127.0.0.1:8788" # Start Process - for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do - ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & - done + # for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do + # ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & + # done + mpirun -np ${RANK_SIZE} ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" # Wait until all process exit wait -- Gitee From 61974f9324ac7405174729c38c7777b124de37e8 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Sat, 29 Nov 2025 11:15:21 +0800 Subject: [PATCH 59/74] examples support 10/14 2.0 default backend --- examples/CMakeLists.txt | 6 +- examples/allgather_matmul/main.cpp | 26 ++--- examples/allgather_matmul_padding/main.cpp | 26 ++--- .../main.cpp | 26 ++--- examples/dynamic_tiling/CMakeLists.txt | 10 +- examples/dynamic_tiling/main.cpp | 104 ++++++++++++------ examples/dynamic_tiling/scripts/run.sh | 19 ++-- examples/matmul_allreduce/main.cpp | 26 ++--- examples/matmul_reduce_scatter/main.cpp | 68 ++++++++---- examples/matmul_reduce_scatter/scripts/run.sh | 7 +- .../matmul_reduce_scatter_padding/main.cpp | 68 ++++++++---- .../scripts/run.sh | 7 +- 12 files changed, 235 insertions(+), 158 deletions(-) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index cbc14a7c..e72c2383 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -69,9 +69,9 @@ foreach(EXAMPLE allgather_matmul_padding dispatch_gmm_combine matmul_allreduce - # matmul_reduce_scatter - # matmul_reduce_scatter_padding - # dynamic_tiling + matmul_reduce_scatter + matmul_reduce_scatter_padding + dynamic_tiling # rdma_perftest # rdma_demo # rdma_handlewait_test/unuse_handlewait diff --git a/examples/allgather_matmul/main.cpp b/examples/allgather_matmul/main.cpp index 2369bce8..806c8540 100644 --- a/examples/allgather_matmul/main.cpp +++ b/examples/allgather_matmul/main.cpp @@ -225,6 +225,18 @@ int main(int argc, char **argv) MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + // Kernel-need params parse + Options options; + if (options.Parse(argc, argv) != 0) { + std::cerr << "Invalid arguments\n"; + return 1; + } + + uint32_t m = options.m; + uint32_t n = options.n; + uint32_t k = options.k; + int32_t device_id = options.deviceIdList[rank_id]; + // Shmem uid pre-init shmemx_uniqueid_t uid; if (rank_id == 0) { @@ -234,7 +246,6 @@ int main(int argc, char **argv) MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); // Acl && Shmem init - int32_t device_id = rank_id % n_ranks; status = aclInit(nullptr); status = aclrtSetDevice(device_id); @@ -247,19 +258,6 @@ int main(int argc, char **argv) aclrtStream stream = nullptr; status = aclrtCreateStream(&stream); - // Kernel-need params parse - Options options; - if (options.Parse(argc, argv) != 0) { - std::cerr << "Invalid arguments\n"; - return 1; - } - // int rankSize = options.rankSize; - // int rankId = options.rankId; - // std::string ipPort = options.ipPort; - uint32_t m = options.m; - uint32_t n = options.n; - uint32_t k = options.k; - std::cout << "[TEST] input rank_size: " << n_ranks << " rank_id:" << rank_id << std::endl; // status = shmem_set_conf_store_tls(false, nullptr, 0); diff --git a/examples/allgather_matmul_padding/main.cpp b/examples/allgather_matmul_padding/main.cpp index d1f7da2b..96326422 100644 --- a/examples/allgather_matmul_padding/main.cpp +++ b/examples/allgather_matmul_padding/main.cpp @@ -227,6 +227,18 @@ int main(int argc, char **argv) MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + // Kernel-need params parse + Options options; + if (options.Parse(argc, argv) != 0) { + std::cerr << "Invalid arguments\n"; + return 1; + } + + uint32_t m = options.m; + uint32_t n = options.n; + uint32_t k = options.k; + int32_t device_id = options.deviceIdList[rank_id]; + // Shmem uid pre-init shmemx_uniqueid_t uid; if (rank_id == 0) { @@ -236,7 +248,6 @@ int main(int argc, char **argv) MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); // Acl && Shmem init - int32_t device_id = rank_id % n_ranks; status = aclInit(nullptr); status = aclrtSetDevice(device_id); @@ -249,19 +260,6 @@ int main(int argc, char **argv) aclrtStream stream = nullptr; status = aclrtCreateStream(&stream); - // Kernel-need params parse - Options options; - if (options.Parse(argc, argv) != 0) { - std::cerr << "Invalid arguments\n"; - return 1; - } - // int rankSize = options.rankSize; - // int rankId = options.rankId; - // std::string ipPort = options.ipPort; - uint32_t m = options.m; - uint32_t n = options.n; - uint32_t k = options.k; - std::cout << "[TEST] input rank_size: " << n_ranks << " rank_id:" << rank_id << std::endl; LayoutB layoutB{k, n}; diff --git a/examples/allgather_matmul_with_gather_result/main.cpp b/examples/allgather_matmul_with_gather_result/main.cpp index 22d0fa9a..15aebe21 100644 --- a/examples/allgather_matmul_with_gather_result/main.cpp +++ b/examples/allgather_matmul_with_gather_result/main.cpp @@ -236,6 +236,18 @@ int main(int argc, char **argv) MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + // Kernel-need params parse + Options options; + if (options.Parse(argc, argv) != 0) { + std::cerr << "Invalid arguments\n"; + return 1; + } + + uint32_t m = options.m; + uint32_t n = options.n; + uint32_t k = options.k; + int32_t device_id = options.deviceIdList[rank_id]; + // Shmem uid pre-init shmemx_uniqueid_t uid; if (rank_id == 0) { @@ -245,7 +257,6 @@ int main(int argc, char **argv) MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); // Acl && Shmem init - int32_t device_id = rank_id % n_ranks; status = aclInit(nullptr); status = aclrtSetDevice(device_id); @@ -258,19 +269,6 @@ int main(int argc, char **argv) aclrtStream stream = nullptr; status = aclrtCreateStream(&stream); - // Kernel-need params parse - Options options; - if (options.Parse(argc, argv) != 0) { - std::cerr << "Invalid arguments\n"; - return 1; - } - // int rankSize = options.rankSize; - // int rankId = options.rankId; - // std::string ipPort = options.ipPort; - uint32_t m = options.m; - uint32_t n = options.n; - uint32_t k = options.k; - std::cout << "[TEST] input rank_size: " << n_ranks << " rank_id:" << rank_id << std::endl; size_t aSize = static_cast(m) * k * sizeof(__fp16); diff --git a/examples/dynamic_tiling/CMakeLists.txt b/examples/dynamic_tiling/CMakeLists.txt index 4f5d6fbd..e8effa44 100644 --- a/examples/dynamic_tiling/CMakeLists.txt +++ b/examples/dynamic_tiling/CMakeLists.txt @@ -2,7 +2,7 @@ add_custom_target(lib_impl) function(add_impl_share_lib NAME) add_library(${NAME} SHARED ${ARGN}) - target_compile_options(${NAME} PRIVATE ${CMAKE_CCE_COMPILE_OPTIONS} --cce-aicore-arch=dav-c220) + target_compile_options(${NAME} PRIVATE ${CMAKE_CCE_COMPILE_OPTIONS} ${MPI_CXX_COMPILE_FLAGS} --cce-aicore-arch=dav-c220) target_include_directories(${NAME} PRIVATE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src/memfabric_hybrid/src/smem/include/host/ @@ -15,7 +15,7 @@ function(add_impl_share_lib NAME) ${PROJECT_SOURCE_DIR}/examples/utils ) target_link_options(${NAME} PRIVATE --cce-fatobj-link) - target_link_libraries(${NAME} PRIVATE shmem) + target_link_libraries(${NAME} PRIVATE shmem MPI::MPI_CXX) add_dependencies(lib_impl ${NAME}) endfunction() @@ -29,7 +29,7 @@ set(TILING_SOURCES ) add_library(tiling_lib SHARED ${TILING_SOURCES}) target_include_directories(tiling_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/tiling ${CMAKE_CURRENT_SOURCE_DIR}/include) -target_compile_options(tiling_lib PRIVATE -O3) +target_compile_options(tiling_lib PRIVATE -O3 ${MPI_CXX_COMPILE_FLAGS}) add_executable(dynamic_tiling main.cpp) @@ -45,7 +45,7 @@ target_include_directories(dynamic_tiling PRIVATE ${PROJECT_SOURCE_DIR}/examples/utils ) target_link_options(dynamic_tiling PRIVATE --cce-fatobj-link) -target_link_libraries(dynamic_tiling PRIVATE tiling_lib shmem ${SHARE_LIB_LINK}) -target_compile_options(dynamic_tiling PRIVATE -O3) +target_link_libraries(dynamic_tiling PRIVATE tiling_lib shmem ${SHARE_LIB_LINK} MPI::MPI_CXX) +target_compile_options(dynamic_tiling PRIVATE -O3 ${MPI_CXX_COMPILE_FLAGS}) add_dependencies(dynamic_tiling lib_impl tiling_lib) diff --git a/examples/dynamic_tiling/main.cpp b/examples/dynamic_tiling/main.cpp index 0a07caab..db55405d 100644 --- a/examples/dynamic_tiling/main.cpp +++ b/examples/dynamic_tiling/main.cpp @@ -12,6 +12,8 @@ #include #include +#include + #include #include #include @@ -189,28 +191,51 @@ std::string GetCurrentTime() int main(int argc, char **argv) { int status = SHMEM_SUCCESS; + // MPI Init + MPI_Init(&argc, &argv); + int rank_id, n_ranks; + MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); + MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + + // Kernel-need params parse Options options; - options.Parse(argc, argv); + if (options.Parse(argc, argv) != 0) { + std::cerr << "Invalid arguments\n"; + return 1; + } + + uint32_t m = options.m; + uint32_t n = options.n; + uint32_t k = options.k; CocCommType commType = options.commType; CocDataType dataType = options.dataType; - int rankSize = options.rankSize; - int rankId = options.rankId; - std::string ipPort = options.ipPort; - int32_t deviceId = options.deviceIdList[rankId]; - std::string data_file = options.data_file; - const std::vector> shapes = InitTestShapes(options); + int32_t device_id = options.deviceIdList[rank_id]; - std::cout << "[TEST] input rank_size: " << rankSize << " rank_id: " << rankId << " input_ip: " << ipPort << "\n"; + // Shmem uid pre-init + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmem_get_uniqueid(&uid); + } + std::cout << "MPI_Bcast!" << std::endl; + MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); + + // Acl && Shmem init + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); - aclrtStream stream = nullptr; - ACL_CHECK(aclInit(nullptr)); - ACL_CHECK(aclrtSetDevice(deviceId)); - ACL_CHECK(aclrtCreateStream(&stream)); - status = shmem_set_conf_store_tls(false, nullptr, 0); shmem_init_attr_t *attributes; - status = shmem_set_attr(rankId, rankSize, SHMEM_MALLOC_MAX_SIZE, ipPort.c_str(), &attributes); - status = shmem_init_attr(attributes); - status = shmem_init_status(); + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, &uid, &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + + // ACLStream init + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); + + std::string data_file = options.data_file; + const std::vector> shapes = InitTestShapes(options); + + std::cout << "[TEST] input rank_size: " << n_ranks << " rank_id: " << rank_id << "\n"; uint64_t fftsAddr{0}; uint32_t fftsLen{0}; @@ -219,7 +244,7 @@ int main(int argc, char **argv) std::string currentTime = GetCurrentTime(); std::string currentDir = options.parentPath; std::string tilingFileName = currentDir + "/output/tiling/tilingData_" + currentTime + ".csv"; - if (rankId == 0) { + if (rank_id == 0) { CreateTilingFile(tilingFileName); } @@ -244,19 +269,19 @@ int main(int argc, char **argv) cocTiling.commNpuSplit = 1; cocTiling.commDataSplit = 16; cocTiling.commBlockM = 64; - cocTiling.rankSize = rankSize; + cocTiling.rankSize = n_ranks; size_t aSize = static_cast(m) * k * sizeof(half); size_t bSize = static_cast(k) * n * sizeof(half); size_t cSize = static_cast(m) * n * sizeof(half); size_t cSizePerRank; - size_t gatherASize = aSize * rankSize; + size_t gatherASize = aSize * n_ranks; size_t wASize = 0; size_t wBSize = 0; if (commType == MATMUL_REDUCE_SCATTER) { - cSizePerRank = cSize / rankSize; + cSizePerRank = cSize / n_ranks; } else if (commType == MATMUL_REDUCE_SCATTER_PADDING) { - cSizePerRank = cSize / rankSize; + cSizePerRank = cSize / n_ranks; bool isNeedPaddingA = IsNeedPadding(m, k, transA); bool isNeedPaddingB = IsNeedPadding(k, n, transB); @@ -274,9 +299,9 @@ int main(int argc, char **argv) kernelType = MATMUL_REDUCE_SCATTER; } } else if (commType == ALLGATHER_MATMUL || commType == ALLGATHER_MATMUL_WITH_GATHER_RESULT) { - cSizePerRank = cSize * rankSize; + cSizePerRank = cSize * n_ranks; } else if (commType == ALLGATHER_MATMUL_PADDING) { - cSizePerRank = cSize * rankSize; + cSizePerRank = cSize * n_ranks; bool isNeedPaddingB = IsNeedPadding(k, n, transB); if (isNeedPaddingB) { @@ -296,7 +321,7 @@ int main(int argc, char **argv) uint8_t *aHost; if (data_file != "") { ACL_CHECK(aclrtMallocHost(reinterpret_cast(&aHost), aSize)); - ReadFile(data_file + "/rank_" + std::to_string(rankId) + "_a.bin", aHost, aSize); + ReadFile(data_file + "/rank_" + std::to_string(rank_id) + "_a.bin", aHost, aSize); ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); } else { std::vector matrixA(m * k, 1); @@ -308,7 +333,7 @@ int main(int argc, char **argv) uint8_t *bHost; if (data_file != "") { ACL_CHECK(aclrtMallocHost(reinterpret_cast(&bHost), bSize)); - ReadFile(data_file + "/rank_" + std::to_string(rankId) + "_b.bin", bHost, bSize); + ReadFile(data_file + "/rank_" + std::to_string(rank_id) + "_b.bin", bHost, bSize); ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); } else { std::vector matrixB(k * n, 1); @@ -355,11 +380,11 @@ int main(int argc, char **argv) } else { if (searchparams == 1) { // 搜索 tiling - GetTilings(cocTilings, cocTiling, commType, rankSize); + GetTilings(cocTilings, cocTiling, commType, n_ranks); } else { - bool ok = ApplyLookupTable(info, commType, rankSize, cocTiling); + bool ok = ApplyLookupTable(info, commType, n_ranks, cocTiling); if (!ok) { - std::cerr << "[LUT] no table for (" << opName << "," << rankSize << "), using defaults\n"; + std::cerr << "[LUT] no table for (" << opName << "," << n_ranks << "), using defaults\n"; } cocTilings.push_back(cocTiling); } @@ -393,22 +418,22 @@ int main(int argc, char **argv) if (data_file != "") { if (commType == MATMUL_ALLREDUCE) { - if (rankId == 0) { + if (rank_id == 0) { WriteFile(data_file + "/output.bin", cHost, cSizePerRank); } } else if (commType == ALLGATHER_MATMUL || commType == ALLGATHER_MATMUL_PADDING || commType == ALLGATHER_MATMUL_WITH_GATHER_RESULT) { - if (rankId == 0) { + if (rank_id == 0) { WriteFile(data_file + "/output.bin", cHost, cSizePerRank); if (commType == ALLGATHER_MATMUL_WITH_GATHER_RESULT) { WriteFile(data_file + "/output_gather_a.bin", gatherAHost, gatherASize); } } } else if (commType == MATMUL_REDUCE_SCATTER || commType == MATMUL_REDUCE_SCATTER_PADDING) { - WriteFile(data_file + "/output.bin", cHost, cSizePerRank, rankId * cSizePerRank); + WriteFile(data_file + "/output.bin", cHost, cSizePerRank, rank_id * cSizePerRank); } } - if (rankId == 0) { + if (rank_id == 0) { WriteTilingInfos(opName, cocTilings, tilingFileName, transA, transB); std::printf("M: %d, K: %d, N: %d aclrtSynchronizeStream success!\n", cocTiling.m, cocTiling.k, cocTiling.n); } @@ -430,11 +455,18 @@ int main(int argc, char **argv) ACL_CHECK(aclrtFree(bDevice)); ACL_CHECK(aclrtFree(cDevice)); } - std::cout << "[TEST] begin to exit...... rankId: " << rankId << std::endl; + + status = aclrtDestroyStream(stream); + status = shmem_finalize(); - ACL_CHECK(aclrtDestroyStream(stream)); - ACL_CHECK(aclrtResetDevice(deviceId)); - ACL_CHECK(aclFinalize()); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + if (status) { + std::exit(EXIT_FAILURE); + } + + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; + MPI_Finalize(); return 0; } \ No newline at end of file diff --git a/examples/dynamic_tiling/scripts/run.sh b/examples/dynamic_tiling/scripts/run.sh index aab6e16c..554bbc62 100644 --- a/examples/dynamic_tiling/scripts/run.sh +++ b/examples/dynamic_tiling/scripts/run.sh @@ -82,10 +82,11 @@ if [ "$TEST_TYPE" = "0" ]; then IPPORT="tcp://127.0.0.1:27008" # Start Process - for (( idx =0; idx < ${RANK_SIZE}; idx = idx + 1 )); do - APP="$EXEC_BIN $COMM_TYPE $DATA_TYPE $RANK_SIZE $idx $IPPORT $M $N $K $TEST_START_LINE $TEST_COLLECT_ROWS $PARENT_PATH $CSV_FILE $DEVICE_ID_STR $DATA_PATH" - ${APP}& - done + # for (( idx =0; idx < ${RANK_SIZE}; idx = idx + 1 )); do + # APP="$EXEC_BIN $COMM_TYPE $DATA_TYPE $RANK_SIZE $idx $IPPORT $M $N $K $TEST_START_LINE $TEST_COLLECT_ROWS $PARENT_PATH $CSV_FILE $DEVICE_ID_STR $DATA_PATH" + # ${APP}& + # done + mpirun -np ${RANK_SIZE} $EXEC_BIN $COMM_TYPE $DATA_TYPE $RANK_SIZE 0 $IPPORT $M $N $K $TEST_START_LINE $TEST_COLLECT_ROWS $PARENT_PATH $CSV_FILE $DEVICE_ID_STR $DATA_PATH # Wait until all process exit wait @@ -117,10 +118,12 @@ else OUTPUT_PATH="./output/msprof/start_line${IDX}_run_rows${TEST_COLLECT_ROWS}/" # Start Process - for (( idx =0; idx < ${RANK_SIZE}; idx = idx + 1 )); do - APP="$EXEC_BIN $COMM_TYPE $DATA_TYPE $RANK_SIZE $idx $IPPORT $M $N $K $TEST_START_LINE $TEST_COLLECT_ROWS $PARENT_PATH $CSV_FILE $DEVICE_ID_STR" - msprof --application="${APP}" --output="${OUTPUT_PATH}"& - done + # for (( idx =0; idx < ${RANK_SIZE}; idx = idx + 1 )); do + # APP="$EXEC_BIN $COMM_TYPE $DATA_TYPE $RANK_SIZE $idx $IPPORT $M $N $K $TEST_START_LINE $TEST_COLLECT_ROWS $PARENT_PATH $CSV_FILE $DEVICE_ID_STR" + # msprof --application="${APP}" --output="${OUTPUT_PATH}"& + # done + APP="$EXEC_BIN $COMM_TYPE $DATA_TYPE $RANK_SIZE 0 $IPPORT $M $N $K $TEST_START_LINE $TEST_COLLECT_ROWS $PARENT_PATH $CSV_FILE $DEVICE_ID_STR" + mpirun -np ${RANK_SIZE} msprof --application="${APP}" --output="${OUTPUT_PATH}" # Wait until all process exit wait diff --git a/examples/matmul_allreduce/main.cpp b/examples/matmul_allreduce/main.cpp index bb245c98..94a9e171 100644 --- a/examples/matmul_allreduce/main.cpp +++ b/examples/matmul_allreduce/main.cpp @@ -223,6 +223,18 @@ int main(int argc, char **argv) MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + // Kernel-need params parse + Options options; + if (options.Parse(argc, argv) != 0) { + std::cerr << "Invalid arguments\n"; + return 1; + } + + uint32_t m = options.m; + uint32_t n = options.n; + uint32_t k = options.k; + int32_t device_id = options.deviceIdList[rank_id]; + // Shmem uid pre-init shmemx_uniqueid_t uid; if (rank_id == 0) { @@ -232,7 +244,6 @@ int main(int argc, char **argv) MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); // Acl && Shmem init - int32_t device_id = rank_id % n_ranks; status = aclInit(nullptr); status = aclrtSetDevice(device_id); @@ -245,19 +256,6 @@ int main(int argc, char **argv) aclrtStream stream = nullptr; status = aclrtCreateStream(&stream); - // Kernel-need params parse - Options options; - if (options.Parse(argc, argv) != 0) { - std::cerr << "Invalid arguments\n"; - return 1; - } - // int rankSize = options.rankSize; - // int rankId = options.rankId; - // std::string ipPort = options.ipPort; - uint32_t m = options.m; - uint32_t n = options.n; - uint32_t k = options.k; - std::cout << "[TEST] input rank_size: " << n_ranks << " rank_id:" << rank_id << "\n"; size_t aSize = static_cast(m) * k * sizeof(__fp16); diff --git a/examples/matmul_reduce_scatter/main.cpp b/examples/matmul_reduce_scatter/main.cpp index 602d9917..eb3eea3d 100644 --- a/examples/matmul_reduce_scatter/main.cpp +++ b/examples/matmul_reduce_scatter/main.cpp @@ -9,6 +9,8 @@ */ #include +#include + #include #include #include @@ -204,48 +206,64 @@ struct Options { int main(int argc, char **argv) { int status = SHMEM_SUCCESS; + // MPI Init + MPI_Init(&argc, &argv); + int rank_id, n_ranks; + MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); + MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + + // Kernel-need params parse Options options; if (options.Parse(argc, argv) != 0) { std::cerr << "Invalid arguments\n"; return 1; } - int rankSize = options.rankSize; - int rankId = options.rankId; - std::string ipPort = options.ipPort; + uint32_t m = options.m; uint32_t n = options.n; uint32_t k = options.k; - int32_t deviceId = options.deviceIdList[rankId]; + int32_t device_id = options.deviceIdList[rank_id]; - std::cout << "[TEST] input rank_size: " << rankSize << " rank_id:" << rankId << " input_ip: " << ipPort << "\n"; + // Shmem uid pre-init + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmem_get_uniqueid(&uid); + } + std::cout << "MPI_Bcast!" << std::endl; + MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); + + // Acl && Shmem init + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); - aclrtStream stream = nullptr; - ACL_CHECK(aclInit(nullptr)); - ACL_CHECK(aclrtSetDevice(deviceId)); - ACL_CHECK(aclrtCreateStream(&stream)); - status = shmem_set_conf_store_tls(false, nullptr, 0); shmem_init_attr_t *attributes; - status = shmem_set_attr(rankId, rankSize, NPU_MALLOC_SPACE, ipPort.c_str(), &attributes); - status = shmem_init_attr(attributes); - status = shmem_init_status(); + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, &uid, &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + + // ACLStream init + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); + + std::cout << "[TEST] input rank_size: " << n_ranks << " rank_id:" << rank_id << "\n"; size_t aSize = static_cast(m) * k * sizeof(__fp16); size_t bSize = static_cast(k) * n * sizeof(__fp16); size_t dSize = static_cast(m) * n * sizeof(__fp16); - size_t dSizeScatter = dSize / options.rankSize; + size_t dSizeScatter = dSize / n_ranks; uint8_t *aDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&aDevice), aSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *aHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&aHost), aSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_a.bin"), aHost, aSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_a.bin"), aHost, aSize); ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *bDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&bDevice), bSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *bHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&bHost), bSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_b.bin"), bHost, bSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_b.bin"), bHost, bSize); ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *dDevice; @@ -270,8 +288,8 @@ int main(int argc, char **argv) std::cout << "After calling MM_RS kernel " << std::endl; ACL_CHECK(aclrtMemcpy(dHost, dSizeScatter, dDevice, dSizeScatter, ACL_MEMCPY_DEVICE_TO_HOST)); - WriteFile(options.GetDataPath("shmem_output.bin"), dHost, dSizeScatter, rankId * dSizeScatter); - if (rankId == 0) { + WriteFile(options.GetDataPath("shmem_output.bin"), dHost, dSizeScatter, rank_id * dSizeScatter); + if (rank_id == 0) { std::printf("test finished\n"); } @@ -284,11 +302,17 @@ int main(int argc, char **argv) ACL_CHECK(aclrtFree(bDevice)); ACL_CHECK(aclrtFree(dDevice)); - std::cout << "[TEST] begin to exit...... rankId: " << rankId << std::endl; + status = aclrtDestroyStream(stream); + status = shmem_finalize(); - ACL_CHECK(aclrtDestroyStream(stream)); - ACL_CHECK(aclrtResetDevice(deviceId)); - ACL_CHECK(aclFinalize()); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + if (status) { + std::exit(EXIT_FAILURE); + } + + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; + MPI_Finalize(); return 0; } diff --git a/examples/matmul_reduce_scatter/scripts/run.sh b/examples/matmul_reduce_scatter/scripts/run.sh index 893ba7d2..a64af5cf 100644 --- a/examples/matmul_reduce_scatter/scripts/run.sh +++ b/examples/matmul_reduce_scatter/scripts/run.sh @@ -39,9 +39,10 @@ tail -n +2 "$CSV_FILE" | while IFS=',' read -r M K N; do IPPORT="tcp://127.0.0.1:8788" # Start Process - for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do - ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & - done + # for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do + # ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & + # done + mpirun -np ${RANK_SIZE} ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" # Wait until all process exit wait diff --git a/examples/matmul_reduce_scatter_padding/main.cpp b/examples/matmul_reduce_scatter_padding/main.cpp index 57586818..09717ceb 100644 --- a/examples/matmul_reduce_scatter_padding/main.cpp +++ b/examples/matmul_reduce_scatter_padding/main.cpp @@ -9,6 +9,8 @@ */ #include +#include + #include #include #include @@ -227,30 +229,46 @@ struct Options { int main(int argc, char **argv) { int status = SHMEM_SUCCESS; + // MPI Init + MPI_Init(&argc, &argv); + int rank_id, n_ranks; + MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); + MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + + // Kernel-need params parse Options options; if (options.Parse(argc, argv) != 0) { std::cerr << "Invalid arguments\n"; return 1; } - int rankSize = options.rankSize; - int rankId = options.rankId; - std::string ipPort = options.ipPort; + uint32_t m = options.m; uint32_t n = options.n; uint32_t k = options.k; - int32_t deviceId = options.deviceIdList[rankId]; + int32_t device_id = options.deviceIdList[rank_id]; - std::cout << "[TEST] input rank_size: " << rankSize << " rank_id:" << rankId << " input_ip: " << ipPort << "\n"; + // Shmem uid pre-init + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmem_get_uniqueid(&uid); + } + std::cout << "MPI_Bcast!" << std::endl; + MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); + + // Acl && Shmem init + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); - aclrtStream stream = nullptr; - ACL_CHECK(aclInit(nullptr)); - ACL_CHECK(aclrtSetDevice(deviceId)); - ACL_CHECK(aclrtCreateStream(&stream)); - status = shmem_set_conf_store_tls(false, nullptr, 0); shmem_init_attr_t *attributes; - status = shmem_set_attr(rankId, rankSize, NPU_MALLOC_SPACE, ipPort.c_str(), &attributes); - status = shmem_init_attr(attributes); - status = shmem_init_status(); + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, &uid, &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + + // ACLStream init + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); + + std::cout << "[TEST] input rank_size: " << n_ranks << " rank_id:" << rank_id << "\n"; LayoutA layoutA{m, k}; LayoutB layoutB{k, n}; @@ -267,20 +285,20 @@ int main(int argc, char **argv) size_t aSize = static_cast(m) * k * sizeof(__fp16); size_t bSize = static_cast(k) * n * sizeof(__fp16); size_t dSize = static_cast(m) * n * sizeof(__fp16); - size_t dSizeScatter = dSize / options.rankSize; + size_t dSizeScatter = dSize / n_ranks; uint8_t *aDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&aDevice), aSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *aHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&aHost), aSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_a.bin"), aHost, aSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_a.bin"), aHost, aSize); ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *bDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&bDevice), bSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *bHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&bHost), bSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_b.bin"), bHost, bSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_b.bin"), bHost, bSize); ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *dDevice; @@ -345,8 +363,8 @@ int main(int argc, char **argv) std::cout << "After calling MM_RS kernel " << std::endl; ACL_CHECK(aclrtMemcpy(dHost, dSizeScatter, dDevice, dSizeScatter, ACL_MEMCPY_DEVICE_TO_HOST)); - WriteFile(options.GetDataPath("shmem_output.bin"), dHost, dSizeScatter, rankId * dSizeScatter); - if (rankId == 0) { + WriteFile(options.GetDataPath("shmem_output.bin"), dHost, dSizeScatter, rank_id * dSizeScatter); + if (rank_id == 0) { std::printf("test finished\n"); } @@ -365,11 +383,17 @@ int main(int argc, char **argv) ACL_CHECK(aclrtFree(wbDevice)); } - std::cout << "[TEST] begin to exit...... rankId: " << rankId << std::endl; + status = aclrtDestroyStream(stream); + status = shmem_finalize(); - ACL_CHECK(aclrtDestroyStream(stream)); - ACL_CHECK(aclrtResetDevice(deviceId)); - ACL_CHECK(aclFinalize()); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + if (status) { + std::exit(EXIT_FAILURE); + } + + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; + MPI_Finalize(); return 0; } diff --git a/examples/matmul_reduce_scatter_padding/scripts/run.sh b/examples/matmul_reduce_scatter_padding/scripts/run.sh index fb7ef375..30871b22 100644 --- a/examples/matmul_reduce_scatter_padding/scripts/run.sh +++ b/examples/matmul_reduce_scatter_padding/scripts/run.sh @@ -39,9 +39,10 @@ tail -n +2 "$CSV_FILE" | while IFS=',' read -r M K N; do IPPORT="tcp://127.0.0.1:8788" # Start Process - for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do - ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & - done + # for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do + # ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & + # done + mpirun -np ${RANK_SIZE} ${EXEC_BIN} "$RANK_SIZE" 0 "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" # Wait until all process exit wait -- Gitee From 00c2fd1e1010663b013762e307fd0e00b0b709e6 Mon Sep 17 00:00:00 2001 From: zhangyunqi05 Date: Sat, 29 Nov 2025 15:36:43 +0800 Subject: [PATCH 60/74] utfix and mf adapt --- CMakeLists.txt | 2 +- include/host/shmem_host_init.h | 3 + include/internal/host/shmemi_host_def.h | 29 +- .../init/init_backends/mf/shmemi_init_mf.cpp | 403 +++++++++++++++++- .../init/init_backends/mf/shmemi_init_mf.h | 5 + src/host/init/shmem_init.cpp | 76 +++- src/modules/transport/shmemi_mte.cpp | 2 +- tests/fuzz/device/mem/shmem_ptr_kernel.cpp | 2 +- tests/fuzz/device/team/team_kernel.cpp | 2 +- .../mem/atomic_add/atomic_add_kernel.cpp | 2 +- tests/unittest/host/init/init_host_test.cpp | 31 ++ tests/unittest/host/main_test.cpp | 16 +- .../host/mem/shmem_host_heap_test.cpp | 5 + 13 files changed, 533 insertions(+), 45 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4dabc368..69ac48cd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -94,7 +94,7 @@ link_directories( link_libraries(runtime stdc++ ascendcl m tiling_api platform c_sec dl nnopbase pthread) # MF_BACKEND -set(USE_MF "0") +set(USE_MF "1") if ("${USE_MF}" STREQUAL "1") add_compile_definitions(BACKEND_MF=1) diff --git a/include/host/shmem_host_init.h b/include/host/shmem_host_init.h index 803a3e17..651c0e37 100644 --- a/include/host/shmem_host_init.h +++ b/include/host/shmem_host_init.h @@ -131,6 +131,9 @@ SHMEM_HOST_API int shmemx_set_attr_uniqueid_args(const int my_rank, const int n_ shmem_init_attr_t **shmem_attr); SHMEM_HOST_API int32_t shmem_get_uniqueid(shmemx_uniqueid_t *uid); + +SHMEM_HOST_API int32_t shmem_set_conf_store_tls(bool enable, const char *tls_info, const uint32_t tls_info_len); + #ifdef __cplusplus } #endif diff --git a/include/internal/host/shmemi_host_def.h b/include/internal/host/shmemi_host_def.h index 619e315c..012d32ce 100644 --- a/include/internal/host/shmemi_host_def.h +++ b/include/internal/host/shmemi_host_def.h @@ -20,20 +20,20 @@ typedef enum { } addr_type_t; // mf unique id -typedef struct { - union { - struct sockaddr_in addr4; - struct sockaddr_in6 addr6; - } addr; - addr_type_t type; -} shmem_sockaddr_t; +// typedef struct { +// union { +// struct sockaddr_in addr4; +// struct sockaddr_in6 addr6; +// } addr; +// addr_type_t type; +// } shmem_sockaddr_t; -typedef struct { - int32_t version; - int32_t inner_sockFd; - shmem_sockaddr_t addr; - uint64_t magic; -} shmem_uniqueid_inner_t; +// typedef struct { +// int32_t version; +// int32_t inner_sockFd; +// shmem_sockaddr_t addr; +// uint64_t magic; +// } shmem_uniqueid_inner_t; // shmem unique id typedef struct { @@ -46,7 +46,8 @@ typedef struct { } sockaddr_t; typedef struct { - int32_t version; + int32_t version; + int32_t inner_sockFd; // for mf backend sockaddr_t addr; // 动态传入的地址(含端口) uint64_t magic; int rank; diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.cpp b/src/host/init/init_backends/mf/shmemi_init_mf.cpp index fd5a7749..c37bb2d4 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.cpp +++ b/src/host/init/init_backends/mf/shmemi_init_mf.cpp @@ -7,18 +7,32 @@ * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. * See LICENSE in the root of the software repository for the full text of the License. */ +#include #include "shmemi_init_mf.h" #ifdef BACKEND_MF // smem api -#include -#include -#include -#include -#include -#include -#include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +#include +#include +#include +#include +#include "internal/host/shmemi_host_def.h" +// #include + +constexpr int MIN_PORT = 1024; +constexpr int MAX_PORT = 65536; +constexpr int MAX_ATTEMPTS = 1000; +constexpr int MAX_IFCONFIG_LENGTH = 23; +constexpr int MAX_IP = 48; +constexpr int DEFAULT_IFNAME_LNEGTH = 4; constexpr int DEFAULT_FLAG = 0; constexpr int DEFAULT_ID = 0; @@ -28,7 +42,7 @@ constexpr int DEFAULT_BLOCK_NUM = 1; // smem need static smem_shm_t g_smem_handle = nullptr; -static char *g_ipport = nullptr; +static char g_ipport[SHMEM_MAX_IP_PORT_LEN] = {0}; shmemi_init_mf::shmemi_init_mf(shmem_init_attr_t *attr, char *ipport) { @@ -118,13 +132,12 @@ int shmemi_init_mf::reserve_heap(shmemi_device_host_state_t &g_state) g_state.topo_list[i] |= SHMEM_TRANSPORT_ROCE; } } - if (g_ipport != nullptr) { - delete[] g_ipport; - g_ipport = nullptr; - attributes->ip_port = nullptr; + if (g_ipport[0] != '\0') { + g_ipport[0] = '\0'; + bzero(attributes->ip_port, sizeof(attributes->ip_port)); } else { SHM_LOG_WARN("my_rank:" << attributes->my_rank << " g_ipport is released in advance!"); - attributes->ip_port = nullptr; + bzero(attributes->ip_port, sizeof(attributes->ip_port)); } g_state.is_shmem_created = true; return status; @@ -167,4 +180,368 @@ int shmemi_init_mf::transport_finalize() return SHMEM_SUCCESS; } +int32_t shmem_get_uid_magic(shmemx_bootstrap_uid_state_t *innerUId) +{ + std::ifstream urandom("/dev/urandom", std::ios::binary); + if (!urandom) { + SHM_LOG_ERROR("open random failed"); + return SHMEM_INNER_ERROR; + } + + urandom.read(reinterpret_cast(&innerUId->magic), sizeof(innerUId->magic)); + if (urandom.fail()) { + SHM_LOG_ERROR("read random failed."); + return SHMEM_INNER_ERROR; + } + SHM_LOG_DEBUG("init magic id to " << innerUId->magic); + return SHMEM_SUCCESS; +} + +int32_t bind_tcp_port_v4(int &sockfd, int port, shmemx_bootstrap_uid_state_t *innerUId, char *ip_str) +{ + sockfd = ::socket(AF_INET, SOCK_STREAM, 0); + if (sockfd != -1) { + int on_v4 = 1; + if (::setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &on_v4, sizeof(on_v4)) == 0) { + innerUId->addr.addr.addr4.sin_port = htons(port); + sockaddr *cur_addr = reinterpret_cast(&innerUId->addr.addr.addr4); + if (::bind(sockfd, cur_addr, sizeof(innerUId->addr.addr.addr4)) == 0) { + SHM_LOG_INFO("bind ipv4 success " << ", fd:" << sockfd << ", " << ip_str << ":" << port); + return 0; + } else { + SHM_LOG_ERROR("bind socket fail:" << errno << "," << ip_str << ":" << port); + } + } else { + SHM_LOG_ERROR("set socket opt fail:" << errno << "," << ip_str << ":" << port); + } + close(sockfd); + sockfd = -1; + } else { + SHM_LOG_ERROR("create socket fail:" << errno << ", " << ip_str << ":" << port); + } + return -1; +} + +int32_t bind_tcp_port_v6(int &sockfd, int port, shmemx_bootstrap_uid_state_t *innerUId, char *ip_str) +{ + sockfd = ::socket(AF_INET6, SOCK_STREAM, 0); + if (sockfd != -1) { + int on_v6 = 1; + if (::setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &on_v6, sizeof(on_v6)) == 0) { + innerUId->addr.addr.addr6.sin6_port = htons(port); + sockaddr *cur_addr = reinterpret_cast(&innerUId->addr.addr.addr6); + if (::bind(sockfd, cur_addr, sizeof(innerUId->addr.addr.addr6)) == 0) { + SHM_LOG_INFO("bind ipv6 success " << ", fd:" << sockfd << ", " << ip_str << ":" << port); + return 0; + } else { + SHM_LOG_ERROR("bind socket6 fail:" << errno << "," << ip_str << ":" << port); + } + } else { + SHM_LOG_ERROR("set socket6 opt fail:" << errno << "," << ip_str << ":" << port); + } + close(sockfd); + sockfd = -1; + } else { + SHM_LOG_ERROR("create socket6 fail:" << errno << "," << ip_str << ":" << port); + } + return -1; +} + +int32_t shmem_get_port_magic(shmemx_bootstrap_uid_state_t *innerUId, char *ip_str) +{ + static std::random_device rd; + const int min_port = MIN_PORT; + const int max_port = MAX_PORT; + const int max_attempts = MAX_ATTEMPTS; + const int offset_bit = 32; + uint64_t seed = 1; + seed |= static_cast(getpid()) << offset_bit; + seed |= static_cast(std::chrono::system_clock::now().time_since_epoch().count() & 0xFFFFFFFF); + static std::mt19937_64 gen(seed); + std::uniform_int_distribution<> dis(min_port, max_port); + + int sockfd = -1; + int32_t ret; + for (int attempt = 0; attempt < max_attempts; ++attempt) { + int port = dis(gen); + if (innerUId->addr.type == ADDR_IPv4) { + ret = bind_tcp_port_v4(sockfd, port, innerUId, ip_str); + if (ret == 0) { + innerUId->inner_sockFd = sockfd; + return 0; + } + } else { + ret = bind_tcp_port_v6(sockfd, port, innerUId, ip_str); + if (ret == 0) { + innerUId->inner_sockFd = sockfd; + return 0; + } + } + } + SHM_LOG_ERROR("Not find a available tcp port"); + return -1; +} + +int32_t shmem_using_env_port(shmemx_bootstrap_uid_state_t *innerUId, char *ip_str, uint16_t envPort) +{ + if (envPort < MIN_PORT) { // envPort > MAX_PORT always false + SHM_LOG_ERROR("env port is invalid. " << envPort); + return SHMEM_INVALID_PARAM; + } + + int sockfd = -1; + int32_t ret; + if (innerUId->addr.type == ADDR_IPv4) { + ret = bind_tcp_port_v4(sockfd, envPort, innerUId, ip_str); + if (ret == 0) { + innerUId->inner_sockFd = sockfd; + return 0; + } + } else { + ret = bind_tcp_port_v6(sockfd, envPort, innerUId, ip_str); + if (ret == 0) { + innerUId->inner_sockFd = sockfd; + return 0; + } + } + SHM_LOG_ERROR("init with env port fialed " << envPort << ", ret=" << ret); + return ret; +} + +int32_t ParseInterfaceWithType(const char *ipInfo, char *IP, sa_family_t &sockType, bool &flag) +{ + const char *delim = ":"; + const char *sep = strchr(ipInfo, delim[0]); + if (sep != nullptr) { + size_t leftLen = sep - ipInfo; + if (leftLen >= MAX_IFCONFIG_LENGTH - 1 || leftLen == 0) { + return SHMEM_INVALID_VALUE; + } + strncpy(IP, ipInfo, leftLen); + IP[leftLen] = '\0'; + sockType = (strcmp(sep + 1, "inet6") != 0) ? AF_INET : AF_INET6; + flag = true; + } + return SHMEM_SUCCESS; +} + +int32_t shmem_auto_get_ip(struct sockaddr *ifaAddr, char *local, sa_family_t &sockType) +{ + sockType = ifaAddr->sa_family; + if (sockType == AF_INET) { + auto localIp = reinterpret_cast(ifaAddr)->sin_addr; + if (inet_ntop(sockType, &localIp, local, MAX_IP) == nullptr) { + SHM_LOG_ERROR("convert local ipv4 to string failed. "); + return SHMEM_INVALID_PARAM; + } + return SHMEM_SUCCESS; + } else if (sockType == AF_INET6) { + auto localIp = reinterpret_cast(ifaAddr)->sin6_addr; + if (inet_ntop(sockType, &localIp, local, MAX_IP) == nullptr) { + SHM_LOG_ERROR("convert local ipv6 to string failed. "); + return SHMEM_INVALID_PARAM; + } + return SHMEM_SUCCESS; + } + return SHMEM_INVALID_PARAM; +} + +bool shmem_check_ifa(struct ifaddrs *ifa, sa_family_t sockType, bool flag, char *ifaName, size_t ifaLen) +{ + if (ifa->ifa_addr == nullptr || ifa->ifa_netmask == nullptr || ifa->ifa_name == nullptr) { + SHM_LOG_DEBUG("loop ifa_addr/ifa_netmask/ifa_name is nullptr"); + return false; + } + + // socket type match and input env ifa valid + if (ifa->ifa_addr->sa_family != sockType && flag) { + SHM_LOG_DEBUG("sa family is not match, get " << ifa->ifa_addr->sa_family << ", expect " << sockType); + return false; + } + + // prefix match with input ifa name + if (strncmp(ifa->ifa_name, ifaName, ifaLen) != 0) { + SHM_LOG_DEBUG("ifa name prefix un-match, get " << ifa->ifa_name << ", expect " << ifaName); + return false; + } + + // ignore ifa which is down or loopback or not running + if ((ifa->ifa_flags & IFF_LOOPBACK) || !(ifa->ifa_flags & IFF_RUNNING) || !(ifa->ifa_flags & IFF_UP)) { + SHM_LOG_DEBUG("ifa flag un-match, flag=" << ifa->ifa_flags); + return false; + } + + if (sockType == AF_INET6) { + struct sockaddr_in6 *sa6 = reinterpret_cast(ifa->ifa_addr); + if (IN6_IS_ADDR_LINKLOCAL(&sa6->sin6_addr)) { + SHM_LOG_DEBUG("ifa is scope link addr " << ifaName); + return false; + } + } + return true; +} + +int32_t shmem_get_ip_from_ifa(char *local, sa_family_t &sockType, const char *ipInfo) +{ + struct ifaddrs *ifaddr; + char ifaName[MAX_IFCONFIG_LENGTH]; + sockType = AF_INET; + bool flag = false; + if (ipInfo == nullptr) { + strncpy(ifaName, "eth", DEFAULT_IFNAME_LNEGTH); + ifaName[DEFAULT_IFNAME_LNEGTH - 1] = '\0'; + SHM_LOG_INFO("use default if to find IP:" << ifaName); + } else if (ParseInterfaceWithType(ipInfo, ifaName, sockType, flag) != SHMEM_SUCCESS) { + SHM_LOG_ERROR("IP size set in SHMEM_CONF_STORE_MASTER_IF format has wrong length"); + return SHMEM_INVALID_PARAM; + } + if (getifaddrs(&ifaddr) == -1) { + SHM_LOG_ERROR("get local net interfaces failed: " << errno); + return SHMEM_INVALID_PARAM; + } + int32_t result = SHMEM_INVALID_PARAM; + for (auto ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next) { + if (!shmem_check_ifa(ifa, sockType, flag, ifaName, strlen(ifaName))) { + continue; + } + if (sockType == AF_INET && flag) { + auto localIp = reinterpret_cast(ifa->ifa_addr)->sin_addr; + if (inet_ntop(sockType, &localIp, local, 64) == nullptr) { + SHM_LOG_ERROR("convert local ipv4 to string failed. "); + continue; + } + result = SHMEM_SUCCESS; + break; + } else if (sockType == AF_INET6 && flag) { + auto localIp = reinterpret_cast(ifa->ifa_addr)->sin6_addr; + if (inet_ntop(sockType, &localIp, local, 64) == nullptr) { + SHM_LOG_ERROR("convert local ipv6 to string failed. "); + continue; + } + result = SHMEM_SUCCESS; + break; + } else { + auto ret = shmem_auto_get_ip(ifa->ifa_addr, local, sockType); + if (ret != SHMEM_SUCCESS) { + continue; + } + result = SHMEM_SUCCESS; + break; + } + } + freeifaddrs(ifaddr); + return result; +} + +int32_t shmem_get_ip_from_env(char *ip, uint16_t &port, sa_family_t &sockType, const char *ipPort) +{ + if (ipPort != nullptr) { + SHM_LOG_DEBUG("get env SHMEM_UID_SESSION_ID value:" << ipPort); + std::string ipPortStr = ipPort; + + if (ipPort[0] == '[') { + sockType = AF_INET6; + size_t found = ipPortStr.find_last_of(']'); + if (found == std::string::npos || ipPortStr.length() - found <= 1) { + SHM_LOG_ERROR("get env SHMEM_UID_SESSION_ID is invalid"); + return SHMEM_INVALID_PARAM; + } + std::string ipStr = ipPortStr.substr(1, found - 1); + std::string portStr = ipPortStr.substr(found + 2); + + std::snprintf(ip, MAX_IP, "%s", ipStr.c_str()); + + port = std::stoi(portStr); + } else { + sockType = AF_INET; + size_t found = ipPortStr.find_last_of(':'); + if (found == std::string::npos || ipPortStr.length() - found <= 1) { + SHM_LOG_ERROR("get env SHMEM_UID_SESSION_ID is invalid"); + return SHMEM_INVALID_PARAM; + } + std::string ipStr = ipPortStr.substr(0, found); + std::string portStr = ipPortStr.substr(found + 1); + + std::snprintf(ip, MAX_IP, "%s", ipStr.c_str()); + + port = std::stoi(portStr); + } + return SHMEM_SUCCESS; + } + return SHMEM_INVALID_PARAM; +} + +int32_t shmem_set_ip_info(shmemx_uniqueid_t *uid, sa_family_t &sockType, char *pta_env_ip, uint16_t pta_env_port, + bool is_from_ifa) +{ + // init default uid + SHM_ASSERT_RETURN(uid != nullptr, SHMEM_INVALID_PARAM); + *uid = SHMEM_UNIQUEID_INITIALIZER; + shmemx_bootstrap_uid_state_t *innerUID = reinterpret_cast(uid); + if (sockType == AF_INET) { + innerUID->addr.addr.addr4.sin_family = AF_INET; + if (inet_pton(AF_INET, pta_env_ip, &(innerUID->addr.addr.addr4.sin_addr)) <= 0) { + SHM_LOG_ERROR("inet_pton IPv4 failed"); + return SHMEM_NOT_INITED; + } + innerUID->addr.type = ADDR_IPv4; + } else if (sockType == AF_INET6) { + innerUID->addr.addr.addr6.sin6_family = AF_INET6; + if (inet_pton(AF_INET6, pta_env_ip, &(innerUID->addr.addr.addr6.sin6_addr)) <= 0) { + SHM_LOG_ERROR("inet_pton IPv6 failed"); + return SHMEM_NOT_INITED; + } + innerUID->addr.type = ADDR_IPv6; + } else { + SHM_LOG_ERROR("IP Type is not IPv4 or IPv6"); + return SHMEM_INVALID_PARAM; + } + + // fill ip port as part of uid + if (is_from_ifa) { + int32_t ret = shmem_get_port_magic(innerUID, pta_env_ip); + if (ret != 0) { + SHM_LOG_ERROR("get available port failed."); + return SHMEM_INVALID_PARAM; + } + } else { + int32_t ret = shmem_using_env_port(innerUID, pta_env_ip, pta_env_port); + if (ret != 0) { + SHM_LOG_ERROR("using env port failed."); + return SHMEM_INVALID_PARAM; + } + } + + SHM_LOG_INFO("gen unique id success."); + return SHMEM_SUCCESS; +} + +int32_t shmem_get_uniqueid_mf(shmemx_uniqueid_t *uid) +{ + if (shmem_set_log_level(shm::WARN_LEVEL) != 0) { + SHM_LOG_ERROR("failed to set log level"); + return SHMEM_INNER_ERROR; + } + char pta_env_ip[MAX_IP]; + uint16_t pta_env_port; + sa_family_t sockType; + const char *ipPort = std::getenv("SHMEM_UID_SESSION_ID"); + const char *ipInfo = std::getenv("SHMEM_UID_SOCK_IFNAM"); + bool is_from_ifa = false; + if (ipPort != nullptr) { + if (shmem_get_ip_from_env(pta_env_ip, pta_env_port, sockType, ipPort) != SHMEM_SUCCESS) { + SHM_LOG_ERROR("cant get pta master addr."); + return SHMEM_INVALID_PARAM; + } + } else { + is_from_ifa = true; + if (shmem_get_ip_from_ifa(pta_env_ip, sockType, ipInfo) != SHMEM_SUCCESS) { + SHM_LOG_ERROR("cant get available ip port."); + return SHMEM_INVALID_PARAM; + } + } + SHM_LOG_INFO("get master IP value:" << pta_env_ip); + return shmem_set_ip_info(uid, sockType, pta_env_ip, pta_env_port, is_from_ifa); +} + #endif \ No newline at end of file diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.h b/src/host/init/init_backends/mf/shmemi_init_mf.h index b2c97dec..6226b1d3 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.h +++ b/src/host/init/init_backends/mf/shmemi_init_mf.h @@ -16,6 +16,9 @@ #include "init/init_backends/shmemi_init_base.h" #include "shmemi_host_common.h" #include "internal/host_device/shmemi_types.h" +#include +#include +#include class shmemi_init_mf: public shmemi_init_base { public: @@ -40,4 +43,6 @@ private: char *g_ipport = nullptr; }; +int32_t shmem_get_uniqueid_mf(shmemx_uniqueid_t *uid); + #endif // SHMEMI_INIT_MF_H \ No newline at end of file diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index 0fbe1f56..ab47665c 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -226,7 +226,32 @@ int shmemx_set_attr_uniqueid_args(const int my_rank, const int n_ranks, const in g_attr.my_rank = my_rank; g_attr.n_ranks = n_ranks; g_attr.local_mem_size = local_mem_size; - +#ifdef BACKEND_MF + std::string ipPort; + if (uid_args->addr.type == ADDR_IPv6) { + char ipStr[INET6_ADDRSTRLEN] = {0}; + if (inet_ntop(AF_INET6, &(uid_args->addr.addr.addr6.sin6_addr), ipStr, sizeof(ipStr)) == nullptr) { + SHM_LOG_ERROR("inet_ntop failed for IPv6"); + return SHMEM_INNER_ERROR; + } + uint16_t port = ntohs(uid_args->addr.addr.addr6.sin6_port); + ipPort = "tcp6://[" + std::string(ipStr) + "]:" + std::to_string(port); + } else { + char ipStr[INET_ADDRSTRLEN] = {0}; + if (inet_ntop(AF_INET, &(uid_args->addr.addr.addr4.sin_addr), ipStr, sizeof(ipStr)) == nullptr) { + SHM_LOG_ERROR("inet_ntop failed for IPv4"); + return SHMEM_INNER_ERROR; + } + uint16_t port = ntohs(uid_args->addr.addr.addr4.sin_port); + ipPort = "tcp://" + std::string(ipStr) + ":" + std::to_string(port); + } + std::copy(ipPort.begin(), ipPort.end(), g_ipport); + std::copy(ipPort.begin(), ipPort.end(), g_attr.ip_port); + g_ipport[ipPort.size()] = '\0'; + g_attr.ip_port[ipPort.size()] = '\0'; + g_attr.option_attr.sockFd = uid_args->inner_sockFd; + SHM_LOG_INFO("extract ip port:" << ipPort); +#endif return SHMEM_SUCCESS; } @@ -241,14 +266,13 @@ int32_t shmem_init_attr(shmemx_bootstrap_t bootstrap_flags, shmem_init_attr_t *a SHMEM_CHECK_RET(version_compatible()); SHMEM_CHECK_RET(shmemi_options_init()); - // bootstrap init - shmemi_bootstrap_attr_t attr = {}; - SHMEM_CHECK_RET(shmemi_bootstrap_init(bootstrap_flags, attributes)); - // shmem basic init #ifdef BACKEND_MF init_manager = new shmemi_init_mf(attributes, g_ipport); #else + // bootstrap init + // shmemi_bootstrap_attr_t attr = {}; + SHMEM_CHECK_RET(shmemi_bootstrap_init(bootstrap_flags, attributes)); init_manager = new shmemi_init_default(attributes); #endif SHMEM_CHECK_RET(shmemi_state_init_attr(attributes)); @@ -296,12 +320,6 @@ void shmem_info_get_name(char *name) name[i] = '\0'; } -int32_t shmem_get_uniqueid_mf(shmemx_uniqueid_t *uid) -{ - return SHMEM_SUCCESS; -} - - int32_t shmem_get_uniqueid_default(shmemx_uniqueid_t *uid) { int status = 0; @@ -369,3 +387,39 @@ int32_t shmem_set_log_level(int level) return shm::shm_out_logger::Instance().set_log_level(static_cast(level)); } + +int32_t shmem_set_conf_store_tls(bool enable, const char *tls_info, const uint32_t tls_info_len) +{ +#ifdef BACKEND_MF + return smem_set_conf_store_tls(enable, tls_info, tls_info_len); +#else + return SHMEM_SUCCESS; +#endif +} + +void shmem_rank_exit(int status) +{ + SHM_LOG_DEBUG("shmem_rank_exit is work ,status: " << status); + exit(status); +} + +int32_t shmem_set_config_store_tls_key(const char *tls_pk, const uint32_t tls_pk_len, + const char *tls_pk_pw, const uint32_t tls_pk_pw_len, const shmem_decrypt_handler decrypt_handler) +{ +#ifdef BACKEND_MF + return smem_set_config_store_tls_key(tls_pk, tls_pk_len, tls_pk_pw, tls_pk_pw_len, decrypt_handler); +#else + return SHMEM_SUCCESS; +#endif +} + +int32_t shmem_set_extern_logger(void (*func)(int level, const char *msg)) +{ +#ifdef BACKEND_MF + SHM_ASSERT_RETURN(func != nullptr, SHMEM_INVALID_PARAM); + shm::shm_out_logger::Instance().set_extern_log_func(func, true); + return smem_set_extern_logger(func); +#else + return SHMEM_SUCCESS; +#endif +} \ No newline at end of file diff --git a/src/modules/transport/shmemi_mte.cpp b/src/modules/transport/shmemi_mte.cpp index 911dbd68..c6e21136 100644 --- a/src/modules/transport/shmemi_mte.cpp +++ b/src/modules/transport/shmemi_mte.cpp @@ -12,7 +12,7 @@ #include #include #include - +#include "mem/shmemi_heap.h" #include "shmemi_host_common.h" #include "internal/host_device/shmemi_types.h" #include "transport/shmemi_transport.h" diff --git a/tests/fuzz/device/mem/shmem_ptr_kernel.cpp b/tests/fuzz/device/mem/shmem_ptr_kernel.cpp index b4812499..7ca9f7ec 100644 --- a/tests/fuzz/device/mem/shmem_ptr_kernel.cpp +++ b/tests/fuzz/device/mem/shmem_ptr_kernel.cpp @@ -14,7 +14,7 @@ public: { gva_gm = (__gm__ int *)gva; int64_t rank = shmem_my_pe(); - int64_t rank_size = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); } __aicore__ inline void Process() { diff --git a/tests/fuzz/device/team/team_kernel.cpp b/tests/fuzz/device/team/team_kernel.cpp index 8dd5821c..64e9de46 100644 --- a/tests/fuzz/device/team/team_kernel.cpp +++ b/tests/fuzz/device/team/team_kernel.cpp @@ -19,7 +19,7 @@ public: team_idx= team_id; int64_t rank = shmem_my_pe(); - int64_t rank_size = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); } __aicore__ inline void Process() { diff --git a/tests/unittest/device/mem/atomic_add/atomic_add_kernel.cpp b/tests/unittest/device/mem/atomic_add/atomic_add_kernel.cpp index 29c27274..3f67f4bd 100644 --- a/tests/unittest/device/mem/atomic_add/atomic_add_kernel.cpp +++ b/tests/unittest/device/mem/atomic_add/atomic_add_kernel.cpp @@ -18,7 +18,7 @@ constexpr uint64_t MESSAGE_SIZE = 64; { \ shmemx_set_ffts_config(config); \ int64_t rank = shmem_my_pe(); \ - int64_t rank_size = shmem_my_pe(); \ + int64_t rank_size = shmem_n_pes(); \ GM_ADDR dst_addr; \ \ for (int64_t peer = 0; peer < rank_size; peer++) \ diff --git a/tests/unittest/host/init/init_host_test.cpp b/tests/unittest/host/init/init_host_test.cpp index cc557b6d..a524e849 100644 --- a/tests/unittest/host/init/init_host_test.cpp +++ b/tests/unittest/host/init/init_host_test.cpp @@ -29,6 +29,11 @@ void test_shmem_init(int rank_id, int n_ranks, uint64_t local_mem_size) int status = SHMEM_SUCCESS; EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); + shmem_set_conf_store_tls(false, nullptr, 0); +#ifdef BACKEND_MF + shmem_init_attr_t *attributes; + shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); +#else shmemx_uniqueid_t uid; if (rank_id == 0) { status = shmemi_get_uniqueid_static_magic(&uid, true); @@ -41,6 +46,7 @@ void test_shmem_init(int rank_id, int n_ranks, uint64_t local_mem_size) &uid, &attributes); EXPECT_EQ(status, SHMEM_SUCCESS); +#endif status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); EXPECT_EQ(status, SHMEM_SUCCESS); EXPECT_EQ(g_state.mype, rank_id); @@ -68,6 +74,11 @@ void test_shmem_init_invalid_rank_id(int rank_id, int n_ranks, uint64_t local_me int status = SHMEM_SUCCESS; EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); + shmem_set_conf_store_tls(false, nullptr, 0); +#ifdef BACKEND_MF + shmem_init_attr_t *attributes; + shmem_set_attr(erank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); +#else shmemx_uniqueid_t uid; if (rank_id == 0) { shmemi_get_uniqueid_static_magic(&uid, true); @@ -78,7 +89,9 @@ void test_shmem_init_invalid_rank_id(int rank_id, int n_ranks, uint64_t local_me shmemx_set_attr_uniqueid_args(erank_id, n_ranks, local_mem_size, &uid, &attributes); +#endif status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + EXPECT_EQ(status, SHMEM_INVALID_VALUE); status = shmem_init_status(); EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); @@ -96,7 +109,12 @@ void test_shmem_init_invalid_n_ranks(int rank_id, int n_ranks, uint64_t local_me int status = SHMEM_SUCCESS; EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); + shmem_set_conf_store_tls(false, nullptr, 0); shmemx_uniqueid_t uid; +#ifdef BACKEND_MF + shmem_init_attr_t *attributes; + shmem_set_attr(rank_id, en_ranks, local_mem_size, test_global_ipport, &attributes); +#else if (rank_id == 0) { shmemi_get_uniqueid_static_magic(&uid, true); } else { @@ -106,6 +124,7 @@ void test_shmem_init_invalid_n_ranks(int rank_id, int n_ranks, uint64_t local_me shmemx_set_attr_uniqueid_args(rank_id, en_ranks, local_mem_size, &uid, &attributes); +#endif status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); EXPECT_EQ(status, SHMEM_INVALID_VALUE); @@ -124,6 +143,11 @@ void test_shmem_init_rank_id_over_size(int rank_id, int n_ranks, uint64_t local_ int status = SHMEM_SUCCESS; EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); + shmem_set_conf_store_tls(false, nullptr, 0); +#ifdef BACKEND_MF + shmem_init_attr_t *attributes; + shmem_set_attr(rank_id + n_ranks, n_ranks, local_mem_size, test_global_ipport, &attributes); +#else shmemx_uniqueid_t uid; if (rank_id == 0) { shmemi_get_uniqueid_static_magic(&uid, true); @@ -134,6 +158,7 @@ void test_shmem_init_rank_id_over_size(int rank_id, int n_ranks, uint64_t local_ shmemx_set_attr_uniqueid_args(rank_id + n_ranks, n_ranks, local_mem_size, &uid, &attributes); +#endif status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); EXPECT_EQ(status, SHMEM_INVALID_PARAM); status = shmem_init_status(); @@ -152,6 +177,11 @@ void test_shmem_init_zero_mem(int rank_id, int n_ranks, uint64_t local_mem_size) int status = SHMEM_SUCCESS; EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); + shmem_set_conf_store_tls(false, nullptr, 0); +#ifdef BACKEND_MF + shmem_init_attr_t *attributes; + shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); +#else shmemx_uniqueid_t uid; if (rank_id == 0) { shmemi_get_uniqueid_static_magic(&uid, true); @@ -162,6 +192,7 @@ void test_shmem_init_zero_mem(int rank_id, int n_ranks, uint64_t local_mem_size) shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, &uid, &attributes); +#endif status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); EXPECT_EQ(status, SHMEM_INVALID_VALUE); status = shmem_init_status(); diff --git a/tests/unittest/host/main_test.cpp b/tests/unittest/host/main_test.cpp index 002ec244..18634974 100644 --- a/tests/unittest/host/main_test.cpp +++ b/tests/unittest/host/main_test.cpp @@ -35,7 +35,11 @@ void test_init(int rank_id, int n_ranks, uint64_t local_mem_size, aclrtStream *s EXPECT_EQ(status = aclrtSetDevice(device_id), 0); aclrtStream stream = nullptr; EXPECT_EQ(status = aclrtCreateStream(&stream), 0); - + EXPECT_EQ(status = shmem_set_conf_store_tls(false, nullptr, 0), 0); +#ifdef BACKEND_MF + shmem_init_attr_t *attributes; + shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); +#else shmemx_uniqueid_t uid; if (rank_id == 0) { status = shmemi_get_uniqueid_static_magic(&uid, true); @@ -48,7 +52,9 @@ void test_init(int rank_id, int n_ranks, uint64_t local_mem_size, aclrtStream *s &uid, &attributes); EXPECT_EQ(status, SHMEM_SUCCESS); +#endif status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + EXPECT_EQ(status, 0); *st = stream; } @@ -67,7 +73,11 @@ int32_t test_rdma_init(int rank_id, int n_ranks, uint64_t local_mem_size, aclrtS EXPECT_EQ(status = aclrtSetDevice(device_id), 0); aclrtStream stream = nullptr; EXPECT_EQ(status = aclrtCreateStream(&stream), 0); - + EXPECT_EQ(status = shmem_set_conf_store_tls(false, nullptr, 0), 0); +#ifdef BACKEND_MF + shmem_init_attr_t *attributes; + shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); +#else shmemx_uniqueid_t uid; if (rank_id == 0) { status = shmemi_get_uniqueid_static_magic(&uid, true); @@ -81,7 +91,9 @@ int32_t test_rdma_init(int rank_id, int n_ranks, uint64_t local_mem_size, aclrtS &attributes); EXPECT_EQ(status, SHMEM_SUCCESS); attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; +#endif status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + EXPECT_EQ(status, 0); *st = stream; return status; diff --git a/tests/unittest/host/mem/shmem_host_heap_test.cpp b/tests/unittest/host/mem/shmem_host_heap_test.cpp index f3b4d039..2bddbb66 100644 --- a/tests/unittest/host/mem/shmem_host_heap_test.cpp +++ b/tests/unittest/host/mem/shmem_host_heap_test.cpp @@ -27,6 +27,10 @@ protected: int status = SHMEM_SUCCESS; EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); + #ifdef BACKEND_MF + shmem_init_attr_t *attributes; + shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); + #else shmemx_uniqueid_t uid; if (rank_id == 0) { status = shmemi_get_uniqueid_static_magic(&uid, true); @@ -39,6 +43,7 @@ protected: &uid, &attributes); EXPECT_EQ(status, SHMEM_SUCCESS); + #endif status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); EXPECT_EQ(status, SHMEM_SUCCESS); EXPECT_EQ(g_state.mype, rank_id); -- Gitee From 0ede9702d7f4f5fd080282863fb0bf1ff01fcf75 Mon Sep 17 00:00:00 2001 From: zhangyunqi05 Date: Sat, 29 Nov 2025 15:49:46 +0800 Subject: [PATCH 61/74] utfix and mf adapt --- CMakeLists.txt | 2 +- include/internal/host/shmemi_host_def.h | 16 ---------------- .../init/init_backends/mf/shmemi_init_mf.cpp | 11 +---------- src/host/init/init_backends/mf/shmemi_init_mf.h | 3 +++ src/host/init/shmem_init.cpp | 1 - 5 files changed, 5 insertions(+), 28 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 69ac48cd..4dabc368 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -94,7 +94,7 @@ link_directories( link_libraries(runtime stdc++ ascendcl m tiling_api platform c_sec dl nnopbase pthread) # MF_BACKEND -set(USE_MF "1") +set(USE_MF "0") if ("${USE_MF}" STREQUAL "1") add_compile_definitions(BACKEND_MF=1) diff --git a/include/internal/host/shmemi_host_def.h b/include/internal/host/shmemi_host_def.h index 012d32ce..05922d22 100644 --- a/include/internal/host/shmemi_host_def.h +++ b/include/internal/host/shmemi_host_def.h @@ -19,22 +19,6 @@ typedef enum { ADDR_IPv6 } addr_type_t; -// mf unique id -// typedef struct { -// union { -// struct sockaddr_in addr4; -// struct sockaddr_in6 addr6; -// } addr; -// addr_type_t type; -// } shmem_sockaddr_t; - -// typedef struct { -// int32_t version; -// int32_t inner_sockFd; -// shmem_sockaddr_t addr; -// uint64_t magic; -// } shmem_uniqueid_inner_t; - // shmem unique id typedef struct { union { diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.cpp b/src/host/init/init_backends/mf/shmemi_init_mf.cpp index c37bb2d4..26073e55 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.cpp +++ b/src/host/init/init_backends/mf/shmemi_init_mf.cpp @@ -11,21 +11,12 @@ #include "shmemi_init_mf.h" #ifdef BACKEND_MF - -// smem api -// #include -// #include -// #include -// #include -// #include -// #include -// #include #include #include #include #include #include "internal/host/shmemi_host_def.h" -// #include + constexpr int MIN_PORT = 1024; constexpr int MAX_PORT = 65536; diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.h b/src/host/init/init_backends/mf/shmemi_init_mf.h index 6226b1d3..7fca3049 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.h +++ b/src/host/init/init_backends/mf/shmemi_init_mf.h @@ -16,9 +16,12 @@ #include "init/init_backends/shmemi_init_base.h" #include "shmemi_host_common.h" #include "internal/host_device/shmemi_types.h" +#ifdef BACKEND_MF +// smem api #include #include #include +#endif class shmemi_init_mf: public shmemi_init_base { public: diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index ab47665c..07df5494 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -271,7 +271,6 @@ int32_t shmem_init_attr(shmemx_bootstrap_t bootstrap_flags, shmem_init_attr_t *a init_manager = new shmemi_init_mf(attributes, g_ipport); #else // bootstrap init - // shmemi_bootstrap_attr_t attr = {}; SHMEM_CHECK_RET(shmemi_bootstrap_init(bootstrap_flags, attributes)); init_manager = new shmemi_init_default(attributes); #endif -- Gitee From 045fb1c755c9d6cadf4223fa1ec52f69e2e76e45 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Sat, 29 Nov 2025 16:26:21 +0800 Subject: [PATCH 62/74] examples support 14/14 3.0 default backend --- examples/CMakeLists.txt | 6 +- examples/rdma_demo/main.cpp | 23 ++++-- .../unuse_handlewait/main.cpp | 32 +++++--- .../use_handlewait/main.cpp | 31 ++++--- examples/rdma_perftest/main.cpp | 80 ++++++++++--------- 5 files changed, 103 insertions(+), 69 deletions(-) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e72c2383..2c751d65 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -72,10 +72,6 @@ foreach(EXAMPLE matmul_reduce_scatter matmul_reduce_scatter_padding dynamic_tiling - # rdma_perftest - # rdma_demo - # rdma_handlewait_test/unuse_handlewait - # rdma_handlewait_test/use_handlewait ) add_subdirectory(${EXAMPLE}) endforeach() @@ -84,6 +80,8 @@ if(SHMEM_RDMA_SUPPORT) foreach(EXAMPLE rdma_perftest rdma_demo + rdma_handlewait_test/unuse_handlewait + rdma_handlewait_test/use_handlewait ) add_subdirectory(${EXAMPLE}) endforeach() diff --git a/examples/rdma_demo/main.cpp b/examples/rdma_demo/main.cpp index b34d26bb..5b9e9b5d 100644 --- a/examples/rdma_demo/main.cpp +++ b/examples/rdma_demo/main.cpp @@ -36,10 +36,18 @@ int test_shmem_team_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size status = aclrtSetDevice(device_id); status = aclrtCreateStream(&stream); + // Shmem uid pre-init + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmem_get_uniqueid(&uid); + } + std::cout << "MPI_Bcast!" << std::endl; + MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); + shmem_init_attr_t *attributes; - status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, &uid, &attributes); attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; - status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); uint8_t *ptr = static_cast(shmem_malloc(1024)); @@ -90,21 +98,20 @@ int main(int argc, char *argv[]) { int argIdx = 1; int status = 0; + // MPI Init MPI_Init(&argc, &argv); - - // 获取当前进程的编号(rank) - int n_ranks; - MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); - int rank_id; + int rank_id, n_ranks; MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); + MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + ipport = argv[1]; g_npus = atoi(argv[2]); f_rank = atoi(argv[3]); f_npu = atoi(argv[4]); uint64_t local_mem_size = 1024UL * 1024UL * 1024; status = test_shmem_team_all_gather(rank_id, n_ranks, local_mem_size); + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; MPI_Finalize(); - return 0; } \ No newline at end of file diff --git a/examples/rdma_handlewait_test/unuse_handlewait/main.cpp b/examples/rdma_handlewait_test/unuse_handlewait/main.cpp index 3f31a420..10ce0a30 100644 --- a/examples/rdma_handlewait_test/unuse_handlewait/main.cpp +++ b/examples/rdma_handlewait_test/unuse_handlewait/main.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include "acl/acl.h" #include "shmem_api.h" @@ -39,11 +40,18 @@ int test_shmem_team_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size status = aclrtSetDevice(device_id); status = aclrtCreateStream(&stream); + // Shmem uid pre-init + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmem_get_uniqueid(&uid); + } + std::cout << "MPI_Bcast!" << std::endl; + MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); + shmem_init_attr_t *attributes; - status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, &uid, &attributes); attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); uint8_t *ptr = static_cast(shmem_malloc(mem_size)); uint8_t *ptr_A = ptr + half_mem_size; @@ -97,14 +105,18 @@ int test_shmem_team_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size int main(int argc, char *argv[]) { - int argIdx = 1; int status = 0; - int n_ranks = atoi(argv[argIdx++]); - int rank_id = atoi(argv[argIdx++]); - ipport = argv[argIdx++]; - g_npus = atoi(argv[argIdx++]); - f_rank = atoi(argv[argIdx++]); - f_npu = atoi(argv[argIdx++]); + + // MPI Init + MPI_Init(&argc, &argv); + int rank_id, n_ranks; + MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); + MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + + ipport = argv[1]; + g_npus = atoi(argv[2]); + f_rank = atoi(argv[3]); + f_npu = atoi(argv[4]); uint64_t local_mem_size = 1024UL * 1024UL * 1024; status = test_shmem_team_all_gather(rank_id, n_ranks, local_mem_size); std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; diff --git a/examples/rdma_handlewait_test/use_handlewait/main.cpp b/examples/rdma_handlewait_test/use_handlewait/main.cpp index 81766bf4..6e2c6ec3 100644 --- a/examples/rdma_handlewait_test/use_handlewait/main.cpp +++ b/examples/rdma_handlewait_test/use_handlewait/main.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include "acl/acl.h" #include "shmem_api.h" @@ -39,11 +40,18 @@ int test_shmem_team_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size status = aclrtSetDevice(device_id); status = aclrtCreateStream(&stream); + // Shmem uid pre-init + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmem_get_uniqueid(&uid); + } + std::cout << "MPI_Bcast!" << std::endl; + MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); + shmem_init_attr_t *attributes; - status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, &uid, &attributes); attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); uint8_t *ptr = static_cast(shmem_malloc(mem_size)); uint8_t *ptr_A = ptr + half_mem_size; @@ -101,14 +109,17 @@ int test_shmem_team_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size int main(int argc, char *argv[]) { - int argIdx = 1; int status = 0; - int n_ranks = atoi(argv[argIdx++]); - int rank_id = atoi(argv[argIdx++]); - ipport = argv[argIdx++]; - g_npus = atoi(argv[argIdx++]); - f_rank = atoi(argv[argIdx++]); - f_npu = atoi(argv[argIdx++]); + // MPI Init + MPI_Init(&argc, &argv); + int rank_id, n_ranks; + MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); + MPI_Comm_size(MPI_COMM_WORLD, &n_ranks); + + ipport = argv[1]; + g_npus = atoi(argv[2]); + f_rank = atoi(argv[3]); + f_npu = atoi(argv[4]); uint64_t local_mem_size = 1024UL * 1024UL * 1024; status = test_shmem_team_all_gather(rank_id, n_ranks, local_mem_size); std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; diff --git a/examples/rdma_perftest/main.cpp b/examples/rdma_perftest/main.cpp index b86cbbd5..2ed26be1 100644 --- a/examples/rdma_perftest/main.cpp +++ b/examples/rdma_perftest/main.cpp @@ -32,7 +32,7 @@ extern void rdma_postsend_cost_do(uint32_t block_dim, void* stream, uint64_t fft extern void rdma_highlevel_put_bw_do(uint32_t block_dim, void* stream, uint64_t cfg, uint8_t* gva, int len); extern void rdma_mte_put_bw_do(uint32_t block_dim, void* stream, uint64_t cfg, uint8_t* gva, int len, int64_t iter); -int test_shmem_rdma_highlevel_put_pingpong_latency(int rank_id, int n_ranks, uint64_t mem_size, int message_length) +int test_shmem_rdma_highlevel_put_pingpong_latency(int rank_id, int n_ranks, uint64_t local_mem_size, int message_length) { uint32_t iteration = 1; int32_t device_id = rank_id % g_npus + f_npu; @@ -47,10 +47,18 @@ int test_shmem_rdma_highlevel_put_pingpong_latency(int rank_id, int n_ranks, uin status = aclrtSetDevice(device_id); status = aclrtCreateStream(&stream); + // Shmem uid pre-init + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmem_get_uniqueid(&uid); + } + std::cout << "MPI_Bcast!" << std::endl; + MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); + shmem_init_attr_t *attributes; - status = shmem_set_attr(rank_id, n_ranks, mem_size, ipport, &attributes); + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, &uid, &attributes); attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; - status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); uint64_t fftsConfig = shmemx_get_ffts_config(); uint8_t *gva = static_cast(shmem_malloc(size6M)); @@ -99,10 +107,18 @@ int test_shmem_rdma_postsend_cost(int rank_id, int n_ranks, uint64_t local_mem_s status = aclrtSetDevice(device_id); status = aclrtCreateStream(&stream); + // Shmem uid pre-init + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmem_get_uniqueid(&uid); + } + std::cout << "MPI_Bcast!" << std::endl; + MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); + shmem_init_attr_t *attributes; - status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, &uid, &attributes); attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; - status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); uint64_t fftsConfig = shmemx_get_ffts_config(); uint8_t *gva = static_cast(shmem_malloc(size6M)); @@ -148,10 +164,18 @@ int test_shmem_rdma_highlevel_put_bw(int rank_id, int n_ranks, uint64_t local_me status = aclrtSetDevice(device_id); status = aclrtCreateStream(&stream); + // Shmem uid pre-init + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmem_get_uniqueid(&uid); + } + std::cout << "MPI_Bcast!" << std::endl; + MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); + shmem_init_attr_t *attributes; - status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, &uid, &attributes); attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; - status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); uint64_t fftsConfig = shmemx_get_ffts_config(); uint8_t *gva = static_cast(shmem_malloc(size6M)); @@ -194,11 +218,19 @@ int test_shmem_rdma_mte_put_bw(int rank_id, int n_ranks, uint64_t local_mem_size status = aclrtSetDevice(device_id); status = aclrtCreateStream(&stream); - shmem_init_attr_t *attributes; - status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + // Shmem uid pre-init + shmemx_uniqueid_t uid; + if (rank_id == 0) { + shmem_get_uniqueid(&uid); + } + std::cout << "MPI_Bcast!" << std::endl; + MPI_Bcast(&uid, sizeof(shmemx_uniqueid_t), MPI_UINT8_T, 0, MPI_COMM_WORLD); + + shmem_init_attr_t *attributes; + status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, &uid, &attributes); attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + shmem_mte_set_ub_params(0, 128 * 1024, 0); uint64_t fftsConfig = shmemx_get_ffts_config(); @@ -254,23 +286,13 @@ int test_shmem_rdma_mte_put_bw(int rank_id, int n_ranks, uint64_t local_mem_size int main(int argc, char *argv[]) { -<<<<<<< HEAD - const int expected_argc = 9; - if (argc != expected_argc) { -======= if (argc != 7) { ->>>>>>> 030043274ef343b23b19afb659ffa26346e1f9b7 std::cout << "[ERROR] Paramater number mismatch." << std::endl; std::cout << "[USAGE] ./rdma_perftest " << " . See README for more details." << std::endl; } int sub = 1; int status = 0; -<<<<<<< HEAD - int n_ranks = atoi(argv[sub++]); - const int rank_max = 2; - if (n_ranks != rank_max) { -======= MPI_Init(&argc, &argv); // 获取当前进程的编号(rank) @@ -279,34 +301,19 @@ int main(int argc, char *argv[]) int rank_id; MPI_Comm_rank(MPI_COMM_WORLD, &rank_id); if (n_ranks != 2) { ->>>>>>> 030043274ef343b23b19afb659ffa26346e1f9b7 std::cout << "[ERROR] Error number of ranks! Only support 2 ranks!" << std::endl; return -1; } -<<<<<<< HEAD - int rank_id = atoi(argv[sub++]); - if (rank_id >= rank_max) { -======= if (rank_id >= 2) { ->>>>>>> 030043274ef343b23b19afb659ffa26346e1f9b7 std::cout << "[ERROR] Error rank ID! Only support 2 ranks!" << std::endl; return -1; } -<<<<<<< HEAD - ipport = argv[sub++]; - g_npus = atoi(argv[sub++]); - f_rank = atoi(argv[sub++]); - f_npu = atoi(argv[sub++]); - test_type = argv[sub++]; - int msg_len = atoi(argv[sub++]); -======= ipport = argv[1]; g_npus = atoi(argv[2]); f_rank = atoi(argv[3]); f_npu = atoi(argv[4]); test_type = argv[5]; int msg_len = atoi(argv[6]); ->>>>>>> 030043274ef343b23b19afb659ffa26346e1f9b7 uint64_t local_mem_size = 1024UL * 1024UL * 64; if (std::string(test_type) == "highlevel_put_pingpong_latency") { test_shmem_rdma_highlevel_put_pingpong_latency(rank_id, n_ranks, local_mem_size, msg_len); @@ -320,6 +327,5 @@ int main(int argc, char *argv[]) std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; MPI_Finalize(); - return 0; } \ No newline at end of file -- Gitee From 00129f4fa391eeab8ef91536bf56c515e569f891 Mon Sep 17 00:00:00 2001 From: zhangyunqi05 Date: Mon, 1 Dec 2025 11:53:24 +0800 Subject: [PATCH 63/74] mf uid init adapt part 1 --- src/host/init/init_backends/mf/shmemi_init_mf.cpp | 13 ++++++++++++- src/host/init/init_backends/mf/shmemi_init_mf.h | 1 + src/host/init/shmem_init.cpp | 8 ++++++++ tests/unittest/host/main_test.cpp | 2 +- 4 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.cpp b/src/host/init/init_backends/mf/shmemi_init_mf.cpp index 26073e55..e92a5172 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.cpp +++ b/src/host/init/init_backends/mf/shmemi_init_mf.cpp @@ -65,7 +65,7 @@ int shmemi_init_mf::init_device_state() SHM_LOG_ERROR("smem_shm_config_init Failed"); return SHMEM_SMEM_ERROR; } - + config.sockFd = attributes->option_attr.sockFd; status = smem_shm_init(attributes->ip_port, attributes->n_ranks, attributes->my_rank, device_id, &config); if (status != SHMEM_SUCCESS) { SHM_LOG_ERROR("smem_shm_init Failed"); @@ -535,4 +535,15 @@ int32_t shmem_get_uniqueid_mf(shmemx_uniqueid_t *uid) return shmem_set_ip_info(uid, sockType, pta_env_ip, pta_env_port, is_from_ifa); } +int32_t shmemi_control_barrier_all_mf() +{ + SHM_ASSERT_RETURN(g_smem_handle != nullptr, SHMEM_INVALID_PARAM); + auto ret = smem_shm_control_barrier(g_smem_handle); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("Barrier failed"); + return ret; + } + return SHMEM_SUCCESS; +} + #endif \ No newline at end of file diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.h b/src/host/init/init_backends/mf/shmemi_init_mf.h index 7fca3049..4ddb5c16 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.h +++ b/src/host/init/init_backends/mf/shmemi_init_mf.h @@ -47,5 +47,6 @@ private: }; int32_t shmem_get_uniqueid_mf(shmemx_uniqueid_t *uid); +int32_t shmemi_control_barrier_all_mf(); #endif // SHMEMI_INIT_MF_H \ No newline at end of file diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index 07df5494..bb0899e6 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -144,7 +144,11 @@ shmemi_init_base* init_manager; int32_t shmemi_control_barrier_all() { +#ifdef BACKEND_MF + return shmemi_control_barrier_all_mf(); +#else return g_boot_handle.barrier(&g_boot_handle); +#endif } int32_t update_device_state() @@ -249,9 +253,13 @@ int shmemx_set_attr_uniqueid_args(const int my_rank, const int n_ranks, const in std::copy(ipPort.begin(), ipPort.end(), g_attr.ip_port); g_ipport[ipPort.size()] = '\0'; g_attr.ip_port[ipPort.size()] = '\0'; + int attr_version = static_cast((1 << 16) + sizeof(shmem_init_attr_t)); + g_attr.option_attr = {attr_version, SHMEM_DATA_OP_MTE, DEFAULT_TIMEOUT, + DEFAULT_TIMEOUT, DEFAULT_TIMEOUT, 0}; g_attr.option_attr.sockFd = uid_args->inner_sockFd; SHM_LOG_INFO("extract ip port:" << ipPort); #endif + g_attr_init = true; return SHMEM_SUCCESS; } diff --git a/tests/unittest/host/main_test.cpp b/tests/unittest/host/main_test.cpp index 18634974..dd8b488c 100644 --- a/tests/unittest/host/main_test.cpp +++ b/tests/unittest/host/main_test.cpp @@ -90,8 +90,8 @@ int32_t test_rdma_init(int rank_id, int n_ranks, uint64_t local_mem_size, aclrtS &uid, &attributes); EXPECT_EQ(status, SHMEM_SUCCESS); - attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; #endif + attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); EXPECT_EQ(status, 0); -- Gitee From 05ca4aa99c3f9aba5997d91be45486b21d1e6548 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Mon, 1 Dec 2025 21:08:03 +0800 Subject: [PATCH 64/74] MF && Default Backend Examples OK(Besides Rdma) --- examples/allgather/run.sh | 2 + examples/allgather_matmul/scripts/run.sh | 4 +- .../allgather_matmul_padding/scripts/run.sh | 4 +- .../scripts/run.sh | 4 +- examples/dispatch_gmm_combine/scripts/run.sh | 5 +- examples/dynamic_tiling/scripts/run.sh | 10 +-- examples/kv_shuffle/scripts/run.sh | 6 +- examples/matmul_allreduce/scripts/run.sh | 4 +- examples/matmul_reduce_scatter/scripts/run.sh | 4 +- .../scripts/run.sh | 4 +- .../low_level/shmem_device_low_level_rma.h | 4 +- include/internal/host_device/shmemi_types.h | 19 ++--- .../default/shmemi_init_default.cpp | 59 +++++++++++----- .../default/shmemi_init_default.h | 9 +-- .../init/init_backends/mf/shmemi_init_mf.cpp | 70 ++++++++++++++----- .../init/init_backends/mf/shmemi_init_mf.h | 9 +-- .../init/init_backends/shmemi_init_base.h | 7 +- src/host/init/shmem_init.cpp | 30 +++++--- src/host/mem/shmem_rma.cpp | 4 +- src/host/transport/shmemi_transport.cpp | 58 +++++++-------- src/host/transport/shmemi_transport.h | 8 +-- src/modules/transport/shmemi_rdma.cpp | 2 +- tests/unittest/CMakeLists.txt | 2 +- 23 files changed, 187 insertions(+), 141 deletions(-) diff --git a/examples/allgather/run.sh b/examples/allgather/run.sh index 58212798..ae984841 100644 --- a/examples/allgather/run.sh +++ b/examples/allgather/run.sh @@ -95,6 +95,8 @@ python3 ./scripts/data_gen.py $RANK_SIZE $TEST_TYPE # Kernel test rm -rf ./output + +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 mpirun -np ${GNPU_NUM} msprof --application="${PROJECT_ROOT}/build/bin/allgather" --output="${PROJECT_ROOT}/examples/allgather/output/" # Profiling data statistic diff --git a/examples/allgather_matmul/scripts/run.sh b/examples/allgather_matmul/scripts/run.sh index 79d2c769..5b4674c5 100644 --- a/examples/allgather_matmul/scripts/run.sh +++ b/examples/allgather_matmul/scripts/run.sh @@ -40,9 +40,7 @@ tail -n +2 "$CSV_FILE" | while IFS=',' read -r M K N; do IPPORT="tcp://127.0.0.1:8788" # Start Process - # for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do - # ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & - # done + export SHMEM_UID_SESSION_ID=127.0.0.1:8899 mpirun -np ${RANK_SIZE} ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" # Wait until all process exit diff --git a/examples/allgather_matmul_padding/scripts/run.sh b/examples/allgather_matmul_padding/scripts/run.sh index 1d8b9541..71150f16 100644 --- a/examples/allgather_matmul_padding/scripts/run.sh +++ b/examples/allgather_matmul_padding/scripts/run.sh @@ -40,9 +40,7 @@ tail -n +2 "$CSV_FILE" | while IFS=',' read -r M K N; do IPPORT="tcp://127.0.0.1:8788" # Start Process - # for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do - # ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & - # done + export SHMEM_UID_SESSION_ID=127.0.0.1:8899 mpirun -np ${RANK_SIZE} ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" # Wait until all process exit diff --git a/examples/allgather_matmul_with_gather_result/scripts/run.sh b/examples/allgather_matmul_with_gather_result/scripts/run.sh index 6a9d8543..95eaebad 100644 --- a/examples/allgather_matmul_with_gather_result/scripts/run.sh +++ b/examples/allgather_matmul_with_gather_result/scripts/run.sh @@ -40,9 +40,7 @@ tail -n +2 "$CSV_FILE" | while IFS=',' read -r M K N; do IPPORT="tcp://127.0.0.1:8788" # Start Process - # for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do - # ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & - # done + export SHMEM_UID_SESSION_ID=127.0.0.1:8899 mpirun -np ${RANK_SIZE} ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" # Wait until all process exit diff --git a/examples/dispatch_gmm_combine/scripts/run.sh b/examples/dispatch_gmm_combine/scripts/run.sh index 2274a287..d2d2ec73 100644 --- a/examples/dispatch_gmm_combine/scripts/run.sh +++ b/examples/dispatch_gmm_combine/scripts/run.sh @@ -115,10 +115,7 @@ EXEC_BIN=${PROJECT_ROOT}/build/bin/dispatch_gmm_combine cd ${PROJECT_ROOT}/examples/dispatch_gmm_combine/ echo "Test Case, M: ${M}, K: ${K}, N: ${N}, expertPerRank: ${expertPerRank}" -# export LD_LIBRARY_PATH=${PROJECT_ROOT}/install/shmem/lib:${ASCEND_HOME_PATH}/lib64:${PROJECT_ROOT}/install/memfabric_hybrid/lib:$LD_LIBRARY_PATH -# for (( idx =0; idx < ${RANK_SIZE}; idx = idx + 1 )); do -# INPUT_PATH=${CURRENT_DIR}/utils/test_data/ ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$FIRST_NPU" "$M" "$K" "$N" "$expertPerRank" "$dataType" "$weightNz" "$transB" & -# done +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 export INPUT_PATH=${CURRENT_DIR}/utils/test_data/ mpirun -np ${RANK_SIZE} ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$FIRST_NPU" "$M" "$K" "$N" "$expertPerRank" "$dataType" "$weightNz" "$transB" diff --git a/examples/dynamic_tiling/scripts/run.sh b/examples/dynamic_tiling/scripts/run.sh index 554bbc62..70d21eba 100644 --- a/examples/dynamic_tiling/scripts/run.sh +++ b/examples/dynamic_tiling/scripts/run.sh @@ -82,10 +82,7 @@ if [ "$TEST_TYPE" = "0" ]; then IPPORT="tcp://127.0.0.1:27008" # Start Process - # for (( idx =0; idx < ${RANK_SIZE}; idx = idx + 1 )); do - # APP="$EXEC_BIN $COMM_TYPE $DATA_TYPE $RANK_SIZE $idx $IPPORT $M $N $K $TEST_START_LINE $TEST_COLLECT_ROWS $PARENT_PATH $CSV_FILE $DEVICE_ID_STR $DATA_PATH" - # ${APP}& - # done + export SHMEM_UID_SESSION_ID=127.0.0.1:8899 mpirun -np ${RANK_SIZE} $EXEC_BIN $COMM_TYPE $DATA_TYPE $RANK_SIZE 0 $IPPORT $M $N $K $TEST_START_LINE $TEST_COLLECT_ROWS $PARENT_PATH $CSV_FILE $DEVICE_ID_STR $DATA_PATH # Wait until all process exit @@ -118,10 +115,7 @@ else OUTPUT_PATH="./output/msprof/start_line${IDX}_run_rows${TEST_COLLECT_ROWS}/" # Start Process - # for (( idx =0; idx < ${RANK_SIZE}; idx = idx + 1 )); do - # APP="$EXEC_BIN $COMM_TYPE $DATA_TYPE $RANK_SIZE $idx $IPPORT $M $N $K $TEST_START_LINE $TEST_COLLECT_ROWS $PARENT_PATH $CSV_FILE $DEVICE_ID_STR" - # msprof --application="${APP}" --output="${OUTPUT_PATH}"& - # done + export SHMEM_UID_SESSION_ID=127.0.0.1:8899 APP="$EXEC_BIN $COMM_TYPE $DATA_TYPE $RANK_SIZE 0 $IPPORT $M $N $K $TEST_START_LINE $TEST_COLLECT_ROWS $PARENT_PATH $CSV_FILE $DEVICE_ID_STR" mpirun -np ${RANK_SIZE} msprof --application="${APP}" --output="${OUTPUT_PATH}" diff --git a/examples/kv_shuffle/scripts/run.sh b/examples/kv_shuffle/scripts/run.sh index 32a002aa..73cd8808 100644 --- a/examples/kv_shuffle/scripts/run.sh +++ b/examples/kv_shuffle/scripts/run.sh @@ -23,13 +23,9 @@ rm -rf scripts/output/*.bin python3 scripts/golden.py $RANK_SIZE # Start Process +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 mpirun -np ${RANK_SIZE} ${PROJECT_ROOT}/build/bin/kv_shuffle -# for (( idx =0; idx < ${RANK_SIZE}; idx = idx + 1 )); do -# APP="$EXEC_BIN $RANK_SIZE $idx $IPPORT" -# ${APP}& -# done - # Wait until all process exit wait diff --git a/examples/matmul_allreduce/scripts/run.sh b/examples/matmul_allreduce/scripts/run.sh index e4f3f6cc..7248520d 100644 --- a/examples/matmul_allreduce/scripts/run.sh +++ b/examples/matmul_allreduce/scripts/run.sh @@ -39,9 +39,7 @@ tail -n +2 "$CSV_FILE" | while IFS=',' read -r M K N; do IPPORT="tcp://127.0.0.1:8788" # Start Process - # for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do - # ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & - # done + export SHMEM_UID_SESSION_ID=127.0.0.1:8899 mpirun -np ${RANK_SIZE} ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" # Wait until all process exit diff --git a/examples/matmul_reduce_scatter/scripts/run.sh b/examples/matmul_reduce_scatter/scripts/run.sh index a64af5cf..ad4009a2 100644 --- a/examples/matmul_reduce_scatter/scripts/run.sh +++ b/examples/matmul_reduce_scatter/scripts/run.sh @@ -39,9 +39,7 @@ tail -n +2 "$CSV_FILE" | while IFS=',' read -r M K N; do IPPORT="tcp://127.0.0.1:8788" # Start Process - # for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do - # ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & - # done + export SHMEM_UID_SESSION_ID=127.0.0.1:8899 mpirun -np ${RANK_SIZE} ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" # Wait until all process exit diff --git a/examples/matmul_reduce_scatter_padding/scripts/run.sh b/examples/matmul_reduce_scatter_padding/scripts/run.sh index 30871b22..cb06d74d 100644 --- a/examples/matmul_reduce_scatter_padding/scripts/run.sh +++ b/examples/matmul_reduce_scatter_padding/scripts/run.sh @@ -39,9 +39,7 @@ tail -n +2 "$CSV_FILE" | while IFS=',' read -r M K N; do IPPORT="tcp://127.0.0.1:8788" # Start Process - # for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do - # ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & - # done + export SHMEM_UID_SESSION_ID=127.0.0.1:8899 mpirun -np ${RANK_SIZE} ${EXEC_BIN} "$RANK_SIZE" 0 "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" # Wait until all process exit diff --git a/include/device/low_level/shmem_device_low_level_rma.h b/include/device/low_level/shmem_device_low_level_rma.h index 92afdbce..f6abc8da 100644 --- a/include/device/low_level/shmem_device_low_level_rma.h +++ b/include/device/low_level/shmem_device_low_level_rma.h @@ -34,7 +34,7 @@ SHMEM_DEVICE __gm__ void *shmem_ptr(__gm__ void *ptr, int pe) uint64_t offset = reinterpret_cast(ptr) - reinterpret_cast(device_state->heap_base); // Address translate - uint64_t remote_ptr = reinterpret_cast(device_state->p2p_heap_base[pe]) + offset; + uint64_t remote_ptr = reinterpret_cast(device_state->device_p2p_heap_base[pe]) + offset; return reinterpret_cast<__gm__ void *>(remote_ptr); } @@ -209,7 +209,7 @@ SHMEM_DEVICE __gm__ void *shmem_roce_ptr(__gm__ void *ptr, int pe) uint64_t offset = reinterpret_cast(ptr) - reinterpret_cast(device_state->heap_base); // Address translate - uint64_t remote_ptr = reinterpret_cast(device_state->rdma_heap_base[pe]) + offset; + uint64_t remote_ptr = reinterpret_cast(device_state->device_rdma_heap_base[pe]) + offset; return reinterpret_cast<__gm__ void *>(remote_ptr); } diff --git a/include/internal/host_device/shmemi_types.h b/include/internal/host_device/shmemi_types.h index 20568911..2a5f60bb 100644 --- a/include/internal/host_device/shmemi_types.h +++ b/include/internal/host_device/shmemi_types.h @@ -78,15 +78,16 @@ typedef struct { int npes; void *heap_base; - void **p2p_heap_host_base; - void **sdma_heap_host_base; - void **roce_heap_host_base; - void **p2p_heap_device_base; - void **sdma_heap_device_base; - void **roce_heap_device_base; - void *p2p_heap_base[SHMEM_MAX_RANKS]; - void *rdma_heap_base[SHMEM_MAX_RANKS]; - void *sdma_heap_base[SHMEM_MAX_RANKS]; + // Store All Devices' heap_base in Host. + void **host_p2p_heap_base; + void **host_rdma_heap_base; + void **host_sdma_heap_base; + + // Store All Devices' heap_base in Device. + void **device_p2p_heap_base; + void **device_rdma_heap_base; + void **device_sdma_heap_base; + uint8_t topo_list[SHMEM_MAX_RANKS]; size_t heap_size; diff --git a/src/host/init/init_backends/default/shmemi_init_default.cpp b/src/host/init/init_backends/default/shmemi_init_default.cpp index 2d92ebb6..b92fc580 100644 --- a/src/host/init/init_backends/default/shmemi_init_default.cpp +++ b/src/host/init/init_backends/default/shmemi_init_default.cpp @@ -10,11 +10,13 @@ #include "shmemi_init_default.h" #include "common/shmemi_logger.h" -shmemi_init_default::shmemi_init_default(shmem_init_attr_t *attr) +shmemi_init_default::shmemi_init_default(shmem_init_attr_t *attr, shmemi_device_host_state_t *global_state) { mype = attr->my_rank; npes = attr->n_ranks; option_attr_ = attr->option_attr; + g_state = global_state; + auto status = aclrtGetDevice(&device_id); if (status != 0) { SHM_LOG_ERROR("Get Device_id error"); @@ -22,12 +24,7 @@ shmemi_init_default::shmemi_init_default(shmem_init_attr_t *attr) } shmemi_init_default::~shmemi_init_default() -{ - finalize_device_state(); - remove_heap(); - release_heap(); - transport_finalize(); -} +{} int shmemi_init_default::init_device_state() { @@ -46,29 +43,41 @@ int shmemi_init_default::finalize_device_state() int shmemi_init_default::update_device_state(void* host_ptr, size_t size) { + int32_t ptr_size = npes * sizeof(void *); + SHMEM_CHECK_RET(aclrtMemcpy(g_state->device_p2p_heap_base, ptr_size, g_state->host_p2p_heap_base, ptr_size, ACL_MEMCPY_HOST_TO_DEVICE)); + SHMEM_CHECK_RET(aclrtMemcpy(g_state->device_rdma_heap_base, ptr_size, g_state->host_rdma_heap_base, ptr_size, ACL_MEMCPY_HOST_TO_DEVICE)); + SHMEM_CHECK_RET(aclrtMemcpy(g_state->device_sdma_heap_base, ptr_size, g_state->host_sdma_heap_base, ptr_size, ACL_MEMCPY_HOST_TO_DEVICE)); + SHMEM_CHECK_RET(aclrtMemcpy(global_state_d->get_ptr(), size, host_ptr, size, ACL_MEMCPY_HOST_TO_DEVICE)); return SHMEM_SUCCESS; } -int shmemi_init_default::reserve_heap(shmemi_device_host_state_t &g_state) +int shmemi_init_default::reserve_heap() { heap_obj = new shmem_symmetric_heap(mype, npes, device_id); - SHMEM_CHECK_RET(heap_obj->reserve_heap(g_state.heap_size)); + SHMEM_CHECK_RET(heap_obj->reserve_heap(g_state->heap_size)); + + g_state->heap_base = heap_obj->get_heap_base(); - g_state.heap_base = heap_obj->get_heap_base(); + SHMEM_CHECK_RET(aclrtMallocHost((void **)&g_state->host_p2p_heap_base, npes * sizeof(void *))); + SHMEM_CHECK_RET(aclrtMallocHost((void **)&g_state->host_rdma_heap_base, npes * sizeof(void *))); + SHMEM_CHECK_RET(aclrtMallocHost((void **)&g_state->host_sdma_heap_base, npes * sizeof(void *))); + SHMEM_CHECK_RET(aclrtMalloc((void **)&g_state->device_p2p_heap_base, npes * sizeof(void *), ACL_MEM_MALLOC_HUGE_FIRST)); + SHMEM_CHECK_RET(aclrtMalloc((void **)&g_state->device_rdma_heap_base, npes * sizeof(void *), ACL_MEM_MALLOC_HUGE_FIRST)); + SHMEM_CHECK_RET(aclrtMalloc((void **)&g_state->device_sdma_heap_base, npes * sizeof(void *), ACL_MEM_MALLOC_HUGE_FIRST)); return SHMEM_SUCCESS; } -int shmemi_init_default::setup_heap(shmemi_device_host_state_t &g_state) +int shmemi_init_default::setup_heap() { SHMEM_CHECK_RET(heap_obj->setup_heap()); - for (int32_t i = 0; i < g_state.npes; i++) { - g_state.p2p_heap_base[i] = heap_obj->get_peer_heap_base_p2p(i); + for (int32_t i = 0; i < g_state->npes; i++) { + g_state->host_p2p_heap_base[i] = heap_obj->get_peer_heap_base_p2p(i); } - g_state.is_shmem_created = true; + g_state->is_shmem_created = true; return SHMEM_SUCCESS; } @@ -81,11 +90,29 @@ int shmemi_init_default::remove_heap() int shmemi_init_default::release_heap() { + if (g_state->host_p2p_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFreeHost(g_state->host_p2p_heap_base)); + } + if (g_state->host_p2p_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFreeHost(g_state->host_rdma_heap_base)); + } + if (g_state->host_p2p_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFreeHost(g_state->host_sdma_heap_base)); + } + if (g_state->device_p2p_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFree(g_state->device_p2p_heap_base)); + } + if (g_state->device_rdma_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFree(g_state->device_rdma_heap_base)); + } + if (g_state->device_sdma_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFree(g_state->device_sdma_heap_base)); + } SHMEM_CHECK_RET(heap_obj->unreserve_heap()); return SHMEM_SUCCESS; } -int shmemi_init_default::transport_init(shmemi_device_host_state_t &g_state) +int shmemi_init_default::transport_init() { SHMEM_CHECK_RET(shmemi_transport_init(g_state, option_attr_)); // mte init && rdma init SHMEM_CHECK_RET(shmemi_build_transport_map(g_state)); // build transport_map @@ -95,6 +122,6 @@ int shmemi_init_default::transport_init(shmemi_device_host_state_t &g_state) int shmemi_init_default::transport_finalize() { - SHMEM_CHECK_RET(shmemi_transport_finalize()); + SHMEM_CHECK_RET(shmemi_transport_finalize(g_state)); return SHMEM_SUCCESS; } \ No newline at end of file diff --git a/src/host/init/init_backends/default/shmemi_init_default.h b/src/host/init/init_backends/default/shmemi_init_default.h index d4a2f72c..07fadb86 100644 --- a/src/host/init/init_backends/default/shmemi_init_default.h +++ b/src/host/init/init_backends/default/shmemi_init_default.h @@ -27,24 +27,25 @@ class shmemi_init_default: public shmemi_init_base { public: - shmemi_init_default(shmem_init_attr_t *attr); + shmemi_init_default(shmem_init_attr_t *attr, shmemi_device_host_state_t *global_state); ~shmemi_init_default(); int init_device_state() override; int finalize_device_state() override; int update_device_state(void* host_ptr, size_t size) override; - int reserve_heap(shmemi_device_host_state_t &g_state) override; - int setup_heap(shmemi_device_host_state_t &g_state) override; + int reserve_heap() override; + int setup_heap() override; int remove_heap() override; int release_heap() override; - int transport_init(shmemi_device_host_state_t &g_state) override; + int transport_init() override; int transport_finalize() override; private: int mype; int npes; int device_id; + shmemi_device_host_state_t *g_state; // global_state global_state_reigister *global_state_d = nullptr; diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.cpp b/src/host/init/init_backends/mf/shmemi_init_mf.cpp index e92a5172..aa30eccb 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.cpp +++ b/src/host/init/init_backends/mf/shmemi_init_mf.cpp @@ -35,10 +35,11 @@ constexpr int DEFAULT_BLOCK_NUM = 1; static smem_shm_t g_smem_handle = nullptr; static char g_ipport[SHMEM_MAX_IP_PORT_LEN] = {0}; -shmemi_init_mf::shmemi_init_mf(shmem_init_attr_t *attr, char *ipport) +shmemi_init_mf::shmemi_init_mf(shmem_init_attr_t *attr, char *ipport, shmemi_device_host_state_t *global_state) { attributes = attr; g_ipport = ipport; + g_state = global_state; aclrtGetDevice(&device_id); smem_set_conf_store_tls(false, nullptr, 0); @@ -50,11 +51,7 @@ shmemi_init_mf::shmemi_init_mf(shmem_init_attr_t *attr, char *ipport) } shmemi_init_mf::~shmemi_init_mf() -{ - finalize_device_state(); - remove_heap(); - release_heap(); -} +{} int shmemi_init_mf::init_device_state() { @@ -85,7 +82,13 @@ int shmemi_init_mf::update_device_state(void* host_ptr, size_t size) SHM_LOG_ERROR("smem_shm_create Not Success, update_device_state Failed"); return SHMEM_SMEM_ERROR; } - return smem_shm_set_extra_context(g_smem_handle, host_ptr, size); + int32_t ptr_size = g_state->npes * sizeof(void *); + SHMEM_CHECK_RET(aclrtMemcpy(g_state->device_p2p_heap_base, ptr_size, g_state->host_p2p_heap_base, ptr_size, ACL_MEMCPY_HOST_TO_DEVICE)); + SHMEM_CHECK_RET(aclrtMemcpy(g_state->device_rdma_heap_base, ptr_size, g_state->host_rdma_heap_base, ptr_size, ACL_MEMCPY_HOST_TO_DEVICE)); + SHMEM_CHECK_RET(aclrtMemcpy(g_state->device_sdma_heap_base, ptr_size, g_state->host_sdma_heap_base, ptr_size, ACL_MEMCPY_HOST_TO_DEVICE)); + + SHMEM_CHECK_RET(smem_shm_set_extra_context(g_smem_handle, host_ptr, size)); + return SHMEM_SUCCESS; } int shmemi_init_mf::finalize_device_state() @@ -94,11 +97,12 @@ int shmemi_init_mf::finalize_device_state() return SHMEM_SUCCESS; } -int shmemi_init_mf::reserve_heap(shmemi_device_host_state_t &g_state) +int shmemi_init_mf::reserve_heap() { int32_t status = SHMEM_SUCCESS; void *gva = nullptr; - g_smem_handle = smem_shm_create(DEFAULT_ID, attributes->n_ranks, attributes->my_rank, g_state.heap_size, + + g_smem_handle = smem_shm_create(DEFAULT_ID, attributes->n_ranks, attributes->my_rank, g_state->heap_size, static_cast(attributes->option_attr.data_op_engine_type), DEFAULT_FLAG, &gva); @@ -106,21 +110,30 @@ int shmemi_init_mf::reserve_heap(shmemi_device_host_state_t &g_state) SHM_LOG_ERROR("smem_shm_create Failed"); return SHMEM_SMEM_ERROR; } - g_state.heap_base = (void *)((uintptr_t)gva + g_state.heap_size * attributes->my_rank); + g_state->heap_base = (void *)((uintptr_t)gva + g_state->heap_size * attributes->my_rank); + + SHMEM_CHECK_RET(aclrtMallocHost((void **)&g_state->host_p2p_heap_base, g_state->npes * sizeof(void *))); + SHMEM_CHECK_RET(aclrtMallocHost((void **)&g_state->host_rdma_heap_base, g_state->npes * sizeof(void *))); + SHMEM_CHECK_RET(aclrtMallocHost((void **)&g_state->host_sdma_heap_base, g_state->npes * sizeof(void *))); + + SHMEM_CHECK_RET(aclrtMalloc((void **)&g_state->device_p2p_heap_base, g_state->npes * sizeof(void *), ACL_MEM_MALLOC_HUGE_FIRST)); + SHMEM_CHECK_RET(aclrtMalloc((void **)&g_state->device_rdma_heap_base, g_state->npes * sizeof(void *), ACL_MEM_MALLOC_HUGE_FIRST)); + SHMEM_CHECK_RET(aclrtMalloc((void **)&g_state->device_sdma_heap_base, g_state->npes * sizeof(void *), ACL_MEM_MALLOC_HUGE_FIRST)); + uint32_t reach_info = 0; - for (int32_t i = 0; i < g_state.npes; i++) { + for (int32_t i = 0; i < g_state->npes; i++) { status = smem_shm_topology_can_reach(g_smem_handle, i, &reach_info); - g_state.p2p_heap_base[i] = (void *)((uintptr_t)gva + g_state.heap_size * i); + g_state->host_p2p_heap_base[i] = (void *)((uintptr_t)gva + g_state->heap_size * i); if (reach_info & SMEMS_DATA_OP_MTE) { - g_state.topo_list[i] |= SHMEM_TRANSPORT_MTE; + g_state->topo_list[i] |= SHMEM_TRANSPORT_MTE; } if (reach_info & SMEMS_DATA_OP_SDMA) { - g_state.sdma_heap_base[i] = (void *)((uintptr_t)gva + g_state.heap_size * i); + g_state->host_sdma_heap_base[i] = (void *)((uintptr_t)gva + g_state->heap_size * i); } else { - g_state.sdma_heap_base[i] = NULL; + g_state->host_sdma_heap_base[i] = NULL; } if (reach_info & SMEMS_DATA_OP_RDMA) { - g_state.topo_list[i] |= SHMEM_TRANSPORT_ROCE; + g_state->topo_list[i] |= SHMEM_TRANSPORT_ROCE; } } if (g_ipport[0] != '\0') { @@ -130,11 +143,11 @@ int shmemi_init_mf::reserve_heap(shmemi_device_host_state_t &g_state) SHM_LOG_WARN("my_rank:" << attributes->my_rank << " g_ipport is released in advance!"); bzero(attributes->ip_port, sizeof(attributes->ip_port)); } - g_state.is_shmem_created = true; + g_state->is_shmem_created = true; return status; } -int shmemi_init_mf::setup_heap(shmemi_device_host_state_t &g_state) +int shmemi_init_mf::setup_heap() { int32_t status = SHMEM_SUCCESS; return status; @@ -148,6 +161,25 @@ int shmemi_init_mf::remove_heap() int shmemi_init_mf::release_heap() { + if (g_state->host_p2p_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFreeHost(g_state->host_p2p_heap_base)); + } + if (g_state->host_p2p_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFreeHost(g_state->host_rdma_heap_base)); + } + if (g_state->host_p2p_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFreeHost(g_state->host_sdma_heap_base)); + } + if (g_state->device_p2p_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFree(g_state->device_p2p_heap_base)); + } + if (g_state->device_rdma_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFree(g_state->device_rdma_heap_base)); + } + if (g_state->device_sdma_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFree(g_state->device_sdma_heap_base)); + } + if (g_smem_handle != nullptr) { int32_t status = smem_shm_destroy(g_smem_handle, 0); if (status != SHMEM_SUCCESS) { @@ -161,7 +193,7 @@ int shmemi_init_mf::release_heap() return SHMEM_SUCCESS; } -int shmemi_init_mf::transport_init(shmemi_device_host_state_t &g_state) +int shmemi_init_mf::transport_init() { return SHMEM_SUCCESS; } diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.h b/src/host/init/init_backends/mf/shmemi_init_mf.h index 4ddb5c16..830ea943 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.h +++ b/src/host/init/init_backends/mf/shmemi_init_mf.h @@ -25,25 +25,26 @@ class shmemi_init_mf: public shmemi_init_base { public: - shmemi_init_mf(shmem_init_attr_t *attr, char *ipport); + shmemi_init_mf(shmem_init_attr_t *attr, char *ipport, shmemi_device_host_state_t *g_state); ~shmemi_init_mf(); int init_device_state() override; int finalize_device_state() override; int update_device_state(void* host_ptr, size_t size) override; - int reserve_heap(shmemi_device_host_state_t &g_state) override; - int setup_heap(shmemi_device_host_state_t &g_state) override; + int reserve_heap() override; + int setup_heap() override; int remove_heap() override; int release_heap() override; - int transport_init(shmemi_device_host_state_t &g_state) override; + int transport_init() override; int transport_finalize() override; private: int32_t device_id; shmem_init_attr_t *attributes; char *g_ipport = nullptr; + shmemi_device_host_state_t *g_state; }; int32_t shmem_get_uniqueid_mf(shmemx_uniqueid_t *uid); diff --git a/src/host/init/init_backends/shmemi_init_base.h b/src/host/init/init_backends/shmemi_init_base.h index 4b8cdbb7..f1599cb8 100644 --- a/src/host/init/init_backends/shmemi_init_base.h +++ b/src/host/init/init_backends/shmemi_init_base.h @@ -21,16 +21,15 @@ public: virtual int finalize_device_state() = 0; virtual int update_device_state(void* host_ptr, size_t size) = 0; - virtual int reserve_heap(shmemi_device_host_state_t &g_state) = 0; - virtual int setup_heap(shmemi_device_host_state_t &g_state) = 0; + virtual int reserve_heap() = 0; + virtual int setup_heap() = 0; virtual int remove_heap() = 0; virtual int release_heap() = 0; - virtual int transport_init(shmemi_device_host_state_t &g_state) = 0; + virtual int transport_init() = 0; virtual int transport_finalize() = 0; virtual ~shmemi_init_base() {} - }; #endif // SHMEMI_INIT_BASE_H \ No newline at end of file diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index bb0899e6..b624ff93 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -39,15 +39,12 @@ constexpr int DEFAULT_BLOCK_NUM = 1; (DEFAULT_MY_PE), /* mype */ \ (DEFAULT_N_PES), /* npes */ \ NULL, /* heap_base */ \ - NULL, /* p2p_heap_host_base */ \ - NULL, /* sdma_heap_host_base */ \ - NULL, /* roce_heap_host_base */ \ + NULL, /* host_p2p_heap_base */ \ + NULL, /* host_rdma_heap_base */ \ + NULL, /* host_sdma_heap_base */ \ NULL, /* p2p_heap_device_base */ \ NULL, /* sdma_heap_device_base */ \ NULL, /* roce_heap_device_base */ \ - {NULL}, /* p2p_heap_base */ \ - {NULL}, /* rdma_heap_base */ \ - {NULL}, /* sdma_heap_base */ \ {}, /* topo_list */ \ SIZE_MAX, /* heap_size */ \ {NULL}, /* team_pools */ \ @@ -276,17 +273,17 @@ int32_t shmem_init_attr(shmemx_bootstrap_t bootstrap_flags, shmem_init_attr_t *a // shmem basic init #ifdef BACKEND_MF - init_manager = new shmemi_init_mf(attributes, g_ipport); + init_manager = new shmemi_init_mf(attributes, g_ipport, &g_state); #else // bootstrap init SHMEM_CHECK_RET(shmemi_bootstrap_init(bootstrap_flags, attributes)); - init_manager = new shmemi_init_default(attributes); + init_manager = new shmemi_init_default(attributes, &g_state); #endif SHMEM_CHECK_RET(shmemi_state_init_attr(attributes)); SHMEM_CHECK_RET(init_manager->init_device_state()); - SHMEM_CHECK_RET(init_manager->reserve_heap(g_state)); - SHMEM_CHECK_RET(init_manager->transport_init(g_state)); - SHMEM_CHECK_RET(init_manager->setup_heap(g_state)); + SHMEM_CHECK_RET(init_manager->reserve_heap()); + SHMEM_CHECK_RET(init_manager->transport_init()); + SHMEM_CHECK_RET(init_manager->setup_heap()); // shmem submodules init SHMEM_CHECK_RET(memory_manager_initialize(g_state.heap_base, g_state.heap_size)); @@ -300,10 +297,21 @@ int32_t shmem_init_attr(shmemx_bootstrap_t bootstrap_flags, shmem_init_attr_t *a int32_t shmem_finalize() { + // shmem submodules finalize SHMEM_CHECK_RET(shmemi_team_finalize()); + + // shmem basic finalize + SHMEM_CHECK_RET(init_manager->remove_heap()); + SHMEM_CHECK_RET(init_manager->transport_finalize()); + SHMEM_CHECK_RET(init_manager->release_heap()); + SHMEM_CHECK_RET(init_manager->finalize_device_state()); delete init_manager; +#ifdef BACKEND_MF + +#else shmemi_bootstrap_finalize(); +#endif return SHMEM_SUCCESS; } diff --git a/src/host/mem/shmem_rma.cpp b/src/host/mem/shmem_rma.cpp index 2491705c..ec670f16 100644 --- a/src/host/mem/shmem_rma.cpp +++ b/src/host/mem/shmem_rma.cpp @@ -29,13 +29,13 @@ void *shmem_ptr(void *ptr, int32_t pe) } uint64_t offset = (uint64_t)ptr - (uint64_t)g_state.heap_base; - void *symm_ptr = g_state.p2p_heap_host_base[pe]; + void *symm_ptr = g_state.host_p2p_heap_base[pe]; if (symm_ptr != nullptr) { symm_ptr = reinterpret_cast(reinterpret_cast(symm_ptr) + offset); return symm_ptr; } SHM_LOG_ERROR("shmem_ptr Failed. PE: " << shmem_my_pe() - << " g_state.p2p_heap_host_base contains nullptr, Please Check Init Status!!"); + << " g_state.host_p2p_heap_base contains nullptr, Please Check Init Status!!"); return nullptr; } diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp index c9738a1d..d7541882 100644 --- a/src/host/transport/shmemi_transport.cpp +++ b/src/host/transport/shmemi_transport.cpp @@ -19,11 +19,11 @@ uint64_t *host_hash_list; shmemi_host_trans_state_t g_host_state; -int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state, shmem_init_optional_attr_t& option_attr) { +int32_t shmemi_transport_init(shmemi_device_host_state_t *g_state, shmem_init_optional_attr_t& option_attr) { // Initialize MTE by default g_host_state.num_choosen_transport = 1; - g_host_state.transport_map = (int *)calloc(g_state.npes * g_state.npes, sizeof(int)); - g_host_state.pe_info = (shmemi_transport_pe_info *)calloc(g_state.npes, sizeof(shmemi_transport_pe_info)); + g_host_state.transport_map = (int *)calloc(g_state->npes * g_state->npes, sizeof(int)); + g_host_state.pe_info = (shmemi_transport_pe_info *)calloc(g_state->npes, sizeof(shmemi_transport_pe_info)); transport_mte_lib = dlopen("shmem_transport_mte.so", RTLD_NOW); if (!transport_mte_lib) { @@ -45,13 +45,13 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state, shmem_init_op SHMEM_CHECK_RET(aclrtGetDevice(&device_id)); shmemi_transport_pe_info_t my_info; - my_info.pe = g_state.mype; + my_info.pe = g_state->mype; my_info.dev_id = device_id; - my_info.host_hash = g_state.host_hash; + my_info.host_hash = g_state->host_hash; // AllGather All pe's host info g_boot_handle.allgather((void *)&my_info, g_host_state.pe_info, sizeof(shmemi_transport_pe_info_t), &g_boot_handle); - SHMEM_CHECK_RET(init_mte_fn(&g_host_state.choosen_transports[0], &g_state)); + SHMEM_CHECK_RET(init_mte_fn(&g_host_state.choosen_transports[0], g_state)); // If enable RDMA if (option_attr.data_op_engine_type & SHMEM_DATA_OP_ROCE) { @@ -79,15 +79,15 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state, shmem_init_op SHM_LOG_ERROR("Unable to get info from " << "shmem_transport_rdma.so" << "."); return SHMEM_INVALID_VALUE; } - SHMEM_CHECK_RET(init_rdma_fn(&g_host_state.choosen_transports[1], &g_state)); + SHMEM_CHECK_RET(init_rdma_fn(&g_host_state.choosen_transports[1], g_state)); } return SHMEM_SUCCESS; } -int32_t shmemi_build_transport_map(shmemi_device_host_state_t &g_state) { +int32_t shmemi_build_transport_map(shmemi_device_host_state_t *g_state) { int *local_map = NULL; - local_map = (int *)calloc(g_state.npes, sizeof(int)); + local_map = (int *)calloc(g_state->npes, sizeof(int)); shmemi_transport_t t; @@ -95,10 +95,10 @@ int32_t shmemi_build_transport_map(shmemi_device_host_state_t &g_state) { for (int j = 0; j < g_host_state.num_choosen_transport; j++) { t = g_host_state.choosen_transports[j]; - for (int i = 0; i < g_state.npes; i++) { + for (int i = 0; i < g_state->npes; i++) { int reach = 0; - SHMEM_CHECK_RET(t.can_access_peer(&reach, g_host_state.pe_info + i, &t, &g_state)); + SHMEM_CHECK_RET(t.can_access_peer(&reach, g_host_state.pe_info + i, &t, g_state)); if (reach) { int m = 1 << j; @@ -107,50 +107,50 @@ int32_t shmemi_build_transport_map(shmemi_device_host_state_t &g_state) { } } - for (int i = 0; i < g_state.npes; i++) { - g_state.topo_list[i] = static_cast(local_map[i]); + for (int i = 0; i < g_state->npes; i++) { + g_state->topo_list[i] = static_cast(local_map[i]); } - g_boot_handle.allgather(local_map, g_host_state.transport_map, g_state.npes * sizeof(int), &g_boot_handle); + g_boot_handle.allgather(local_map, g_host_state.transport_map, g_state->npes * sizeof(int), &g_boot_handle); if (local_map) free(local_map); return SHMEM_SUCCESS; } -int32_t shmemi_transport_setup_connections(shmemi_device_host_state_t &g_state) { +int32_t shmemi_transport_setup_connections(shmemi_device_host_state_t *g_state) { shmemi_transport_t t; // MTE t = g_host_state.choosen_transports[0]; int *mte_peer_list; int mte_peer_num = 0; - mte_peer_list = (int *)calloc(g_state.npes, sizeof(int)); + mte_peer_list = (int *)calloc(g_state->npes, sizeof(int)); - int local_offset = g_state.mype * g_state.npes; - for (int i = 0; i < g_state.npes; i++) { - if (i == g_state.mype) + int local_offset = g_state->mype * g_state->npes; + for (int i = 0; i < g_state->npes; i++) { + if (i == g_state->mype) continue; /* Check if MTE connected. */ if (g_host_state.transport_map[local_offset + i] & 0x1) { shmemi_transport_pe_info_t *peer_info = (g_host_state.pe_info + i); // Only PEs in the same Node need to build up MTE connection. - if (g_state.host_hash == peer_info->host_hash) { + if (g_state->host_hash == peer_info->host_hash) { mte_peer_list[mte_peer_num] = peer_info->dev_id; ++mte_peer_num; } } } - t.connect_peers(&t, mte_peer_list, mte_peer_num, &g_state); + t.connect_peers(&t, mte_peer_list, mte_peer_num, g_state); if (g_host_state.num_choosen_transport > 1) { int *rdma_peer_list; int rdma_peer_num = 0; - rdma_peer_list = (int *)calloc(g_state.npes, sizeof(int)); + rdma_peer_list = (int *)calloc(g_state->npes, sizeof(int)); - int local_offset = g_state.mype * g_state.npes; - for (int i = 0; i < g_state.npes; i++) { - if (i == g_state.mype) + int local_offset = g_state->mype * g_state->npes; + for (int i = 0; i < g_state->npes; i++) { + if (i == g_state->mype) continue; if (g_host_state.transport_map[local_offset + i] & 2) { shmemi_transport_pe_info_t *peer_info = (g_host_state.pe_info + i); @@ -159,17 +159,17 @@ int32_t shmemi_transport_setup_connections(shmemi_device_host_state_t &g_state) } } t = g_host_state.choosen_transports[1]; - t.connect_peers(&t, rdma_peer_list, rdma_peer_num, &g_state); + t.connect_peers(&t, rdma_peer_list, rdma_peer_num, g_state); } return 0; } -int32_t shmemi_transport_finalize() { +int32_t shmemi_transport_finalize(shmemi_device_host_state_t *g_state) { shmemi_transport_t t; // MTE t = g_host_state.choosen_transports[0]; - t.finalize(&t, &g_state); + t.finalize(&t, g_state); if (transport_mte_lib != NULL) { dlclose(transport_mte_lib); @@ -178,7 +178,7 @@ int32_t shmemi_transport_finalize() { if (g_host_state.num_choosen_transport > 1) { t = g_host_state.choosen_transports[1]; - t.finalize(&t, &g_state); + t.finalize(&t, g_state); if (transport_rdma_lib != NULL) { dlclose(transport_rdma_lib); diff --git a/src/host/transport/shmemi_transport.h b/src/host/transport/shmemi_transport.h index b958c908..e89bfaf1 100644 --- a/src/host/transport/shmemi_transport.h +++ b/src/host/transport/shmemi_transport.h @@ -12,12 +12,12 @@ typedef int(*transport_init_func)(shmemi_transport_t *transport, shmemi_device_host_state_t *g_state); -int32_t shmemi_transport_init(shmemi_device_host_state_t &g_state, shmem_init_optional_attr_t &option_attr); +int32_t shmemi_transport_init(shmemi_device_host_state_t *g_state, shmem_init_optional_attr_t &option_attr); -int32_t shmemi_build_transport_map(shmemi_device_host_state_t &g_state); +int32_t shmemi_build_transport_map(shmemi_device_host_state_t *g_state); -int32_t shmemi_transport_setup_connections(shmemi_device_host_state_t &g_state); +int32_t shmemi_transport_setup_connections(shmemi_device_host_state_t *g_state); -int32_t shmemi_transport_finalize(); +int32_t shmemi_transport_finalize(shmemi_device_host_state_t *g_state); #endif // SHMEMI_TRANSPORT_H \ No newline at end of file diff --git a/src/modules/transport/shmemi_rdma.cpp b/src/modules/transport/shmemi_rdma.cpp index 7a711143..476ed633 100644 --- a/src/modules/transport/shmemi_rdma.cpp +++ b/src/modules/transport/shmemi_rdma.cpp @@ -46,7 +46,7 @@ int shmemi_rdma_connect_peers(shmemi_transport *t, int *selected_dev_ids, int nu std::vector mrs(state->npes); g_boot_handle.allgather(&local_mr, mrs.data(), sizeof(RegMemResult), &g_boot_handle); for (int i = 0; i < state->npes; i++) { - state->rdma_heap_base[i] = reinterpret_cast(mrs[i].address); + state->host_rdma_heap_base[i] = reinterpret_cast(mrs[i].address); SHM_LOG_INFO("get rank " << i << ", mr info = " << mrs[i]); } diff --git a/tests/unittest/CMakeLists.txt b/tests/unittest/CMakeLists.txt index cd370614..60bbec87 100644 --- a/tests/unittest/CMakeLists.txt +++ b/tests/unittest/CMakeLists.txt @@ -19,4 +19,4 @@ target_link_directories(shmem_unittest PRIVATE ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib ${PROJECT_SOURCE_DIR}/3rdparty/googletest/lib ) -target_link_libraries(shmem_unittest PRIVATE shmem_unittest_device gtest gcov mf_smem shmem_unittest_include) +target_link_libraries(shmem_unittest PRIVATE shmem_unittest_device gtest gcov mf_smem shmem_unittest_include shmem) -- Gitee From 3794f72207203f031d86d50fa8fc9ee6b0079444 Mon Sep 17 00:00:00 2001 From: zhangyunqi05 Date: Tue, 2 Dec 2025 20:39:24 +0800 Subject: [PATCH 65/74] Adjust the log module && fix ut --- scripts/run.sh | 1 + src/host/bootstrap/shmemi_bootstrap.cpp | 3 +- src/host/common/shmemi_host_types.h | 1 + src/host/common/shmemi_logger.h | 84 ++++++++----------- .../default/shmemi_init_default.cpp | 6 ++ .../default/shmemi_init_default.h | 2 + .../init/init_backends/mf/shmemi_init_mf.cpp | 3 + src/host/init/shmem_init.cpp | 17 ++-- src/host/mem/shmemi_heap.cpp | 3 + src/host/team/shmem_team.cpp | 8 +- src/host/transport/shmemi_transport.cpp | 3 +- .../bootstrap/shmemi_bootstrap_uid.cpp | 1 + src/modules/transport/shmemi_rdma.cpp | 1 + tests/unittest/host/init/init_host_test.cpp | 3 +- tests/unittest/host/main_test.cpp | 2 + .../host/mem/shmem_host_heap_test.cpp | 13 ++- .../unittest/host/mem/shmem_ptr_host_test.cpp | 2 +- .../host/team/team/team_host_test.cpp | 2 +- 18 files changed, 80 insertions(+), 75 deletions(-) diff --git a/scripts/run.sh b/scripts/run.sh index 7db18eab..1216d002 100644 --- a/scripts/run.sh +++ b/scripts/run.sh @@ -129,6 +129,7 @@ export SMEM_CONF_STORE_TLS_ENABLE=0 export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/install/memfabric_hybrid/lib:${ASCEND_HOME_PATH}/lib64:$LD_LIBRARY_PATH # Run unit test +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 cd "$BUILD_PATH" ./bin/shmem_unittest "$RANK_SIZE" "$IPPORT" "$GNPU_NUM" "$FIRST_RANK" "$FIRST_NPU" --gtest_output=xml:test_detail.xml --gtest_filter=${TEST_FILTER} diff --git a/src/host/bootstrap/shmemi_bootstrap.cpp b/src/host/bootstrap/shmemi_bootstrap.cpp index 9ff0c5fa..9f2dc10b 100644 --- a/src/host/bootstrap/shmemi_bootstrap.cpp +++ b/src/host/bootstrap/shmemi_bootstrap.cpp @@ -131,11 +131,12 @@ int32_t shmemi_bootstrap_init(int flags, shmem_init_attr_t *attr) { shmemi_bootstrap_free(); return SHMEM_INNER_ERROR; } + g_boot_handle.is_bootstraped = true; return status; } void shmemi_bootstrap_finalize() { g_boot_handle.finalize(&g_boot_handle); - + g_boot_handle.is_bootstraped = false; dlclose(plugin_hdl); } diff --git a/src/host/common/shmemi_host_types.h b/src/host/common/shmemi_host_types.h index 280855b1..98073b0d 100644 --- a/src/host/common/shmemi_host_types.h +++ b/src/host/common/shmemi_host_types.h @@ -39,6 +39,7 @@ typedef struct shmemi_bootstrap_handle { int (*alltoall)(const void *sendbuf, void *recvbuf, int size, shmemi_bootstrap_handle *boot_handle); void (*global_exit)(int status); shmemi_bootstrap_init_ops_t *pre_init_ops; + bool is_bootstraped = false; } shmemi_bootstrap_handle_t; typedef struct shmemi_bootstrap_mpi_options { diff --git a/src/host/common/shmemi_logger.h b/src/host/common/shmemi_logger.h index ae5d7e94..0dae3da3 100644 --- a/src/host/common/shmemi_logger.h +++ b/src/host/common/shmemi_logger.h @@ -118,11 +118,11 @@ private: #ifndef SHM_LOG_FILENAME_SHORT #define SHM_LOG_FILENAME_SHORT (strrchr(__FILE__, '/') ? strrchr(__FILE__, '/') + 1 : __FILE__) #endif -#define SHM_OUT_LOG(LEVEL, ARGS) \ - do { \ - std::ostringstream oss; \ - oss << "[SHM_SHMEM " << SHM_LOG_FILENAME_SHORT << ":" << __LINE__ << "] " << ARGS; \ - shm::shm_out_logger::Instance().log(LEVEL, oss); \ +#define SHM_OUT_LOG(LEVEL, ARGS) \ + do { \ + std::ostringstream oss; \ + oss << "[SHM_SHMEM " << SHM_LOG_FILENAME_SHORT << ":" << __LINE__ << "] " << ARGS; \ + shm::shm_out_logger::Instance().log(LEVEL, oss); \ } while (0) #define SHM_LOG_DEBUG(ARGS) SHM_OUT_LOG(shm::DEBUG_LEVEL, ARGS) @@ -172,61 +172,43 @@ private: } \ } while (0) -#define SHMEM_CHECK_RET(x, ...) \ - do \ - { \ - int32_t check_ret = x; \ - if (check_ret != 0) \ - { \ - if (sizeof(#__VA_ARGS__) > 1) \ - { \ - SHM_LOG_ERROR(" return shmem error: " << check_ret << " - " << #__VA_ARGS__ << " failed. More error information can be found in plog"); \ - } \ - else \ - { \ - SHM_LOG_ERROR(" return shmem error: " << check_ret); \ - } \ - return check_ret; \ - } \ - } while (0) - -#define SHMEM_CHECK(x) \ - do { \ - int32_t check_ret = x; \ - if (check_ret != 0) { \ - SHM_LOG_ERROR(" return shmem error: " << check_ret); \ - return ; \ - } \ +#define SHMEM_CHECK(x) \ + do { \ + int32_t check_ret = x; \ + if (check_ret != 0) { \ + SHM_LOG_ERROR(" return shmem error: " << check_ret << " - " << #x << " failed."); \ + return ; \ + } \ } while (0) #define SHMEM_CHECK_RET(...) \ _SHMEM_CHECK_RET_HELPER(__VA_ARGS__, _SHMEM_CHECK_RET_WITH_LOG_AND_ERR_CODE, _SHMEM_CHECK_RET_WITH_LOG, _SHMEM_CHECK_RET)(__VA_ARGS__) -#define _SHMEM_CHECK_RET(x) \ - do { \ - int32_t check_ret = (x); \ - if (check_ret != 0) { \ - SHM_LOG_ERROR(" return shmem error: " << check_ret); \ - return check_ret; \ - } \ +#define _SHMEM_CHECK_RET(x) \ + do { \ + int32_t check_ret = (x); \ + if (check_ret != 0) { \ + SHM_LOG_ERROR(" return shmem error: " << check_ret << " - " << #x << " failed."); \ + return check_ret; \ + } \ } while (0) -#define _SHMEM_CHECK_RET_WITH_LOG(x, log_str) \ - do { \ - int32_t check_ret = (x); \ - if (check_ret != 0) { \ - SHM_LOG_ERROR(" " << log_str << " return shmem error: " << check_ret); \ - return check_ret; \ - } \ +#define _SHMEM_CHECK_RET_WITH_LOG(x, LOG_STR) \ + do { \ + int32_t check_ret = (x); \ + if (check_ret != 0) { \ + SHM_LOG_ERROR(" " << LOG_STR << " return shmem error: " << check_ret); \ + return check_ret; \ + } \ } while (0) -#define _SHMEM_CHECK_RET_WITH_LOG_AND_ERR_CODE(x, log_str, error_code) \ - do { \ - int32_t check_ret = (x); \ - if (check_ret != 0) { \ - SHM_LOG_ERROR(" " << log_str << " return shmem error: " << error_code); \ - return error_code; \ - } \ +#define _SHMEM_CHECK_RET_WITH_LOG_AND_ERR_CODE(x, LOG_STR, ERR_CODE) \ + do { \ + int32_t check_ret = (x); \ + if (check_ret != 0) { \ + SHM_LOG_ERROR(" " << LOG_STR << " return shmem error: " << ERR_CODE); \ + return ERR_CODE; \ + } \ } while (0) #define _SHMEM_CHECK_RET_HELPER(_1, _2, _3, FUNC, ...) FUNC diff --git a/src/host/init/init_backends/default/shmemi_init_default.cpp b/src/host/init/init_backends/default/shmemi_init_default.cpp index b92fc580..7241aad9 100644 --- a/src/host/init/init_backends/default/shmemi_init_default.cpp +++ b/src/host/init/init_backends/default/shmemi_init_default.cpp @@ -124,4 +124,10 @@ int shmemi_init_default::transport_finalize() { SHMEM_CHECK_RET(shmemi_transport_finalize(g_state)); return SHMEM_SUCCESS; +} + +int32_t shmemi_control_barrier_all_default(shmemi_bootstrap_handle_t boot_handle) +{ + SHMEM_CHECK_RET((boot_handle.is_bootstraped != true), "boot_handle not bootstraped, Please check if the method call occurs before initialization or after finalization.", SHMEM_BOOTSTRAP_ERROR); + return boot_handle.barrier(&boot_handle); } \ No newline at end of file diff --git a/src/host/init/init_backends/default/shmemi_init_default.h b/src/host/init/init_backends/default/shmemi_init_default.h index 07fadb86..16b29451 100644 --- a/src/host/init/init_backends/default/shmemi_init_default.h +++ b/src/host/init/init_backends/default/shmemi_init_default.h @@ -55,4 +55,6 @@ private: shmem_init_optional_attr_t option_attr_; }; +int32_t shmemi_control_barrier_all_default(shmemi_bootstrap_handle_t boot_handle); + #endif // SHMEMI_INIT_NORMAL_H \ No newline at end of file diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.cpp b/src/host/init/init_backends/mf/shmemi_init_mf.cpp index aa30eccb..72a75952 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.cpp +++ b/src/host/init/init_backends/mf/shmemi_init_mf.cpp @@ -123,6 +123,9 @@ int shmemi_init_mf::reserve_heap() uint32_t reach_info = 0; for (int32_t i = 0; i < g_state->npes; i++) { status = smem_shm_topology_can_reach(g_smem_handle, i, &reach_info); + if (status != SHMEM_SUCCESS) { + SHM_LOG_ERROR("smem_shm_topology_can_reach failed"); + } g_state->host_p2p_heap_base[i] = (void *)((uintptr_t)gva + g_state->heap_size * i); if (reach_info & SMEMS_DATA_OP_MTE) { g_state->topo_list[i] |= SHMEM_TRANSPORT_MTE; diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index b624ff93..a62f81ba 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -111,7 +111,7 @@ int32_t shmemi_state_init_attr(shmem_init_attr_t *attributes) g_state.host_hash = shmemi_get_host_hash(); aclrtStream stream = nullptr; - SHMEM_CHECK_RET(aclrtCreateStream(&stream), aclrtCreateStream); + SHMEM_CHECK_RET(aclrtCreateStream(&stream)); g_state_host.default_stream = stream; g_state_host.default_event_id = DEFAULT_TEVENT; g_state_host.default_block_num = DEFAULT_BLOCK_NUM; @@ -144,7 +144,7 @@ int32_t shmemi_control_barrier_all() #ifdef BACKEND_MF return shmemi_control_barrier_all_mf(); #else - return g_boot_handle.barrier(&g_boot_handle); + return shmemi_control_barrier_all_default(g_boot_handle); #endif } @@ -263,18 +263,20 @@ int shmemx_set_attr_uniqueid_args(const int my_rank, const int n_ranks, const in int32_t shmem_init_attr(shmemx_bootstrap_t bootstrap_flags, shmem_init_attr_t *attributes) { int32_t ret; - shmem_set_log_level(shm::ERROR_LEVEL); + SHMEM_CHECK_RET(shmem_set_log_level(shm::ERROR_LEVEL)); // config init SHM_ASSERT_RETURN(attributes != nullptr, SHMEM_INVALID_PARAM); - SHMEM_CHECK_RET(check_attr(attributes)); - SHMEM_CHECK_RET(version_compatible()); + SHMEM_CHECK_RET(check_attr(attributes), "An error occurred while checking the initialization attributes. Please check the initialization parameters."); + SHMEM_CHECK_RET(version_compatible(), "SHMEM Version mismatch."); SHMEM_CHECK_RET(shmemi_options_init()); // shmem basic init #ifdef BACKEND_MF + SHM_LOG_INFO("The current backend is MF."); init_manager = new shmemi_init_mf(attributes, g_ipport, &g_state); #else + SHM_LOG_INFO("The current backend is SHMEM default."); // bootstrap init SHMEM_CHECK_RET(shmemi_bootstrap_init(bootstrap_flags, attributes)); init_manager = new shmemi_init_default(attributes, &g_state); @@ -292,11 +294,13 @@ int32_t shmem_init_attr(shmemx_bootstrap_t bootstrap_flags, shmem_init_attr_t *a g_state.is_shmem_initialized = true; SHMEM_CHECK_RET(update_device_state()); SHMEM_CHECK_RET(shmemi_control_barrier_all()); + SHM_LOG_INFO("SHMEM init success."); return SHMEM_SUCCESS; } int32_t shmem_finalize() { + SHM_LOG_INFO("The pe: " << shmem_my_pe() << " begins to finalize."); // shmem submodules finalize SHMEM_CHECK_RET(shmemi_team_finalize()); @@ -312,6 +316,7 @@ int32_t shmem_finalize() #else shmemi_bootstrap_finalize(); #endif + SHM_LOG_INFO("The pe: " << shmem_my_pe() << " finalize success."); return SHMEM_SUCCESS; } @@ -363,7 +368,7 @@ int32_t shmem_get_uniqueid(shmemx_uniqueid_t *uid){ } int32_t shmemi_get_uniqueid_static_magic(shmemx_uniqueid_t *uid, bool is_root) { - shmem_set_log_level(shm::INFO_LEVEL); + shmem_set_log_level(shm::ERROR_LEVEL); int status = 0; SHMEM_CHECK_RET(shmemi_options_init(), "Bootstrap failed during the preloading step."); SHMEM_CHECK_RET(shmemi_bootstrap_pre_init(SHMEMX_INIT_WITH_UNIQUEID, &g_boot_handle), "Get uniqueid failed during the bootstrap preloading step."); diff --git a/src/host/mem/shmemi_heap.cpp b/src/host/mem/shmemi_heap.cpp index 9c30b860..d5c17566 100644 --- a/src/host/mem/shmemi_heap.cpp +++ b/src/host/mem/shmemi_heap.cpp @@ -73,6 +73,7 @@ int shmem_symmetric_heap::export_pid() int shmem_symmetric_heap::import_pid() { // Get all pids + SHMEM_CHECK_RET((g_boot_handle.is_bootstraped != true), "boot_handle not bootstraped, Please check if the method call occurs before initialization or after finalization.", SHMEM_BOOTSTRAP_ERROR); g_boot_handle.allgather(&my_pid, pid_list.data(), 1 * sizeof(int), &g_boot_handle); // Get all sdids @@ -92,6 +93,7 @@ int shmem_symmetric_heap::import_pid() int shmem_symmetric_heap::import_memory() { + SHMEM_CHECK_RET((g_boot_handle.is_bootstraped != true), "boot_handle not bootstraped, Please check if the method call occurs before initialization or after finalization.", SHMEM_BOOTSTRAP_ERROR); g_boot_handle.allgather(memory_name.c_str(), names, IPC_NAME_SIZE, &g_boot_handle); static std::mutex mut; @@ -129,6 +131,7 @@ int shmem_symmetric_heap::remove_heap() } // This barrier is necessary, otherwise Unmap will fail. + SHMEM_CHECK_RET((g_boot_handle.is_bootstraped != true), "boot_handle not bootstraped, Please check if the method call occurs before initialization or after finalization.", SHMEM_BOOTSTRAP_ERROR); g_boot_handle.barrier(&g_boot_handle); SHMEM_CHECK_RET(rtIpcDestroyMemoryName(memory_name.c_str())); diff --git a/src/host/team/shmem_team.cpp b/src/host/team/shmem_team.cpp index 4e06a269..00d28e95 100644 --- a/src/host/team/shmem_team.cpp +++ b/src/host/team/shmem_team.cpp @@ -57,12 +57,12 @@ inline int32_t device_team_update(int team_idx, shmemi_team_t *host_team_ptr) { // device_ptr Malloc void *team_ptr = nullptr; - SHMEM_CHECK_RET(aclrtMalloc(&team_ptr, sizeof(shmemi_team_t), ACL_MEM_MALLOC_NORMAL_ONLY), aclrtMalloc); + SHMEM_CHECK_RET(aclrtMalloc(&team_ptr, sizeof(shmemi_team_t), ACL_MEM_MALLOC_NORMAL_ONLY)); auto ret = aclrtMemcpy((shmemi_team_t *)team_ptr, sizeof(shmemi_team_t), host_team_ptr, sizeof(shmemi_team_t), ACL_MEMCPY_HOST_TO_DEVICE); if (ret != 0) { SHM_LOG_ERROR("memcpy device team info failed, ret: " << ret); - SHMEM_CHECK_RET(aclrtFree(team_ptr), aclrtFree); + SHMEM_CHECK_RET(aclrtFree(team_ptr)); return SHMEM_INNER_ERROR; } g_state.team_pools[team_idx] = (shmemi_team_t *)team_ptr; @@ -210,11 +210,11 @@ int32_t shmemi_team_finalize() g_state.sync_pool = 0; } if (g_state.core_sync_counter != 0) { - SHMEM_CHECK_RET(aclrtFree(reinterpret_cast(g_state.core_sync_counter)), aclrtFree); + SHMEM_CHECK_RET(aclrtFree(reinterpret_cast(g_state.core_sync_counter))); g_state.core_sync_counter = 0; } if (g_state.core_sync_pool != 0) { - SHMEM_CHECK_RET(aclrtFree(reinterpret_cast(g_state.core_sync_pool)), aclrtFree); + SHMEM_CHECK_RET(aclrtFree(reinterpret_cast(g_state.core_sync_pool))); g_state.core_sync_pool = 0; } if (g_shmem_team_pool != nullptr) { diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp index d7541882..2192a249 100644 --- a/src/host/transport/shmemi_transport.cpp +++ b/src/host/transport/shmemi_transport.cpp @@ -50,6 +50,7 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t *g_state, shmem_init_op my_info.host_hash = g_state->host_hash; // AllGather All pe's host info + SHMEM_CHECK_RET((g_boot_handle.is_bootstraped != true), "boot_handle not bootstraped, Please check if the method call occurs before initialization or after finalization.", SHMEM_BOOTSTRAP_ERROR); g_boot_handle.allgather((void *)&my_info, g_host_state.pe_info, sizeof(shmemi_transport_pe_info_t), &g_boot_handle); SHMEM_CHECK_RET(init_mte_fn(&g_host_state.choosen_transports[0], g_state)); @@ -110,7 +111,7 @@ int32_t shmemi_build_transport_map(shmemi_device_host_state_t *g_state) { for (int i = 0; i < g_state->npes; i++) { g_state->topo_list[i] = static_cast(local_map[i]); } - + SHMEM_CHECK_RET((g_boot_handle.is_bootstraped != true), "boot_handle not bootstraped, Please check if the method call occurs before initialization or after finalization.", SHMEM_BOOTSTRAP_ERROR); g_boot_handle.allgather(local_map, g_host_state.transport_map, g_state->npes * sizeof(int), &g_boot_handle); if (local_map) free(local_map); diff --git a/src/modules/bootstrap/shmemi_bootstrap_uid.cpp b/src/modules/bootstrap/shmemi_bootstrap_uid.cpp index 17b261d0..833f907a 100644 --- a/src/modules/bootstrap/shmemi_bootstrap_uid.cpp +++ b/src/modules/bootstrap/shmemi_bootstrap_uid.cpp @@ -399,6 +399,7 @@ static int shmemi_bootstrap_uid_allgather(const void *in, void *out, int len, sh static int shmemi_bootstrap_uid_barrier(shmemi_bootstrap_handle_t *handle) { + SHM_LOG_INFO("shmemi_bootstrap_uid_barrier"); if (!handle || !handle->bootstrap_state) { SHM_LOG_ERROR("bootstrap barrier: invalid arguments"); return SHMEM_BOOTSTRAP_ERROR; diff --git a/src/modules/transport/shmemi_rdma.cpp b/src/modules/transport/shmemi_rdma.cpp index 476ed633..4e43b2a6 100644 --- a/src/modules/transport/shmemi_rdma.cpp +++ b/src/modules/transport/shmemi_rdma.cpp @@ -35,6 +35,7 @@ int shmemi_rdma_connect_peers(shmemi_transport *t, int *selected_dev_ids, int nu auto local_device_ip = manager->GetDeviceIP(); SHM_LOG_INFO("local ip = " << inet_ntoa(local_device_ip)); std::vector device_ips(state->npes); + SHMEM_CHECK_RET((g_boot_handle.is_bootstraped != true), "boot_handle not bootstraped, Please check if the method call occurs before initialization or after finalization.", SHMEM_BOOTSTRAP_ERROR); g_boot_handle.allgather(&local_device_ip, device_ips.data(), sizeof(in_addr), &g_boot_handle); g_boot_handle.barrier(&g_boot_handle); for (int i = 0; i < state->npes; i++) { diff --git a/tests/unittest/host/init/init_host_test.cpp b/tests/unittest/host/init/init_host_test.cpp index a524e849..cdb05fde 100644 --- a/tests/unittest/host/init/init_host_test.cpp +++ b/tests/unittest/host/init/init_host_test.cpp @@ -52,8 +52,7 @@ void test_shmem_init(int rank_id, int n_ranks, uint64_t local_mem_size) EXPECT_EQ(g_state.mype, rank_id); EXPECT_EQ(g_state.npes, n_ranks); EXPECT_NE(g_state.heap_base, nullptr); - EXPECT_NE(g_state.p2p_heap_host_base[rank_id], nullptr); - EXPECT_NE(g_state.p2p_heap_device_base[rank_id], nullptr); + EXPECT_NE(g_state.host_p2p_heap_base[rank_id], nullptr); EXPECT_EQ(g_state.heap_size, local_mem_size + SHMEM_EXTRA_SIZE); EXPECT_NE(g_state.team_pools[0], nullptr); status = shmem_init_status(); diff --git a/tests/unittest/host/main_test.cpp b/tests/unittest/host/main_test.cpp index dd8b488c..994a0416 100644 --- a/tests/unittest/host/main_test.cpp +++ b/tests/unittest/host/main_test.cpp @@ -125,6 +125,8 @@ void test_mutil_task(std::function func, uint64_t loca waitpid(pids[i], &status[i], 0); if (WIFEXITED(status[i]) && WEXITSTATUS(status[i]) != 0) { FAIL(); + } else if (WIFSIGNALED(status[i])) { + FAIL(); } } } diff --git a/tests/unittest/host/mem/shmem_host_heap_test.cpp b/tests/unittest/host/mem/shmem_host_heap_test.cpp index 2bddbb66..09bc2c54 100644 --- a/tests/unittest/host/mem/shmem_host_heap_test.cpp +++ b/tests/unittest/host/mem/shmem_host_heap_test.cpp @@ -49,8 +49,7 @@ protected: EXPECT_EQ(g_state.mype, rank_id); EXPECT_EQ(g_state.npes, n_ranks); EXPECT_NE(g_state.heap_base, nullptr); - EXPECT_NE(g_state.p2p_heap_host_base[rank_id], nullptr); - EXPECT_NE(g_state.p2p_heap_device_base[rank_id], nullptr); + EXPECT_NE(g_state.host_p2p_heap_base[rank_id], nullptr); EXPECT_EQ(g_state.heap_size, local_mem_size + SHMEM_EXTRA_SIZE); EXPECT_NE(g_state.team_pools[0], nullptr); status = shmem_init_status(); @@ -118,7 +117,7 @@ TEST_F(ShareMemoryManagerTest, allocate_large_memory_failed) int32_t device_id = rank_id % test_gnpu_num + test_first_npu; aclrtStream stream; test_init(rank_id, n_ranks, local_mem_size, &stream); - auto ptr = shmem_malloc(heap_memory_size + 1UL); + auto ptr = shmem_malloc(heap_memory_size + SHMEM_EXTRA_SIZE + 1UL); EXPECT_EQ(nullptr, ptr); test_finalize(stream, device_id); }, @@ -201,7 +200,7 @@ TEST_F(ShareMemoryManagerTest, calloc_large_memory_failed) aclrtStream stream; test_init(rank_id, n_ranks, local_mem_size, &stream); const size_t nmemb = 16; - auto ptr = shmem_calloc(nmemb, heap_memory_size / nmemb + 1UL); + auto ptr = shmem_calloc(nmemb, (heap_memory_size + SHMEM_EXTRA_SIZE) / nmemb + 1UL); EXPECT_EQ(nullptr, ptr); test_finalize(stream, device_id); }, @@ -298,7 +297,7 @@ TEST_F(ShareMemoryManagerTest, align_large_memory_failed) aclrtStream stream; test_init(rank_id, n_ranks, local_mem_size, &stream); const size_t alignment = 16; - auto ptr = shmem_align(alignment, heap_memory_size + 1UL); + auto ptr = shmem_align(alignment, heap_memory_size + SHMEM_EXTRA_SIZE + 1UL); EXPECT_EQ(nullptr, ptr); test_finalize(stream, device_id); }, @@ -390,13 +389,11 @@ TEST_F(ShareMemoryManagerTest, calls_before_init_and_after_finalize) aclrtStream stream; test_init(rank_id, n_ranks, local_mem_size, &stream); - void *ok = shmem_malloc(2048UL); EXPECT_NE(nullptr, ok); shmem_free(ok); - test_finalize(stream, device_id); - + EXPECT_EQ(nullptr, shmem_malloc(1024UL)); EXPECT_EQ(nullptr, shmem_calloc(2, 512)); EXPECT_EQ(nullptr, shmem_align(32, 1024UL)); diff --git a/tests/unittest/host/mem/shmem_ptr_host_test.cpp b/tests/unittest/host/mem/shmem_ptr_host_test.cpp index c091f72a..f061766b 100644 --- a/tests/unittest/host/mem/shmem_ptr_host_test.cpp +++ b/tests/unittest/host/mem/shmem_ptr_host_test.cpp @@ -24,7 +24,7 @@ static int32_t test_get_device_ptr(aclrtStream stream, uint8_t *ptr, int rank_id uint32_t block_dim = 1; int32_t device_id; - SHMEM_CHECK_RET(aclrtGetDevice(&device_id), aclrtGetDevice); + SHMEM_CHECK_RET(aclrtGetDevice(&device_id)); get_device_ptr(block_dim, stream, ptr); EXPECT_EQ(aclrtSynchronizeStream(stream), 0); diff --git a/tests/unittest/host/team/team/team_host_test.cpp b/tests/unittest/host/team/team/team_host_test.cpp index eb72efff..833586bb 100644 --- a/tests/unittest/host/team/team/team_host_test.cpp +++ b/tests/unittest/host/team/team/team_host_test.cpp @@ -32,7 +32,7 @@ static int32_t test_get_device_state(aclrtStream stream, uint8_t *gva, uint32_t uint32_t block_dim = 1; void *ptr = shmem_malloc(1024); int32_t device_id; - SHMEM_CHECK_RET(aclrtGetDevice(&device_id), aclrtGetDevice); + SHMEM_CHECK_RET(aclrtGetDevice(&device_id)); get_device_state(block_dim, stream, (uint8_t *)ptr, team_id); EXPECT_EQ(aclrtSynchronizeStream(stream), 0); sleep(1); -- Gitee From 1395ec8fca9b83114a4e8130bcf3e278a3acc0d5 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Tue, 2 Dec 2025 20:51:25 +0800 Subject: [PATCH 66/74] MF && Default Backend Examples OK --- examples/rdma_demo/README.md | 3 +- examples/rdma_demo/run.sh | 5 ++- .../unuse_handlewait/README.md | 4 +- .../unuse_handlewait/run.sh | 9 +++++ .../use_handlewait/README.md | 4 +- .../use_handlewait/run.sh | 9 +++++ examples/rdma_perftest/README.md | 3 +- examples/rdma_perftest/run.sh | 9 +++++ .../low_level/shmem_device_low_level_roce.h | 37 ++++++++++++------- .../init/init_backends/mf/shmemi_init_mf.cpp | 1 + 10 files changed, 63 insertions(+), 21 deletions(-) create mode 100644 examples/rdma_handlewait_test/unuse_handlewait/run.sh create mode 100644 examples/rdma_handlewait_test/use_handlewait/run.sh create mode 100644 examples/rdma_perftest/run.sh diff --git a/examples/rdma_demo/README.md b/examples/rdma_demo/README.md index effc8a35..a6eeba74 100644 --- a/examples/rdma_demo/README.md +++ b/examples/rdma_demo/README.md @@ -7,7 +7,8 @@ bash scripts/build.sh ```bash export PROJECT_ROOT= export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/src/memfabric_hybrid/output/smem/lib64/:${PROJECT_ROOT}/src/memfabric_hybrid/output/hybm/lib64/:$LD_LIBRARY_PATH -mpirun -np 2 ./build/bin/rdma_demo tcp://127.0.0.1:8765 2 0 0 +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 +mpirun -np 2 ${PROJECT_ROOT}/build/bin/rdma_demo tcp://127.0.0.1:8899 2 0 0 ``` 3.命令行参数说明 diff --git a/examples/rdma_demo/run.sh b/examples/rdma_demo/run.sh index 873d82ce..b9443786 100644 --- a/examples/rdma_demo/run.sh +++ b/examples/rdma_demo/run.sh @@ -4,5 +4,6 @@ script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" project_root="$(cd ${script_dir}/../../ && pwd)" export PROJECT_ROOT=${project_root} export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/src/memfabric_hybrid/output/smem/lib64/:${PROJECT_ROOT}/src/memfabric_hybrid/output/hybm/lib64/:$LD_LIBRARY_PATH -./build/bin/rdma_demo 2 0 tcp://127.0.0.1:8765 2 0 0 & # rank 0 -./build/bin/rdma_demo 2 1 tcp://127.0.0.1:8765 2 0 0 & # rank 1 \ No newline at end of file + +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 +mpirun -np 2 ${PROJECT_ROOT}/build/bin/rdma_demo tcp://127.0.0.1:8899 2 0 0 \ No newline at end of file diff --git a/examples/rdma_handlewait_test/unuse_handlewait/README.md b/examples/rdma_handlewait_test/unuse_handlewait/README.md index 065cb996..63424274 100644 --- a/examples/rdma_handlewait_test/unuse_handlewait/README.md +++ b/examples/rdma_handlewait_test/unuse_handlewait/README.md @@ -7,8 +7,8 @@ bash scripts/build.sh -examples ```bash export PROJECT_ROOT= export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/smem/lib64:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/hybm/lib64:$LD_LIBRARY_PATH -./build/bin/unuse_handlewait 2 0 tcp://127.0.0.1:8765 2 0 0 # rank 0 -./build/bin/unuse_handlewait 2 1 tcp://127.0.0.1:8765 2 0 0 # rank 1 +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 +mpirun -np 2 ${PROJECT_ROOT}/build/bin/unuse_handlewait tcp://127.0.0.1:8899 2 0 0 ``` 3.命令行参数说明 diff --git a/examples/rdma_handlewait_test/unuse_handlewait/run.sh b/examples/rdma_handlewait_test/unuse_handlewait/run.sh new file mode 100644 index 00000000..fc19564b --- /dev/null +++ b/examples/rdma_handlewait_test/unuse_handlewait/run.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +project_root="$(cd ${script_dir}/../../../ && pwd)" +export PROJECT_ROOT=${project_root} +export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/src/memfabric_hybrid/output/smem/lib64/:${PROJECT_ROOT}/src/memfabric_hybrid/output/hybm/lib64/:$LD_LIBRARY_PATH + +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 +mpirun -np 2 ${PROJECT_ROOT}/build/bin/unuse_handlewait tcp://127.0.0.1:8899 2 0 0 \ No newline at end of file diff --git a/examples/rdma_handlewait_test/use_handlewait/README.md b/examples/rdma_handlewait_test/use_handlewait/README.md index e0bdcd8c..18dc98be 100644 --- a/examples/rdma_handlewait_test/use_handlewait/README.md +++ b/examples/rdma_handlewait_test/use_handlewait/README.md @@ -7,8 +7,8 @@ bash scripts/build.sh -examples ```bash export PROJECT_ROOT= export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/smem/lib64:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/hybm/lib64:$LD_LIBRARY_PATH -./build/bin/use_handlewait 2 0 tcp://127.0.0.1:8765 2 0 0 # rank 0 -./build/bin/use_handlewait 2 1 tcp://127.0.0.1:8765 2 0 0 # rank 1 +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 +mpirun -np 2 ${PROJECT_ROOT}/build/bin/use_handlewait tcp://127.0.0.1:8899 2 0 0 ``` 3.命令行参数说明 diff --git a/examples/rdma_handlewait_test/use_handlewait/run.sh b/examples/rdma_handlewait_test/use_handlewait/run.sh new file mode 100644 index 00000000..1bbf49c4 --- /dev/null +++ b/examples/rdma_handlewait_test/use_handlewait/run.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +project_root="$(cd ${script_dir}/../../../ && pwd)" +export PROJECT_ROOT=${project_root} +export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/src/memfabric_hybrid/output/smem/lib64/:${PROJECT_ROOT}/src/memfabric_hybrid/output/hybm/lib64/:$LD_LIBRARY_PATH + +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 +mpirun -np 2 ${PROJECT_ROOT}/build/bin/use_handlewait tcp://127.0.0.1:8899 2 0 0 \ No newline at end of file diff --git a/examples/rdma_perftest/README.md b/examples/rdma_perftest/README.md index cdf91185..3d861e07 100644 --- a/examples/rdma_perftest/README.md +++ b/examples/rdma_perftest/README.md @@ -7,7 +7,8 @@ bash scripts/build.sh ```bash export PROJECT_ROOT= export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/smem/lib64:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/hybm/lib64:$LD_LIBRARY_PATH -mpirun -np 2 ./build/bin/rdma_perftest tcp://127.0.0.1:8765 2 0 0 highlevel_put_pingpong_latency 64 +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 +mpirun -np 2 ${PROJECT_ROOT}/build/bin/rdma_perftest tcp://127.0.0.1:8899 2 0 0 highlevel_put_pingpong_latency 64 ``` 3.命令行参数说明 diff --git a/examples/rdma_perftest/run.sh b/examples/rdma_perftest/run.sh new file mode 100644 index 00000000..b138bfde --- /dev/null +++ b/examples/rdma_perftest/run.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +project_root="$(cd ${script_dir}/../../ && pwd)" +export PROJECT_ROOT=${project_root} +export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/src/memfabric_hybrid/output/smem/lib64/:${PROJECT_ROOT}/src/memfabric_hybrid/output/hybm/lib64/:$LD_LIBRARY_PATH + +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 +mpirun -np 2 ${PROJECT_ROOT}/build/bin/rdma_perftest tcp://127.0.0.1:8899 2 0 0 highlevel_put_pingpong_latency 64 \ No newline at end of file diff --git a/include/device/low_level/shmem_device_low_level_roce.h b/include/device/low_level/shmem_device_low_level_roce.h index efaf9463..9e4eb32b 100644 --- a/include/device/low_level/shmem_device_low_level_roce.h +++ b/include/device/low_level/shmem_device_low_level_roce.h @@ -102,6 +102,19 @@ struct SHMEMHybmDeviceMeta { uint64_t reserved[12]; // total 128B, equal HYBM_DEVICE_PRE_META_SIZE }; +SHMEM_DEVICE __gm__ SHMEMAIVRDMAInfo* shmemi_qp_info_fetch() +{ +#ifdef BACKEND_MF + __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)( + SMEM_SHM_DEVICE_META_ADDR + SMEM_SHM_DEVICE_GLOBAL_META_SIZE); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); +#else + __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(device_state->qp_info); +#endif + return RDMAInfo; +} + SHMEM_DEVICE void shmemi_roce_poll_cq_update_info(AscendC::LocalTensor &ubLocal64, AscendC::LocalTensor &ubLocal32, uint32_t &curTail, uint32_t &rRankId, uint32_t &qpIdx); SHMEM_DEVICE void shmemi_rdma_post_send_update_info(AscendC::LocalTensor &ubLocal64, @@ -123,9 +136,8 @@ SHMEM_DEVICE uint32_t shmemi_roce_poll_cq(uint32_t remoteRankId, uint32_t qpIdx, if (idx == 0) { return 0; } + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = shmemi_qp_info_fetch(); - __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); - __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(device_state->qp_info); uint32_t qpNum = RDMAInfo->qpNum; __gm__ SHMEMCQCtx* cqCtxEntry = (__gm__ SHMEMCQCtx*)(RDMAInfo->scqPtr + (remoteRankId * qpNum + qpIdx) * sizeof(SHMEMCQCtx)); @@ -171,9 +183,8 @@ SHMEM_DEVICE uint32_t shmemi_roce_poll_cq(uint32_t remoteRankId, uint32_t qpIdx, SHMEM_DEVICE void shmemi_roce_poll_cq_update_info(AscendC::LocalTensor &ubLocal64, AscendC::LocalTensor &ubLocal32, uint32_t &curTail, uint32_t &remoteRankId, uint32_t &qpIdx) { - __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + - SMEM_SHM_DEVICE_GLOBAL_META_SIZE); - __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = shmemi_qp_info_fetch(); + uint32_t qpNum = RDMAInfo->qpNum; __gm__ SHMEMCQCtx* cqCtxEntry = (__gm__ SHMEMCQCtx*)(RDMAInfo->scqPtr + (remoteRankId * qpNum + qpIdx) * sizeof(SHMEMCQCtx)); @@ -235,8 +246,8 @@ SHMEM_DEVICE void shmemi_rdma_post_send(__gm__ uint8_t* remoteAddr, __gm__ uint8 AscendC::LocalTensor ubLocal64, AscendC::LocalTensor ubLocal32) { - __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); - __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(device_state->qp_info); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = shmemi_qp_info_fetch(); + uint32_t qpNum = RDMAInfo->qpNum; __gm__ SHMEMWQCtx* qpCtxEntry = (__gm__ SHMEMWQCtx*)(RDMAInfo->sqPtr + (destRankId * qpNum + qpIdx) * sizeof(SHMEMWQCtx)); @@ -375,8 +386,8 @@ SHMEM_DEVICE void shmemi_roce_quiet(uint32_t remoteRankId, uint32_t qpIdx, AscendC::LocalTensor ubLocal64, AscendC::LocalTensor ubLocal32) { - __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); - __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(device_state->qp_info); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = shmemi_qp_info_fetch(); + uint32_t qpNum = RDMAInfo->qpNum; __gm__ SHMEMWQCtx* qpCtxEntry = (__gm__ SHMEMWQCtx*)(RDMAInfo->sqPtr + (remoteRankId * qpNum + qpIdx) * sizeof(SHMEMWQCtx)); @@ -388,8 +399,8 @@ SHMEM_DEVICE void shmemi_roce_quiet(uint32_t remoteRankId, uint32_t qpIdx, SHMEM_DEVICE void shmemi_roce_qpinfo_test(__gm__ uint8_t* gva, uint32_t destRankId, uint32_t qpIdx) { - __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); - __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(device_state->qp_info); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = shmemi_qp_info_fetch(); + *(__gm__ uint64_t*)(gva) = (uint64_t)RDMAInfo; uint32_t qpNum = RDMAInfo->qpNum; *(__gm__ uint64_t*)(gva + 8) = (uint64_t)qpNum; @@ -440,8 +451,8 @@ SHMEM_DEVICE void shmemi_roce_pollcq_test(__gm__ T* srcDmaAddr, __gm__ T* destDm shmemi_rdma_post_send(destDmaAddr, srcDmaAddr, destRankId, qpIdx, SHMEMAIVOPCODE::OP_RDMA_WRITE, messageLen, ubLocal64, ubLocal32); uint32_t idx = 1; - __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); - __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(device_state->qp_info); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = shmemi_qp_info_fetch(); + uint32_t qpNum = RDMAInfo->qpNum; __gm__ SHMEMCQCtx* cqCtxEntry = (__gm__ SHMEMCQCtx*)(RDMAInfo->scqPtr + (destRankId * qpNum + qpIdx) * sizeof(SHMEMCQCtx)); *(__gm__ uint64_t*)(gva) = (uint64_t)cqCtxEntry; diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.cpp b/src/host/init/init_backends/mf/shmemi_init_mf.cpp index 72a75952..b8a9a3ba 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.cpp +++ b/src/host/init/init_backends/mf/shmemi_init_mf.cpp @@ -136,6 +136,7 @@ int shmemi_init_mf::reserve_heap() g_state->host_sdma_heap_base[i] = NULL; } if (reach_info & SMEMS_DATA_OP_RDMA) { + g_state->host_rdma_heap_base[i] = g_state->host_p2p_heap_base[i]; g_state->topo_list[i] |= SHMEM_TRANSPORT_ROCE; } } -- Gitee From d7c3dcba5aebe036878199e349fe7fd625098cda Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Wed, 3 Dec 2025 11:37:27 +0800 Subject: [PATCH 67/74] Rdma bug fix --- include/internal/device/sync/shmemi_device_p2p.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/internal/device/sync/shmemi_device_p2p.h b/include/internal/device/sync/shmemi_device_p2p.h index 8039afc3..ebc629c7 100644 --- a/include/internal/device/sync/shmemi_device_p2p.h +++ b/include/internal/device/sync/shmemi_device_p2p.h @@ -36,7 +36,7 @@ SHMEM_DEVICE void shmemi_highlevel_signal_set(__gm__ int32_t *dst, __gm__ int32_ ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR + UB_ALIGN_SIZE); ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; - shmemi_roce_write((__gm__ uint8_t*)shmem_ptr(dst, pe), (__gm__ uint8_t*)src, pe, 0, sizeof(int32_t), + shmemi_roce_write((__gm__ uint8_t*)shmem_roce_ptr(dst, pe), (__gm__ uint8_t*)src, pe, 0, sizeof(int32_t), ub_tensor_64, ub_tensor_32); shmemi_roce_quiet(pe, 0, ub_tensor_64, ub_tensor_32); } -- Gitee From 13c3f6e59ad10a4d7b5bd0f2e74b33b4f8bd1d35 Mon Sep 17 00:00:00 2001 From: zhangyunqi05 Date: Wed, 3 Dec 2025 15:18:03 +0800 Subject: [PATCH 68/74] ut support cross-machine --- scripts/run.sh | 6 ++++-- src/host/init/shmem_init.cpp | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/scripts/run.sh b/scripts/run.sh index 1216d002..8372b582 100644 --- a/scripts/run.sh +++ b/scripts/run.sh @@ -23,6 +23,7 @@ rm -rf "$COVERAGE_PATH" set -e RANK_SIZE="8" IPPORT="tcp://127.0.0.1:8666" +SESSION_ID="127.0.0.1:8766" GNPU_NUM="8" FIRST_NPU="0" FIRST_RANK="0" @@ -67,7 +68,9 @@ while [[ $# -gt 0 ]]; do -ipport) if [ -n "$2" ]; then if [[ "$2" =~ ^[a-zA-z0-9.:/_-]+$ ]]; then - IPPORT="$2" + IPPORT="tcp://${2}" + SESSION_ID="${2}" + export SHMEM_UID_SESSION_ID=$SESSION_ID shift 2 else echo "Error: Invalid -ipport format, only alphanumeric and :/_- allowed" @@ -129,7 +132,6 @@ export SMEM_CONF_STORE_TLS_ENABLE=0 export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/install/memfabric_hybrid/lib:${ASCEND_HOME_PATH}/lib64:$LD_LIBRARY_PATH # Run unit test -export SHMEM_UID_SESSION_ID=127.0.0.1:8899 cd "$BUILD_PATH" ./bin/shmem_unittest "$RANK_SIZE" "$IPPORT" "$GNPU_NUM" "$FIRST_RANK" "$FIRST_NPU" --gtest_output=xml:test_detail.xml --gtest_filter=${TEST_FILTER} diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index a62f81ba..c28ec6ec 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -358,6 +358,7 @@ int32_t shmem_get_uniqueid_default(shmemx_uniqueid_t *uid) int32_t shmem_get_uniqueid(shmemx_uniqueid_t *uid){ shmem_set_log_level(shm::ERROR_LEVEL); + *uid = SHMEM_UNIQUEID_INITIALIZER; #ifdef BACKEND_MF SHMEM_CHECK_RET(shmem_get_uniqueid_mf(uid), "shmem_get_uniqueid failed, backend: mf"); return SHMEM_SUCCESS; @@ -369,6 +370,7 @@ int32_t shmem_get_uniqueid(shmemx_uniqueid_t *uid){ int32_t shmemi_get_uniqueid_static_magic(shmemx_uniqueid_t *uid, bool is_root) { shmem_set_log_level(shm::ERROR_LEVEL); + *uid = SHMEM_UNIQUEID_INITIALIZER; int status = 0; SHMEM_CHECK_RET(shmemi_options_init(), "Bootstrap failed during the preloading step."); SHMEM_CHECK_RET(shmemi_bootstrap_pre_init(SHMEMX_INIT_WITH_UNIQUEID, &g_boot_handle), "Get uniqueid failed during the bootstrap preloading step."); -- Gitee From 9a33512c60eec774ef07b1dbbf2e09bd2a729529 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Wed, 3 Dec 2025 16:33:13 +0800 Subject: [PATCH 69/74] Fix 910B CrossNode Bug && Improve CrossNode Judgement --- include/internal/host_device/shmemi_types.h | 1 - src/host/common/shmemi_host_types.h | 7 +++--- src/host/init/shmem_init.cpp | 26 --------------------- src/host/transport/shmemi_transport.cpp | 15 +++++++++--- src/modules/transport/shmemi_mte.cpp | 25 +++++++++++++++++--- src/modules/transport/shmemi_rdma.cpp | 4 ++-- 6 files changed, 40 insertions(+), 38 deletions(-) diff --git a/include/internal/host_device/shmemi_types.h b/include/internal/host_device/shmemi_types.h index 2a5f60bb..b5c01427 100644 --- a/include/internal/host_device/shmemi_types.h +++ b/include/internal/host_device/shmemi_types.h @@ -100,7 +100,6 @@ typedef struct { uint64_t sync_counter; uint64_t core_sync_pool; uint64_t core_sync_counter; - uint64_t host_hash; bool is_shmem_initialized; bool is_shmem_created; diff --git a/src/host/common/shmemi_host_types.h b/src/host/common/shmemi_host_types.h index 98073b0d..31897db7 100644 --- a/src/host/common/shmemi_host_types.h +++ b/src/host/common/shmemi_host_types.h @@ -53,13 +53,14 @@ typedef struct shmemi_bootstrap_uid_options { typedef struct shmemi_transport_pe_info { int32_t pe; int32_t dev_id; - uint64_t host_hash; + int64_t server_id; + int64_t superpod_id; } shmemi_transport_pe_info_t; typedef struct shmemi_transport { // control plane - int (*can_access_peer)(int *access, shmemi_transport_pe_info_t *peer, - struct shmemi_transport *t, shmemi_device_host_state_t *g_state); + int (*can_access_peer)(int *access, shmemi_transport_pe_info_t *peer_info, + shmemi_transport_pe_info_t *my_info, struct shmemi_transport *t); int (*connect_peers)(struct shmemi_transport *t, int *selected_dev_ids, int num_selected_devs, shmemi_device_host_state_t *g_state); int (*finalize)(struct shmemi_transport *t, diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index c28ec6ec..77a65fe0 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -52,7 +52,6 @@ constexpr int DEFAULT_BLOCK_NUM = 1; 0, /* sync_counter */ \ 0, /* core_sync_pool */ \ 0, /* core_sync_counter */ \ - 0, /* host_hash */ \ false, /* shmem_is_shmem_initialized */\ false, /* shmem_is_shmem_created */ \ {0, 16 * 1024, 0}, /* shmem_mte_config */ \ @@ -78,37 +77,12 @@ int32_t shmemi_options_init() return status; } -uint64_t shmemi_get_host_hash() -{ - char hostname[128]; - struct hostent *he; - - if (gethostname(hostname, sizeof(hostname)) != 0) { - perror("gethostname"); - return 0; - } - - if ((he = gethostbyname(hostname)) == NULL) { - perror("gethostbyname"); - return 0; - } - - // Host IP Address - for (int i = 0; he->h_addr_list[i] != NULL; i++) { - char *ip = inet_ntoa(*(struct in_addr*)he->h_addr_list[i]); - } - - std::size_t host_hash = std::hash{}(hostname); - return host_hash; -} - int32_t shmemi_state_init_attr(shmem_init_attr_t *attributes) { int32_t status = SHMEM_SUCCESS; g_state.mype = attributes->my_rank; g_state.npes = attributes->n_ranks; g_state.heap_size = attributes->local_mem_size + SHMEM_EXTRA_SIZE; - g_state.host_hash = shmemi_get_host_hash(); aclrtStream stream = nullptr; SHMEM_CHECK_RET(aclrtCreateStream(&stream)); diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp index 2192a249..cb81c09f 100644 --- a/src/host/transport/shmemi_transport.cpp +++ b/src/host/transport/shmemi_transport.cpp @@ -7,6 +7,7 @@ * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. * See LICENSE in the root of the software repository for the full text of the License. */ +#include "mem/shmemi_heap.h" #include "shmemi_host_common.h" #include "dlfcn.h" @@ -42,12 +43,19 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t *g_state, shmem_init_op // Package my_info int32_t device_id; + int64_t server_id; + int64_t superpod_id; SHMEM_CHECK_RET(aclrtGetDevice(&device_id)); + const int infoTypeServerId = 27; + SHMEM_CHECK_RET(rtGetDeviceInfo(device_id, 0, infoTypeServerId, &server_id)); + const int infoTypeSuperpodId = 29; + SHMEM_CHECK_RET(rtGetDeviceInfo(device_id, 0, infoTypeSuperpodId, &superpod_id)); shmemi_transport_pe_info_t my_info; my_info.pe = g_state->mype; my_info.dev_id = device_id; - my_info.host_hash = g_state->host_hash; + my_info.server_id = server_id; + my_info.superpod_id = superpod_id; // AllGather All pe's host info SHMEM_CHECK_RET((g_boot_handle.is_bootstraped != true), "boot_handle not bootstraped, Please check if the method call occurs before initialization or after finalization.", SHMEM_BOOTSTRAP_ERROR); @@ -99,7 +107,7 @@ int32_t shmemi_build_transport_map(shmemi_device_host_state_t *g_state) { for (int i = 0; i < g_state->npes; i++) { int reach = 0; - SHMEM_CHECK_RET(t.can_access_peer(&reach, g_host_state.pe_info + i, &t, g_state)); + SHMEM_CHECK_RET(t.can_access_peer(&reach, g_host_state.pe_info + i, g_host_state.pe_info + g_state->mype, &t)); if (reach) { int m = 1 << j; @@ -134,8 +142,9 @@ int32_t shmemi_transport_setup_connections(shmemi_device_host_state_t *g_state) /* Check if MTE connected. */ if (g_host_state.transport_map[local_offset + i] & 0x1) { shmemi_transport_pe_info_t *peer_info = (g_host_state.pe_info + i); + shmemi_transport_pe_info_t *my_info = (g_host_state.pe_info + g_state->mype); // Only PEs in the same Node need to build up MTE connection. - if (g_state->host_hash == peer_info->host_hash) { + if (my_info->server_id == peer_info->server_id) { mte_peer_list[mte_peer_num] = peer_info->dev_id; ++mte_peer_num; } diff --git a/src/modules/transport/shmemi_mte.cpp b/src/modules/transport/shmemi_mte.cpp index c6e21136..6bb0d76e 100644 --- a/src/modules/transport/shmemi_mte.cpp +++ b/src/modules/transport/shmemi_mte.cpp @@ -21,15 +21,22 @@ extern "C" { #endif -int shmemi_mte_can_access_peer(int *access, shmemi_transport_pe_info_t *peer_info, shmemi_transport *t, shmemi_device_host_state_t *g_state) { +int shmemi_mte_can_access_peer(int *access, shmemi_transport_pe_info_t *peer_info, shmemi_transport_pe_info_t *my_info, shmemi_transport *t) { // origin access set to 0. *access = 0; auto sName = aclrtGetSocName(); std::string socName{sName}; if (socName.find("Ascend910B") != std::string::npos) { // Ascend910B Topo + // Check Server ID. + if (my_info->server_id != peer_info->server_id) { + *access = 0; + return 0; + } + + // In same node, Check HCCS Connectivity. int64_t hccs_connected = -1; - SHMEM_CHECK_RET(rtGetPairDevicesInfo(g_state->mype, peer_info->dev_id, 0, &hccs_connected)); + SHMEM_CHECK_RET(rtGetPairDevicesInfo(my_info->pe, peer_info->dev_id, 0, &hccs_connected)); // In 910B, Flag 0 -> HCCS. const static int SELF_FLAG = 0; @@ -37,9 +44,21 @@ int shmemi_mte_can_access_peer(int *access, shmemi_transport_pe_info_t *peer_inf *access = 1; } } else if (socName.find("Ascend910_93") != std::string::npos) { // Ascend910_93 Topo + // Check SuperPod ID. + if (my_info->superpod_id != peer_info->superpod_id) { + *access = 0; + return 0; + } + + if (my_info->superpod_id == 0xFFFFFFFFU || peer_info->superpod_id == 0xFFFFFFFFU) { // Invalid + *access = 0; + return 0; + } + + // In same node, Check HCCS Connectivity. int64_t hccs_connected = -1; /* TODO: This func now doesn't support 910_93 multiNode HCCS Check. Only Check in the same Node. */ - SHMEM_CHECK_RET(rtGetPairDevicesInfo(g_state->mype, peer_info->dev_id, 0, &hccs_connected)); + SHMEM_CHECK_RET(rtGetPairDevicesInfo(my_info->pe, peer_info->dev_id, 0, &hccs_connected)); // In 910_93, Flag 0 -> SELF, 5 -> SIO, 6 -> HCCS. const static int SELF_FLAG = 0; diff --git a/src/modules/transport/shmemi_rdma.cpp b/src/modules/transport/shmemi_rdma.cpp index 4e43b2a6..2792506e 100644 --- a/src/modules/transport/shmemi_rdma.cpp +++ b/src/modules/transport/shmemi_rdma.cpp @@ -22,8 +22,8 @@ extern "C" { #endif static rdma_manager* manager; -int shmemi_rdma_can_access_peer(int *access, shmemi_transport_pe_info_t *peer_info, shmemi_transport *t, shmemi_device_host_state_t *state) { - if (peer_info->pe == state->mype) { +int shmemi_rdma_can_access_peer(int *access, shmemi_transport_pe_info_t *peer_info, shmemi_transport_pe_info_t *my_info, shmemi_transport *t) { + if (peer_info->pe == my_info->pe) { *access = 0; } else { *access = 1; -- Gitee From 81fbde27bd84a642c1b9ff2bccd3815240911887 Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Wed, 3 Dec 2025 19:46:11 +0800 Subject: [PATCH 70/74] Enable 910B CrossNode case --- src/host/transport/shmemi_transport.cpp | 16 ++++++++++++++++ src/modules/transport/shmemi_mte.cpp | 11 ----------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp index cb81c09f..0271e9c0 100644 --- a/src/host/transport/shmemi_transport.cpp +++ b/src/host/transport/shmemi_transport.cpp @@ -7,6 +7,7 @@ * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. * See LICENSE in the root of the software repository for the full text of the License. */ +#include #include "mem/shmemi_heap.h" #include "shmemi_host_common.h" #include "dlfcn.h" @@ -56,6 +57,21 @@ int32_t shmemi_transport_init(shmemi_device_host_state_t *g_state, shmem_init_op my_info.dev_id = device_id; my_info.server_id = server_id; my_info.superpod_id = superpod_id; + + // server_id invalid + if (server_id == 0x3FFU) { + static uint32_t bootIdHead; + static std::string sysBootId; + + std::string bootIdPath("/proc/sys/kernel/random/boot_id"); + std::ifstream input(bootIdPath); + input >> sysBootId; + + std::stringstream ss(sysBootId); + ss >> std::hex >> bootIdHead; + + my_info.server_id = bootIdHead; + } // AllGather All pe's host info SHMEM_CHECK_RET((g_boot_handle.is_bootstraped != true), "boot_handle not bootstraped, Please check if the method call occurs before initialization or after finalization.", SHMEM_BOOTSTRAP_ERROR); diff --git a/src/modules/transport/shmemi_mte.cpp b/src/modules/transport/shmemi_mte.cpp index 6bb0d76e..9e952b6c 100644 --- a/src/modules/transport/shmemi_mte.cpp +++ b/src/modules/transport/shmemi_mte.cpp @@ -44,17 +44,6 @@ int shmemi_mte_can_access_peer(int *access, shmemi_transport_pe_info_t *peer_inf *access = 1; } } else if (socName.find("Ascend910_93") != std::string::npos) { // Ascend910_93 Topo - // Check SuperPod ID. - if (my_info->superpod_id != peer_info->superpod_id) { - *access = 0; - return 0; - } - - if (my_info->superpod_id == 0xFFFFFFFFU || peer_info->superpod_id == 0xFFFFFFFFU) { // Invalid - *access = 0; - return 0; - } - // In same node, Check HCCS Connectivity. int64_t hccs_connected = -1; /* TODO: This func now doesn't support 910_93 multiNode HCCS Check. Only Check in the same Node. */ -- Gitee From 68cbe24b90cf4e4b0252b25d5526ba0e684110f6 Mon Sep 17 00:00:00 2001 From: zhangyunqi05 Date: Thu, 4 Dec 2025 17:26:15 +0800 Subject: [PATCH 71/74] use master_addr for uid --- .../bootstrap/shmemi_bootstrap_uid.cpp | 215 +++++++++++++----- 1 file changed, 155 insertions(+), 60 deletions(-) diff --git a/src/modules/bootstrap/shmemi_bootstrap_uid.cpp b/src/modules/bootstrap/shmemi_bootstrap_uid.cpp index 833f907a..319c5d75 100644 --- a/src/modules/bootstrap/shmemi_bootstrap_uid.cpp +++ b/src/modules/bootstrap/shmemi_bootstrap_uid.cpp @@ -47,6 +47,26 @@ bool is_ipv6_loopback(const struct in6_addr *addr6) { return memcmp(addr6, &loopback6, sizeof(struct in6_addr)) == 0; } +// fe80 check +bool is_ipv6_link_local(const struct in6_addr *addr6) { + if (addr6 == nullptr) { + return false; + } + + const uint8_t* bytes = addr6->s6_addr; + + if (bytes[0] != 0xfe) { + return false; + } + + if ((bytes[1] & 0xc0) != 0x80) { + return false; + } + + SHM_LOG_DEBUG("It is fe80 address."); + return true; +} + bool is_ipv4_loopback(const struct in_addr *addr4) { return ((ntohl(addr4->s_addr) >> 24) & 0xFF) == IN_LOOPBACKNET; } @@ -95,10 +115,16 @@ int32_t shmemi_traverse_ifa( const char **prefixes, bool exclude, shmemx_bootstrap_uid_state_t *uid_args, - bool skipStateCheck = false + bool skipStateCheck = false, + bool allow_local = false ) { for (struct ifaddrs *ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next) { if (ifa->ifa_addr == nullptr) continue; + const char* ifname = ifa->ifa_name; + if (!allow_local && (strstr(ifname, "lo") != nullptr || strstr(ifname, "docker") != nullptr)) { + SHM_LOG_DEBUG("Skip interface: " << ifname << " (lo/docker, allow_local=false)"); + continue; + } bool match = false; const char **p = prefixes; @@ -131,6 +157,28 @@ int32_t shmemi_traverse_ifa( continue; } } + bool is_invalid_addr = false; + if (!allow_local) { + if (ifa->ifa_addr->sa_family == AF_INET) { + struct sockaddr_in *addr4 = (struct sockaddr_in *)ifa->ifa_addr; + if (is_ipv4_loopback(&addr4->sin_addr)) { + is_invalid_addr = true; + } + } else if (ifa->ifa_addr->sa_family == AF_INET6) { + struct sockaddr_in6 *addr6 = (struct sockaddr_in6 *)ifa->ifa_addr; + if (is_ipv6_link_local(&addr6->sin6_addr)) { + SHM_LOG_INFO("Blocked fe80 address."); + continue; + } + if (is_ipv6_loopback(&addr6->sin6_addr)) { + is_invalid_addr = true; + } + } + } + if (is_invalid_addr) { + SHM_LOG_DEBUG("Skip invalid address (lo/fe80, allow_local=false) on interface: " << ifname); + continue; + } if (ifa->ifa_addr->sa_family == AF_INET && (sockType == AF_UNSPEC || sockType == AF_INET)) { memset(&uid_args->addr.addr.addr4, 0, sizeof(struct sockaddr_in)); @@ -156,66 +204,31 @@ int32_t shmemi_traverse_ifa( } return SHMEM_INVALID_PARAM; } -int32_t shmemi_get_ip_from_ifa(shmemx_bootstrap_uid_state_t *uid_args, const char *ipInfo) { - if (uid_args == nullptr) { - SHM_LOG_ERROR("uid_args is nullptr"); - return SHMEM_INVALID_PARAM; - } - struct ifaddrs *ifaddr = nullptr; - char ifaPrefix[MAX_IFCONFIG_LENGTH] = {0}; - bool flag = false; - sa_family_t sockType = AF_INET; - bool foundValidIp = false; - - shmemi_get_uid_magic(uid_args); - - bool isIpInfoConfigured = (ipInfo != nullptr && strlen(ipInfo) > 0); - if (isIpInfoConfigured) { - int32_t ret = shmemi_uid_parse_interface_with_type(ipInfo, ifaPrefix, sockType, flag); - if (ret != SHMEM_SUCCESS) { - SHM_LOG_ERROR("Parse ipInfo failed, ret: " << ret); - return ret; - } +char* format_master_addr(const char* raw_addr) { + if (raw_addr == nullptr || strlen(raw_addr) == 0) { + return nullptr; } - - if (getifaddrs(&ifaddr) == -1) { - SHM_LOG_ERROR("getifaddrs failed: " << strerror(errno)); - return SHMEM_INVALID_PARAM; + bool is_ipv6 = (strchr(raw_addr, ':') != nullptr); + int raw_len = strlen(raw_addr); + int alloc_len = is_ipv6 ? (raw_len + 5) : (raw_len + 3); + + char* formatted_addr = (char*)malloc(alloc_len); + if (formatted_addr == nullptr) { + SHM_LOG_ERROR("malloc formatted addr failed, len=" << alloc_len); + return nullptr; } - if (isIpInfoConfigured) { - const char *specifiedPrefixes[] = {ifaPrefix, nullptr}; - SHM_LOG_INFO("Search interface with specified prefix: " << ifaPrefix); - foundValidIp = (shmemi_traverse_ifa(ifaddr, sockType, flag, specifiedPrefixes, false, uid_args) == SHMEM_SUCCESS); + if (is_ipv6) { + strcpy(formatted_addr, "["); + strcat(formatted_addr, raw_addr); + strcat(formatted_addr, "]:0"); } else { - const char *excludePrefixes[] = {"docker", "lo", nullptr}; - const char *dockerPrefixes[] = {"docker", nullptr}; - const char *loPrefixes[] = {"lo", nullptr}; - - SHM_LOG_INFO("Step 1: Search interfaces exclude 'docker' and 'lo'"); - foundValidIp = (shmemi_traverse_ifa(ifaddr, sockType, flag, excludePrefixes, true, uid_args) == SHMEM_SUCCESS); - - if (!foundValidIp) { - SHM_LOG_WARN("Step 2: Search interfaces match 'docker'"); - foundValidIp = (shmemi_traverse_ifa(ifaddr, sockType, flag, dockerPrefixes, false, uid_args) == SHMEM_SUCCESS); - } - - if (!foundValidIp) { - SHM_LOG_WARN("Step 3: Search interfaces match 'lo' (skip state check)"); - foundValidIp = (shmemi_traverse_ifa(ifaddr, sockType, flag, loPrefixes, false, uid_args, true) == SHMEM_SUCCESS); - } + strcpy(formatted_addr, raw_addr); + strcat(formatted_addr, ":0"); } - if (!foundValidIp) { - SHM_LOG_ERROR("Failed to get any valid IP address from interfaces"); - freeifaddrs(ifaddr); - return SHMEM_INVALID_PARAM; - } - - freeifaddrs(ifaddr); - SHM_LOG_INFO("Assign IP/Port from interface success"); - return SHMEM_SUCCESS; + return formatted_addr; } int32_t shmemi_get_ip_from_env(shmemx_bootstrap_uid_state_t *uid_args, const char *ipPort) { @@ -296,6 +309,85 @@ int32_t shmemi_get_ip_from_env(shmemx_bootstrap_uid_state_t *uid_args, const cha return SHMEM_SUCCESS; } +int32_t shmemi_get_ip_from_ifa(shmemx_bootstrap_uid_state_t *uid_args, const char *ipInfo) { + if (uid_args == nullptr) { + SHM_LOG_ERROR("uid_args is nullptr"); + return SHMEM_INVALID_PARAM; + } + + struct ifaddrs *ifaddr = nullptr; + char ifaPrefix[MAX_IFCONFIG_LENGTH] = {0}; + bool flag = false; + sa_family_t sockType = AF_INET; + bool foundValidIp = false; + + shmemi_get_uid_magic(uid_args); + + bool isIpInfoConfigured = (ipInfo != nullptr && strlen(ipInfo) > 0); + if (isIpInfoConfigured) { + int32_t ret = shmemi_uid_parse_interface_with_type(ipInfo, ifaPrefix, sockType, flag); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("Parse ipInfo failed, ret: " << ret); + return ret; + } + } + bool allow_local = isIpInfoConfigured; + + if (getifaddrs(&ifaddr) == -1) { + SHM_LOG_ERROR("getifaddrs failed: " << strerror(errno)); + return SHMEM_INVALID_PARAM; + } + + if (isIpInfoConfigured) { + const char *specifiedPrefixes[] = {ifaPrefix, nullptr}; + SHM_LOG_INFO("Search interface with specified prefix: " << ifaPrefix); + foundValidIp = (shmemi_traverse_ifa(ifaddr, sockType, flag, specifiedPrefixes, false, uid_args, false, allow_local) == SHMEM_SUCCESS); + } else { + const char *ethPrefixes[] = {"eth", nullptr}; + const char *excludePrefixes[] = {"docker", "lo", nullptr}; + + SHM_LOG_INFO("Step 1: Search interfaces match 'eth'"); + foundValidIp = (shmemi_traverse_ifa(ifaddr, sockType, flag, ethPrefixes, false, uid_args, false, allow_local) == SHMEM_SUCCESS); + + + char* temp_ip_port = nullptr; + if (!foundValidIp) { + SHM_LOG_WARN("Step 2: Failed to get any valid IP address from eth interfaces, trying to use torch's MASTER_ADDR."); + const char* env_master_addr = std::getenv("MASTER_ADDR"); + if (env_master_addr != nullptr && strlen(env_master_addr) != 0) { + temp_ip_port = format_master_addr(env_master_addr); + if (temp_ip_port == nullptr) { + SHM_LOG_WARN("format_master_addr failed"); + } else { + SHM_LOG_INFO("MASTER_ADDR is: " << temp_ip_port); + foundValidIp = (shmemi_get_ip_from_env(uid_args, temp_ip_port) == SHMEM_SUCCESS); + free((void*)temp_ip_port); + temp_ip_port = nullptr; + } + } else { + SHM_LOG_WARN("MASTER_ADDR is not set."); + } + } + + if (!foundValidIp) { + SHM_LOG_WARN("Step 3: Search interfaces exclude 'docker' and 'lo/fe80'"); + foundValidIp = (shmemi_traverse_ifa(ifaddr, sockType, flag, excludePrefixes, true, uid_args, false, allow_local) == SHMEM_SUCCESS); + } + } + + if (ifaddr != nullptr) { + freeifaddrs(ifaddr); + ifaddr = nullptr; + } + if (!foundValidIp) { + SHM_LOG_ERROR("No valid IP address found from any interface!"); + return SHMEM_INVALID_PARAM; + } + + SHM_LOG_INFO("Assign IP/Port from interface success"); + return SHMEM_SUCCESS; +} + int32_t shmemi_set_ip_info(void *uid, sa_family_t &sockType, char *pta_env_ip, uint16_t pta_env_port, bool is_from_ifa) { @@ -446,7 +538,7 @@ static bool is_loopback_addr(const sockaddr_t* addr) { if (addr->type == ADDR_IPv4) { return is_ipv4_loopback(&addr->addr.addr4.sin_addr); } else if (addr->type == ADDR_IPv6) { - return is_ipv6_loopback(&addr->addr.addr6.sin6_addr) != 0; + return is_ipv6_loopback(&addr->addr.addr6.sin6_addr); } else { return false; } @@ -462,7 +554,7 @@ static bool matchSubnet(struct ifaddrs local_if, sockaddr_t* remote) { } else { return false; } - is_loopback_addr(remote); + SHM_LOG_DEBUG("local_if family: " << local_if.ifa_addr->sa_family << " remote family: " << family); if (family != local_if.ifa_addr->sa_family) { SHM_LOG_DEBUG(" matchSubnet family unmatch."); @@ -531,6 +623,11 @@ static int find_interface_match_subnet(char* ifNames, sockaddr_t* localAddrs, so } bool remote_is_loopback = is_loopback_addr(remoteAddr); + if (remoteAddr->type == ADDR_IPv6){ + SHMEM_CHECK_RET(is_ipv6_link_local(&remoteAddr->addr.addr6.sin6_addr), "Remote address is fe80", SHMEM_BOOTSTRAP_ERROR); + } + SHM_LOG_INFO("Remote address is loopback:" << remote_is_loopback); + if (remote_is_loopback) { SHM_LOG_DEBUG("Remote address is loopback, check lo interface first"); for (interface = interfaces; interface && !found; interface = interface->ifa_next) { @@ -669,8 +766,6 @@ static int shmemi_bootstrap_net_init(shmemx_bootstrap_uid_state_t* uid_args, boo return SHMEM_BOOTSTRAP_ERROR; } - // Find the local interface that matches the remote address - SHM_LOG_INFO("Trying to find interface matching root address."); int find_result = find_interface_match_subnet(priv_info.bootstrap_netifname, &priv_info.bootstrap_netifaddr, &uid_args->addr); @@ -1046,7 +1141,7 @@ int shmemi_bootstrap_plugin_init(void* comm, shmemi_bootstrap_handle_t* handle) char ip_str[INET_ADDRSTRLEN] = {0}; SHMEM_CHECK_RET(inet_ntop(AF_INET, &ipv4->sin_addr, ip_str, sizeof(ip_str)) == nullptr, "convert ext_addr_listen ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); uint16_t port = ntohs(ipv4->sin_port); - SHM_LOG_INFO(" Ext_addr_listen socket: Type: IPv6, IP: " << ip_str << ", Port: " << port << ", sa_family: " << ipv4->sin_family); + SHM_LOG_INFO(" Ext_addr_listen socket: Type: IPv4, IP: " << ip_str << ", Port: " << port << ", sa_family: " << ipv4->sin_family); } else if (info.ext_addr_listen.type == ADDR_IPv6) { struct sockaddr_in6* ipv6 = &info.ext_addr_listen.addr.addr6; @@ -1064,7 +1159,7 @@ int shmemi_bootstrap_plugin_init(void* comm, shmemi_bootstrap_handle_t* handle) char ip_str[INET_ADDRSTRLEN] = {0}; SHMEM_CHECK_RET(inet_ntop(AF_INET, &ipv4->sin_addr, ip_str, sizeof(ip_str)) == nullptr, "convert ext_address_listen_root ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); uint16_t port = ntohs(ipv4->sin_port); - SHM_LOG_INFO(" Ext_address_listen_root socket: Type: IPv6, IP: " << ip_str << ", Port: " << port << ", sa_family: " << ipv4->sin_family); + SHM_LOG_INFO(" Ext_address_listen_root socket: Type: IPv4, IP: " << ip_str << ", Port: " << port << ", sa_family: " << ipv4->sin_family); } else if (info.ext_address_listen_root.type == ADDR_IPv6) { struct sockaddr_in6* ipv6 = &info.ext_address_listen_root.addr.addr6; -- Gitee From 87f0f311edeff6aa4148f68df6df52354a8f62ce Mon Sep 17 00:00:00 2001 From: zhu-wangyi Date: Thu, 4 Dec 2025 19:56:13 +0800 Subject: [PATCH 72/74] Suppport pyshmem --- include/host/shmem_host_init.h | 31 ++++++------------- .../init/init_backends/mf/shmemi_init_mf.cpp | 5 +++ .../init/init_backends/mf/shmemi_init_mf.h | 1 + src/host/init/shmem_init.cpp | 8 +++++ src/host/python_wrapper/pyshmem.cpp | 17 +++++----- src/python/test.py | 16 ++++++++-- 6 files changed, 46 insertions(+), 32 deletions(-) diff --git a/include/host/shmem_host_init.h b/include/host/shmem_host_init.h index 651c0e37..acdebfb7 100644 --- a/include/host/shmem_host_init.h +++ b/include/host/shmem_host_init.h @@ -64,25 +64,11 @@ SHMEM_HOST_API int shmem_set_data_op_engine_type(shmem_init_attr_t *attributes, */ SHMEM_HOST_API int shmem_set_timeout(shmem_init_attr_t *attributes, uint32_t value); -// /** -// * @brief get the unique id and return it by intput argument uid. This function need run with PTA. -// * -// * @param uid [out] a ptr to uid generate by shmem -// * @return Returns 0 on success or an error code on failure -// */ -// SHMEM_HOST_API int shmem_get_uniqueid(shmem_uniqueid_t *uid); - -// /** -// * @brief init process with unique id. This function need run with PTA. -// * -// * @param rank_id [in] current rank id -// * @param nranks [in] total ranks -// * @param uid [in] a ptr to uid, generated by shmem_get_uniqueid -// * @param attr [out] a ptr to shmem_init_attr_t -// * @return Returns 0 on success or an error code on failure -// */ -// SHMEM_HOST_API int shmem_set_attr_uniqueid_args(int rank_id, int nranks, const shmem_uniqueid_t *uid, shmem_init_attr_t *attr); +SHMEM_HOST_API int32_t shmem_get_uniqueid(shmemx_uniqueid_t *uid); +SHMEM_HOST_API int shmemx_set_attr_uniqueid_args(const int my_rank, const int n_ranks, const int64_t local_mem_size, + const shmemx_uniqueid_t *uid, + shmem_init_attr_t **shmem_attr); /** * @brief Initialize the resources required for SHMEM task based on attributes. * Attributes can be created by users or obtained by calling shmem_set_attr(). @@ -126,11 +112,12 @@ SHMEM_HOST_API void shmem_info_get_name(char *name); */ SHMEM_HOST_API int32_t shmem_set_log_level(int level); -SHMEM_HOST_API int shmemx_set_attr_uniqueid_args(const int my_rank, const int n_ranks, const int64_t local_mem_size, - const shmemx_uniqueid_t *uid, - shmem_init_attr_t **shmem_attr); +SHMEM_HOST_API int32_t shmem_set_config_store_tls_key(const char *tls_pk, const uint32_t tls_pk_len, + const char *tls_pk_pw, const uint32_t tls_pk_pw_len, const shmem_decrypt_handler decrypt_handler); -SHMEM_HOST_API int32_t shmem_get_uniqueid(shmemx_uniqueid_t *uid); +SHMEM_HOST_API int32_t shmem_set_extern_logger(void (*func)(int level, const char *msg)); + +SHMEM_HOST_API void shmem_global_exit(int status); SHMEM_HOST_API int32_t shmem_set_conf_store_tls(bool enable, const char *tls_info, const uint32_t tls_info_len); diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.cpp b/src/host/init/init_backends/mf/shmemi_init_mf.cpp index b8a9a3ba..3652ebfb 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.cpp +++ b/src/host/init/init_backends/mf/shmemi_init_mf.cpp @@ -582,4 +582,9 @@ int32_t shmemi_control_barrier_all_mf() return SHMEM_SUCCESS; } +void shmemi_global_exit_mf(int status) +{ + smem_shm_global_exit(g_smem_handle, status); +} + #endif \ No newline at end of file diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.h b/src/host/init/init_backends/mf/shmemi_init_mf.h index 830ea943..f0999771 100644 --- a/src/host/init/init_backends/mf/shmemi_init_mf.h +++ b/src/host/init/init_backends/mf/shmemi_init_mf.h @@ -49,5 +49,6 @@ private: int32_t shmem_get_uniqueid_mf(shmemx_uniqueid_t *uid); int32_t shmemi_control_barrier_all_mf(); +void shmemi_global_exit_mf(int status); #endif // SHMEMI_INIT_MF_H \ No newline at end of file diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index 77a65fe0..0fd8969a 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -418,4 +418,12 @@ int32_t shmem_set_extern_logger(void (*func)(int level, const char *msg)) #else return SHMEM_SUCCESS; #endif +} + +void shmem_global_exit(int status) +{ +#ifdef BACKEND_MF + shmemi_global_exit_mf(status); +#else +#endif } \ No newline at end of file diff --git a/src/host/python_wrapper/pyshmem.cpp b/src/host/python_wrapper/pyshmem.cpp index 80130610..baa807d3 100644 --- a/src/host/python_wrapper/pyshmem.cpp +++ b/src/host/python_wrapper/pyshmem.cpp @@ -56,7 +56,7 @@ inline std::string get_connect_url() int shmem_initialize(shmem_init_attr_t &attributes) { - auto ret = shmem_init_attr(&attributes); + auto ret = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, &attributes); if (ret != 0) { std::cerr << "initialize shmem failed, ret: " << ret; return ret; @@ -67,7 +67,7 @@ int shmem_initialize(shmem_init_attr_t &attributes) py::bytes shmem_get_unique_id() { - shmem_uniqueid_t uid; + shmemx_uniqueid_t uid; auto ret = shmem_get_uniqueid(&uid); if (ret != 0) { std::cerr << "get unique id failed " << ret << std::endl; @@ -77,22 +77,22 @@ py::bytes shmem_get_unique_id() int shmem_initialize_unique_id(int rank, int world_size, int64_t mem_size, const std::string &bytes) { - if (bytes.size() < sizeof(shmem_uniqueid_t)) { + if (bytes.size() < sizeof(shmemx_uniqueid_t)) { std::cerr << "Error: Input bytes size (" << bytes.size() - << ") is smaller than required size (" << sizeof(shmem_uniqueid_t) + << ") is smaller than required size (" << sizeof(shmemx_uniqueid_t) << ")." << std::endl; return -1; } - shmem_uniqueid_t uid; + shmemx_uniqueid_t uid; memcpy(&uid, bytes.data(), sizeof(uid)); shmem_init_attr_t *attr; - auto ret = shmem_set_attr(rank, world_size, mem_size, nullptr, &attr); + auto ret = shmemx_set_attr_uniqueid_args(rank, world_size, mem_size, &uid, &attr); if (ret != 0) { std::cerr << "set attr failed " << ret << std::endl; return ret; } - return shmem_set_attr_uniqueid_args(rank, world_size, &uid, attr); + return shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attr); } int32_t shmem_set_op_engine_type(shmem_init_attr_t &attributes, data_op_engine_type_t value) @@ -202,7 +202,8 @@ void DefineShmemAttr(py::module_ &m) .def_readwrite("data_op_engine_type", &shmem_init_optional_attr_t::data_op_engine_type) .def_readwrite("shm_init_timeout", &shmem_init_optional_attr_t::shm_init_timeout) .def_readwrite("shm_create_timeout", &shmem_init_optional_attr_t::shm_create_timeout) - .def_readwrite("control_operation_timeout", &shmem_init_optional_attr_t::control_operation_timeout); + .def_readwrite("control_operation_timeout", &shmem_init_optional_attr_t::control_operation_timeout) + .def_readwrite("sockFd", &shmem_init_optional_attr_t::sockFd); py::class_(m, "InitAttr") .def(py::init([]() { diff --git a/src/python/test.py b/src/python/test.py index e7307a93..60d8e1ca 100644 --- a/src/python/test.py +++ b/src/python/test.py @@ -121,6 +121,17 @@ def run_tests(): attributes.n_ranks = world_size attributes.local_mem_size = g_ash_size attributes.ip_port = G_IP_PORT + + optional_attr = ash.OptionalAttr() + optional_attr.version = 1 + optional_attr.data_op_engine_type = ash.OpEngineType.MTE + optional_attr.shm_init_timeout = 120 # second + optional_attr.shm_create_timeout = 120 # second + optional_attr.control_operation_timeout = 120 # second + optional_attr.sockFd = 0 + + attributes.option_attr = optional_attr + ret = ash.shmem_init(attributes) if ret != 0: raise ValueError('[ERROR] shmem_init failed') @@ -168,6 +179,7 @@ if __name__ == "__main__": torch.npu.set_device(local_rank) dist.init_process_group(backend="hccl", rank=local_rank) run_tests() - run_register_decrypt_tests() - exit_test() + # Not supported + # run_register_decrypt_tests() + # exit_test() print("test.py running success!") -- Gitee From 1fbf07ab46c426d964daf7c027341d131ed684e2 Mon Sep 17 00:00:00 2001 From: zhangyunqi05 Date: Fri, 5 Dec 2025 17:02:36 +0800 Subject: [PATCH 73/74] uid init by ipport --- include/host/shmem_host_def.h | 1 + src/host/bootstrap/shmemi_bootstrap.cpp | 39 ++++- src/host/common/shmemi_host_types.h | 4 +- src/host/init/shmem_init.cpp | 8 +- .../bootstrap/shmemi_bootstrap_uid.cpp | 152 ++++++++++-------- tests/unittest/host/init/init_host_test.cpp | 84 ++-------- tests/unittest/host/main_test.cpp | 38 +---- .../host/mem/shmem_host_heap_test.cpp | 19 +-- 8 files changed, 155 insertions(+), 190 deletions(-) diff --git a/include/host/shmem_host_def.h b/include/host/shmem_host_def.h index 045367b3..eae88743 100644 --- a/include/host/shmem_host_def.h +++ b/include/host/shmem_host_def.h @@ -89,6 +89,7 @@ enum shmem_error_code_t : int { enum shmemx_bootstrap_t : int { SHMEMX_INIT_WITH_UNIQUEID = 1, SHMEMX_INIT_WITH_MPI = 1 << 1, + SHMEMX_INIT_WITH_DEFAULT = 1 << 2, }; /** diff --git a/src/host/bootstrap/shmemi_bootstrap.cpp b/src/host/bootstrap/shmemi_bootstrap.cpp index 9f2dc10b..6e5ed246 100644 --- a/src/host/bootstrap/shmemi_bootstrap.cpp +++ b/src/host/bootstrap/shmemi_bootstrap.cpp @@ -97,13 +97,49 @@ int32_t shmemi_bootstrap_pre_init(int flags, shmemi_bootstrap_handle_t *handle) return status; } +void remove_tcp_prefix_and_copy(const char* input, char* output, size_t output_len) { + memset(output, 0, output_len); + if (output_len == 0) return; + + if (input == nullptr || strlen(input) == 0) { + return; + } + + const char* prefix_tcp = "tcp://"; + const char* prefix_tcp6 = "tcp6://"; + size_t len_tcp = strlen(prefix_tcp); + size_t len_tcp6 = strlen(prefix_tcp6); + const char* result_ptr = input; + + if (strncmp(input, prefix_tcp, len_tcp) == 0) { + result_ptr = input + len_tcp; + } + else if (strncmp(input, prefix_tcp6, len_tcp6) == 0) { + result_ptr = input + len_tcp6; + } + + strncpy(output, result_ptr, output_len - 1); + output[output_len - 1] = '\0'; +} + int32_t shmemi_bootstrap_init(int flags, shmem_init_attr_t *attr) { int32_t status = SHMEM_SUCCESS; void *arg; - if (flags & SHMEMX_INIT_WITH_MPI) { + g_boot_handle.use_attr_ipport= false; + if (flags & SHMEMX_INIT_WITH_DEFAULT){ + SHM_LOG_INFO("SHMEMX_INIT_WITH_DEFAULT"); + g_boot_handle.use_attr_ipport= true; + remove_tcp_prefix_and_copy(attr->ip_port, + g_boot_handle.ipport, + sizeof(g_boot_handle.ipport)); + plugin_name = BOOTSTRAP_MODULE_UID; + arg = (attr != NULL) ? attr->comm_args : NULL; + } else if (flags & SHMEMX_INIT_WITH_MPI) { + SHM_LOG_INFO("SHMEMX_INIT_WITH_MPI"); plugin_name = BOOTSTRAP_MODULE_MPI; arg = (attr != NULL) ? attr->comm_args : NULL; } else if (flags & SHMEMX_INIT_WITH_UNIQUEID) { + SHM_LOG_INFO("SHMEMX_INIT_WITH_UNIQUEID"); plugin_name = BOOTSTRAP_MODULE_UID; arg = (attr != NULL) ? attr->comm_args : NULL; } else { @@ -125,6 +161,7 @@ int32_t shmemi_bootstrap_init(int flags, shmem_init_attr_t *attr) { shmemi_bootstrap_free(); return SHMEM_INNER_ERROR; } + SHM_LOG_INFO("plugin_init"); status = plugin_init(arg, &g_boot_handle); if (status != 0) { SHM_LOG_ERROR("Bootstrap plugin init failed for " << plugin_name); diff --git a/src/host/common/shmemi_host_types.h b/src/host/common/shmemi_host_types.h index 31897db7..45ef8aeb 100644 --- a/src/host/common/shmemi_host_types.h +++ b/src/host/common/shmemi_host_types.h @@ -13,7 +13,7 @@ #define SHMEM_MAX_TRANSPORT_NUM 16 #include "internal/host_device/shmemi_types.h" - +#define SHMEM_MAX_HANDLE_IP_PORT_LEN 64 typedef struct shmemi_bootstrap_attr { shmemi_bootstrap_attr() : initialize_mf(0), mpi_comm(NULL), uid_args(NULL) {} @@ -40,6 +40,8 @@ typedef struct shmemi_bootstrap_handle { void (*global_exit)(int status); shmemi_bootstrap_init_ops_t *pre_init_ops; bool is_bootstraped = false; + char ipport[SHMEM_MAX_HANDLE_IP_PORT_LEN]; + bool use_attr_ipport = false; } shmemi_bootstrap_handle_t; typedef struct shmemi_bootstrap_mpi_options { diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index 9bd4511c..34bff733 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -61,6 +61,7 @@ constexpr int DEFAULT_BLOCK_NUM = 1; shmemi_device_host_state_t g_state = SHMEM_DEVICE_HOST_STATE_INITIALIZER; shmemi_host_state_t g_state_host = {nullptr, DEFAULT_TEVENT, DEFAULT_BLOCK_NUM}; shmem_init_attr_t g_attr; +shmemx_uniqueid_t default_flag_uid; static bool g_attr_init = false; static char g_ipport[SHMEM_MAX_IP_PORT_LEN] = {0}; @@ -172,6 +173,10 @@ int32_t shmem_set_attr(int32_t my_rank, int32_t n_ranks, uint64_t local_mem_size g_attr.local_mem_size = local_mem_size; g_attr.option_attr = {attr_version, SHMEM_DATA_OP_MTE, DEFAULT_TIMEOUT, DEFAULT_TIMEOUT, DEFAULT_TIMEOUT}; + g_attr.comm_args = reinterpret_cast(&default_flag_uid); + shmemx_bootstrap_uid_state_t *uid_args = (shmemx_bootstrap_uid_state_t *)(g_attr.comm_args); + uid_args->rank = my_rank; + uid_args->nranks = n_ranks; g_attr_init = true; return SHMEM_SUCCESS; } @@ -248,6 +253,7 @@ int32_t shmem_init_attr(shmemx_bootstrap_t bootstrap_flags, shmem_init_attr_t *a // shmem basic init #ifdef BACKEND_MF SHM_LOG_INFO("The current backend is MF."); + SHMEM_CHECK_RET(bootstrap_flags != SHMEMX_INIT_WITH_DEFAULT, "The current backend is MF, and the value of bootstrap_flags only supports SHMEMX_INIT_WITH_DEFAULT.", SHMEM_INVALID_PARAM); init_manager = new shmemi_init_mf(attributes, g_ipport, &g_state); #else SHM_LOG_INFO("The current backend is SHMEM default."); @@ -356,9 +362,9 @@ int32_t shmemi_get_uniqueid_static_magic(shmemx_uniqueid_t *uid, bool is_root) { status = SHMEM_INVALID_PARAM; } return SHMEM_SUCCESS; - } + int32_t shmem_set_log_level(int level) { // use env first, input level secondly, user may change level from env instead call func diff --git a/src/modules/bootstrap/shmemi_bootstrap_uid.cpp b/src/modules/bootstrap/shmemi_bootstrap_uid.cpp index 319c5d75..37895cb1 100644 --- a/src/modules/bootstrap/shmemi_bootstrap_uid.cpp +++ b/src/modules/bootstrap/shmemi_bootstrap_uid.cpp @@ -63,7 +63,7 @@ bool is_ipv6_link_local(const struct in6_addr *addr6) { return false; } - SHM_LOG_DEBUG("It is fe80 address."); + SHM_LOG_DEBUG("It is IPv6 link-local address (fe80::/10)."); return true; } @@ -71,6 +71,47 @@ bool is_ipv4_loopback(const struct in_addr *addr4) { return ((ntohl(addr4->s_addr) >> 24) & 0xFF) == IN_LOOPBACKNET; } +bool is_ipv4_link_local(const struct in_addr *addr4) { + if (addr4 == nullptr) { + return false; + } + + uint32_t ip_addr = ntohl(addr4->s_addr); + + uint8_t byte1 = (ip_addr >> 24) & 0xff; + uint8_t byte2 = (ip_addr >> 16) & 0xff; + if (byte1 != 169 || byte2 != 254) { + return false; + } + SHM_LOG_DEBUG("It is IPv4 link-local address (169.254.x.x)."); + return true; +} + +static bool is_loopback_addr(const sockaddr_t* addr) { + if (addr == nullptr) { + return false; + } + if (addr->type == ADDR_IPv4) { + return is_ipv4_loopback(&addr->addr.addr4.sin_addr); + } else if (addr->type == ADDR_IPv6) { + return is_ipv6_loopback(&addr->addr.addr6.sin6_addr); + } else { + return false; + } +} + +static bool is_link_local_addr(const sockaddr_t* addr) { + if (addr == nullptr) { + return false; + } + if (addr->type == ADDR_IPv4) { + return is_ipv4_link_local(&addr->addr.addr4.sin_addr); + } else if (addr->type == ADDR_IPv6) { + return is_ipv6_link_local(&addr->addr.addr6.sin6_addr); + } else { + return false; + } +} static int32_t shmemi_get_uid_magic(shmemx_bootstrap_uid_state_t *innerUId) { std::ifstream urandom("/dev/urandom", std::ios::binary); @@ -161,13 +202,17 @@ int32_t shmemi_traverse_ifa( if (!allow_local) { if (ifa->ifa_addr->sa_family == AF_INET) { struct sockaddr_in *addr4 = (struct sockaddr_in *)ifa->ifa_addr; + if (is_ipv4_link_local(&addr4->sin_addr)) { + SHM_LOG_INFO("Blocked ipv4 link local address."); + continue; + } if (is_ipv4_loopback(&addr4->sin_addr)) { is_invalid_addr = true; } } else if (ifa->ifa_addr->sa_family == AF_INET6) { struct sockaddr_in6 *addr6 = (struct sockaddr_in6 *)ifa->ifa_addr; if (is_ipv6_link_local(&addr6->sin6_addr)) { - SHM_LOG_INFO("Blocked fe80 address."); + SHM_LOG_INFO("Blocked ipv6 link local address."); continue; } if (is_ipv6_loopback(&addr6->sin6_addr)) { @@ -205,32 +250,6 @@ int32_t shmemi_traverse_ifa( return SHMEM_INVALID_PARAM; } -char* format_master_addr(const char* raw_addr) { - if (raw_addr == nullptr || strlen(raw_addr) == 0) { - return nullptr; - } - bool is_ipv6 = (strchr(raw_addr, ':') != nullptr); - int raw_len = strlen(raw_addr); - int alloc_len = is_ipv6 ? (raw_len + 5) : (raw_len + 3); - - char* formatted_addr = (char*)malloc(alloc_len); - if (formatted_addr == nullptr) { - SHM_LOG_ERROR("malloc formatted addr failed, len=" << alloc_len); - return nullptr; - } - - if (is_ipv6) { - strcpy(formatted_addr, "["); - strcat(formatted_addr, raw_addr); - strcat(formatted_addr, "]:0"); - } else { - strcpy(formatted_addr, raw_addr); - strcat(formatted_addr, ":0"); - } - - return formatted_addr; -} - int32_t shmemi_get_ip_from_env(shmemx_bootstrap_uid_state_t *uid_args, const char *ipPort) { if (uid_args == nullptr || ipPort == nullptr || strlen(ipPort) == 0) { SHM_LOG_ERROR("Invalid param: uid_args is null or ipPort is empty"); @@ -348,26 +367,6 @@ int32_t shmemi_get_ip_from_ifa(shmemx_bootstrap_uid_state_t *uid_args, const cha SHM_LOG_INFO("Step 1: Search interfaces match 'eth'"); foundValidIp = (shmemi_traverse_ifa(ifaddr, sockType, flag, ethPrefixes, false, uid_args, false, allow_local) == SHMEM_SUCCESS); - - - char* temp_ip_port = nullptr; - if (!foundValidIp) { - SHM_LOG_WARN("Step 2: Failed to get any valid IP address from eth interfaces, trying to use torch's MASTER_ADDR."); - const char* env_master_addr = std::getenv("MASTER_ADDR"); - if (env_master_addr != nullptr && strlen(env_master_addr) != 0) { - temp_ip_port = format_master_addr(env_master_addr); - if (temp_ip_port == nullptr) { - SHM_LOG_WARN("format_master_addr failed"); - } else { - SHM_LOG_INFO("MASTER_ADDR is: " << temp_ip_port); - foundValidIp = (shmemi_get_ip_from_env(uid_args, temp_ip_port) == SHMEM_SUCCESS); - free((void*)temp_ip_port); - temp_ip_port = nullptr; - } - } else { - SHM_LOG_WARN("MASTER_ADDR is not set."); - } - } if (!foundValidIp) { SHM_LOG_WARN("Step 3: Search interfaces exclude 'docker' and 'lo/fe80'"); @@ -531,19 +530,6 @@ static void shmemi_bootstrap_uid_global_exit(int status) { } -static bool is_loopback_addr(const sockaddr_t* addr) { - if (addr == nullptr) { - return false; - } - if (addr->type == ADDR_IPv4) { - return is_ipv4_loopback(&addr->addr.addr4.sin_addr); - } else if (addr->type == ADDR_IPv6) { - return is_ipv6_loopback(&addr->addr.addr6.sin6_addr); - } else { - return false; - } -} - static bool matchSubnet(struct ifaddrs local_if, sockaddr_t* remote) { int family; bool is_lo_interface = (strncmp(local_if.ifa_name, "lo", 2) == 0); @@ -621,11 +607,8 @@ static int find_interface_match_subnet(char* ifNames, sockaddr_t* localAddrs, so SHM_LOG_ERROR(" remoteAddr is NULL."); return SHMEM_BOOTSTRAP_ERROR; } - + SHMEM_CHECK_RET(is_link_local_addr(remoteAddr), "Remote address is link_local", SHMEM_BOOTSTRAP_ERROR); bool remote_is_loopback = is_loopback_addr(remoteAddr); - if (remoteAddr->type == ADDR_IPv6){ - SHMEM_CHECK_RET(is_ipv6_link_local(&remoteAddr->addr.addr6.sin6_addr), "Remote address is fe80", SHMEM_BOOTSTRAP_ERROR); - } SHM_LOG_INFO("Remote address is loopback:" << remote_is_loopback); if (remote_is_loopback) { @@ -1059,6 +1042,37 @@ int shmemi_bootstrap_get_unique_id_static_magic(void* uid, bool is_root) { return SHMEM_SUCCESS; } +int shmemi_bootstrap_get_unique_id_by_ipport(void* uid, const char *ipport) { + shmemx_bootstrap_uid_state_t* uid_args = (shmemx_bootstrap_uid_state_t*)uid; + + if (ipport != nullptr) { + env_ip_port = ipport; + SHM_LOG_DEBUG("The ipport param is: " << env_ip_port); + } else { + + SHM_LOG_DEBUG("The ipport param is not set. Try to use SHMEM_UID_SESSION_ID."); + const char* envip = std::getenv("SHMEM_UID_SESSION_ID"); + if (envip != nullptr) { + env_ip_port = envip; + SHM_LOG_DEBUG("SHMEM_UID_SESSION_ID is: " << env_ip_port); + } else { + SHM_LOG_DEBUG("The environment variable SHMEM_UID_SESSION_ID is not set."); + } + } + if (env_ip_port == nullptr) { + SHM_LOG_ERROR("Using method get_unique_id_by_ipport requires setting ipport or SHMEM_UID_SESSION_ID."); + return SHMEM_BOOTSTRAP_ERROR; + } + + SHMEM_CHECK_RET(shmemi_bootstrap_net_init(uid_args, false), "rank 0: failed to init bootstrap net."); + uid_args->magic = SOCKET_MAGIC + static_magic_count; + static_magic_count++; + if (uid_args->rank == 0) { + SHMEM_CHECK_RET(bootstrap_create_root(uid_args), "rank 0: failed to create root thread"); + } + return SHMEM_SUCCESS; +} + // Plugin pre-initialization entry function. int shmemi_bootstrap_plugin_pre_init(shmemi_bootstrap_handle_t* handle) { if (handle->pre_init_ops == nullptr) { @@ -1076,21 +1090,23 @@ int shmemi_bootstrap_plugin_pre_init(shmemi_bootstrap_handle_t* handle) { int shmemi_bootstrap_plugin_init(void* comm, shmemi_bootstrap_handle_t* handle) { - if (comm == nullptr || handle == nullptr) { SHM_LOG_ERROR(" shmemi_bootstrap_plugin_init: invalid arguments (nullptr)"); return SHMEM_BOOTSTRAP_ERROR; } - socket_t sock, listen_sock_root; uid_bootstrap_state* state = nullptr; SHMEM_CHECK_RET(SHMEM_BOOTSTRAP_CALLOC(&state, 1)); shmemx_bootstrap_uid_state_t* uid_args = (shmemx_bootstrap_uid_state_t*)comm; sockaddr_t next_addr; bootstrap_ext_info info = {}; - int rank = uid_args->rank; int nranks = uid_args->nranks; + + if (handle->use_attr_ipport && handle->ipport != nullptr) { + SHM_LOG_DEBUG("shmemi_bootstrap_get_unique_id_by_ipport start. ipport: " << handle->ipport); + shmemi_bootstrap_get_unique_id_by_ipport(comm, handle->ipport); + } uint64_t magic = uid_args->magic; SHMEM_CHECK_RET(shmemi_bootstrap_net_init(uid_args), " rank: " << rank << ": network interface init failed."); diff --git a/tests/unittest/host/init/init_host_test.cpp b/tests/unittest/host/init/init_host_test.cpp index cdb05fde..29ae01ad 100644 --- a/tests/unittest/host/init/init_host_test.cpp +++ b/tests/unittest/host/init/init_host_test.cpp @@ -30,24 +30,10 @@ void test_shmem_init(int rank_id, int n_ranks, uint64_t local_mem_size) EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); shmem_set_conf_store_tls(false, nullptr, 0); -#ifdef BACKEND_MF shmem_init_attr_t *attributes; shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); -#else - shmemx_uniqueid_t uid; - if (rank_id == 0) { - status = shmemi_get_uniqueid_static_magic(&uid, true); - } else { - status = shmemi_get_uniqueid_static_magic(&uid, false); - } - EXPECT_EQ(status, SHMEM_SUCCESS); - shmem_init_attr_t* attributes; - status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, - &uid, - &attributes); - EXPECT_EQ(status, SHMEM_SUCCESS); -#endif - status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); EXPECT_EQ(status, SHMEM_SUCCESS); EXPECT_EQ(g_state.mype, rank_id); EXPECT_EQ(g_state.npes, n_ranks); @@ -74,22 +60,11 @@ void test_shmem_init_invalid_rank_id(int rank_id, int n_ranks, uint64_t local_me EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); shmem_set_conf_store_tls(false, nullptr, 0); -#ifdef BACKEND_MF + shmem_init_attr_t *attributes; shmem_set_attr(erank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); -#else - shmemx_uniqueid_t uid; - if (rank_id == 0) { - shmemi_get_uniqueid_static_magic(&uid, true); - } else { - shmemi_get_uniqueid_static_magic(&uid, false); - } - shmem_init_attr_t* attributes; - shmemx_set_attr_uniqueid_args(erank_id, n_ranks, local_mem_size, - &uid, - &attributes); -#endif - status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); EXPECT_EQ(status, SHMEM_INVALID_VALUE); status = shmem_init_status(); @@ -110,21 +85,11 @@ void test_shmem_init_invalid_n_ranks(int rank_id, int n_ranks, uint64_t local_me EXPECT_EQ(status = aclrtSetDevice(device_id), 0); shmem_set_conf_store_tls(false, nullptr, 0); shmemx_uniqueid_t uid; -#ifdef BACKEND_MF + shmem_init_attr_t *attributes; shmem_set_attr(rank_id, en_ranks, local_mem_size, test_global_ipport, &attributes); -#else - if (rank_id == 0) { - shmemi_get_uniqueid_static_magic(&uid, true); - } else { - shmemi_get_uniqueid_static_magic(&uid, false); - } - shmem_init_attr_t* attributes; - shmemx_set_attr_uniqueid_args(rank_id, en_ranks, local_mem_size, - &uid, - &attributes); -#endif - status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); EXPECT_EQ(status, SHMEM_INVALID_VALUE); status = shmem_init_status(); @@ -143,22 +108,11 @@ void test_shmem_init_rank_id_over_size(int rank_id, int n_ranks, uint64_t local_ EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); shmem_set_conf_store_tls(false, nullptr, 0); -#ifdef BACKEND_MF + shmem_init_attr_t *attributes; shmem_set_attr(rank_id + n_ranks, n_ranks, local_mem_size, test_global_ipport, &attributes); -#else - shmemx_uniqueid_t uid; - if (rank_id == 0) { - shmemi_get_uniqueid_static_magic(&uid, true); - } else { - shmemi_get_uniqueid_static_magic(&uid, false); - } - shmem_init_attr_t* attributes; - shmemx_set_attr_uniqueid_args(rank_id + n_ranks, n_ranks, local_mem_size, - &uid, - &attributes); -#endif - status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); EXPECT_EQ(status, SHMEM_INVALID_PARAM); status = shmem_init_status(); EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); @@ -177,22 +131,10 @@ void test_shmem_init_zero_mem(int rank_id, int n_ranks, uint64_t local_mem_size) EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); shmem_set_conf_store_tls(false, nullptr, 0); -#ifdef BACKEND_MF shmem_init_attr_t *attributes; shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); -#else - shmemx_uniqueid_t uid; - if (rank_id == 0) { - shmemi_get_uniqueid_static_magic(&uid, true); - } else { - shmemi_get_uniqueid_static_magic(&uid, false); - } - shmem_init_attr_t* attributes; - shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, - &uid, - &attributes); -#endif - status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); EXPECT_EQ(status, SHMEM_INVALID_VALUE); status = shmem_init_status(); EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); diff --git a/tests/unittest/host/main_test.cpp b/tests/unittest/host/main_test.cpp index e278ef63..a3eb98d9 100644 --- a/tests/unittest/host/main_test.cpp +++ b/tests/unittest/host/main_test.cpp @@ -36,24 +36,11 @@ void test_init(int rank_id, int n_ranks, uint64_t local_mem_size, aclrtStream *s aclrtStream stream = nullptr; EXPECT_EQ(status = aclrtCreateStream(&stream), 0); EXPECT_EQ(status = shmem_set_conf_store_tls(false, nullptr, 0), 0); -#ifdef BACKEND_MF + shmem_init_attr_t *attributes; shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); -#else - shmemx_uniqueid_t uid; - if (rank_id == 0) { - status = shmemi_get_uniqueid_static_magic(&uid, true); - } else { - status = shmemi_get_uniqueid_static_magic(&uid, false); - } - EXPECT_EQ(status, SHMEM_SUCCESS); - shmem_init_attr_t* attributes; - status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, - &uid, - &attributes); - EXPECT_EQ(status, SHMEM_SUCCESS); -#endif - status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); EXPECT_EQ(status, 0); *st = stream; @@ -74,25 +61,12 @@ int32_t test_rdma_init(int rank_id, int n_ranks, uint64_t local_mem_size, aclrtS aclrtStream stream = nullptr; EXPECT_EQ(status = aclrtCreateStream(&stream), 0); EXPECT_EQ(status = shmem_set_conf_store_tls(false, nullptr, 0), 0); -#ifdef BACKEND_MF + shmem_init_attr_t *attributes; shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); -#else - shmemx_uniqueid_t uid; - if (rank_id == 0) { - status = shmemi_get_uniqueid_static_magic(&uid, true); - } else { - status = shmemi_get_uniqueid_static_magic(&uid, false); - } - EXPECT_EQ(status, SHMEM_SUCCESS); - shmem_init_attr_t* attributes; - status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, - &uid, - &attributes); - EXPECT_EQ(status, SHMEM_SUCCESS); -#endif + attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; - status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); EXPECT_EQ(status, 0); *st = stream; diff --git a/tests/unittest/host/mem/shmem_host_heap_test.cpp b/tests/unittest/host/mem/shmem_host_heap_test.cpp index 2fb5d7b4..2a9a85f2 100644 --- a/tests/unittest/host/mem/shmem_host_heap_test.cpp +++ b/tests/unittest/host/mem/shmem_host_heap_test.cpp @@ -27,24 +27,11 @@ protected: int status = SHMEM_SUCCESS; EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); - #ifdef BACKEND_MF + shmem_init_attr_t *attributes; shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); - #else - shmemx_uniqueid_t uid; - if (rank_id == 0) { - status = shmemi_get_uniqueid_static_magic(&uid, true); - } else { - status = shmemi_get_uniqueid_static_magic(&uid, false); - } - EXPECT_EQ(status, SHMEM_SUCCESS); - shmem_init_attr_t* attributes; - status = shmemx_set_attr_uniqueid_args(rank_id, n_ranks, local_mem_size, - &uid, - &attributes); - EXPECT_EQ(status, SHMEM_SUCCESS); - #endif - status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); + + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); EXPECT_EQ(status, SHMEM_SUCCESS); EXPECT_EQ(g_state.mype, rank_id); EXPECT_EQ(g_state.npes, n_ranks); -- Gitee From f5bb0d540235ff50c9165b9f2183b82964c427b3 Mon Sep 17 00:00:00 2001 From: zhangyunqi05 Date: Mon, 8 Dec 2025 17:03:10 +0800 Subject: [PATCH 74/74] barrierv2 && ut link shmem.so && add log and close sock when failed --- include/host/shmem_host_def.h | 3 +- .../bootstrap/shmemi_bootstrap_uid.cpp | 169 ++++++++++++++++-- src/modules/bootstrap/socket/uid_socket.cpp | 28 ++- src/modules/bootstrap/socket/uid_socket.h | 11 ++ src/modules/bootstrap/socket/uid_utils.h | 22 ++- tests/unittest/CMakeLists.txt | 9 +- tests/unittest/host/init/init_host_test.cpp | 11 ++ .../host/mem/shmem_host_heap_test.cpp | 2 +- 8 files changed, 233 insertions(+), 22 deletions(-) diff --git a/include/host/shmem_host_def.h b/include/host/shmem_host_def.h index eae88743..a3a227f1 100644 --- a/include/host/shmem_host_def.h +++ b/include/host/shmem_host_def.h @@ -80,7 +80,8 @@ enum shmem_error_code_t : int { SHMEM_SMEM_ERROR = -3, ///< There is a problem with SMEM. SHMEM_INNER_ERROR = -4, ///< This is a problem caused by an internal error. SHMEM_NOT_INITED = -5, ///< This is a problem caused by an uninitialization. - SHMEM_BOOTSTRAP_ERROR = -6, ///< This is a problem caused by an uninitialization. + SHMEM_BOOTSTRAP_ERROR = -6,///< This is a problem with BOOTSTRAP. + SHMEM_TIMEOUT_ERROR = -7, ///< This is a problem caused by TIMEOUT. }; /** diff --git a/src/modules/bootstrap/shmemi_bootstrap_uid.cpp b/src/modules/bootstrap/shmemi_bootstrap_uid.cpp index 37895cb1..bf496a12 100644 --- a/src/modules/bootstrap/shmemi_bootstrap_uid.cpp +++ b/src/modules/bootstrap/shmemi_bootstrap_uid.cpp @@ -437,7 +437,14 @@ static int shmemi_bootstrap_uid_finalize(shmemi_bootstrap_handle_t *handle) { if (handle->bootstrap_state) { uid_bootstrap_state* state = (uid_bootstrap_state*) handle->bootstrap_state; - + unexpected_conn_t* elem = state->unexpected_conns; + while (elem != NULL) { + unexpected_conn_t* next = elem->next; + socket_close(&elem->sock); // 关闭socket句柄 + SHMEM_BOOTSTRAP_PTR_FREE(elem); + elem = next; + } + state->unexpected_conns = NULL; socket_close(&state->listen_sock); socket_close(&state->ring_send_sock); socket_close(&state->ring_recv_sock); @@ -520,6 +527,142 @@ static int shmemi_bootstrap_uid_barrier(shmemi_bootstrap_handle_t *handle) { } return SHMEM_SUCCESS; } +static int unexpected_dequeue(uid_bootstrap_state* state, int peer, int tag, socket_t* sock, int* found) { + SHM_LOG_INFO("unexpected_dequeue start."); + if (state == NULL || sock == NULL || found == NULL) { + return SHMEM_BOOTSTRAP_ERROR; + } + + unexpected_conn_t* elem = state->unexpected_conns; + unexpected_conn_t* prev = NULL; + *found = 0; + while (elem != NULL) { + if (elem->peer == peer && elem->tag == tag) { + if (prev == NULL) { + state->unexpected_conns = elem->next; + } else { + prev->next = elem->next; + } + + memcpy(sock, &elem->sock, sizeof(socket_t)); + SHMEM_BOOTSTRAP_PTR_FREE(elem); + *found = 1; + return SHMEM_SUCCESS; + } + + prev = elem; + elem = elem->next; + } + return SHMEM_SUCCESS; +} + +static int unexpected_enqueue(uid_bootstrap_state* state, int peer, int tag, socket_t* sock) { + SHM_LOG_INFO("unexpected_enqueue start."); + if (state == NULL || sock == NULL) { + return SHMEM_BOOTSTRAP_ERROR; + } + + unexpected_conn_t* new_conn = NULL; + SHMEM_BOOTSTRAP_CALLOC(&new_conn, 1); + if (new_conn == NULL) { + return SHMEM_BOOTSTRAP_ERROR; + } + + new_conn->peer = peer; + new_conn->tag = tag; + memcpy(&new_conn->sock, sock, sizeof(socket_t)); + new_conn->next = NULL; + if (state->unexpected_conns == NULL) { + state->unexpected_conns = new_conn; + } else { + new_conn->next = state->unexpected_conns; + state->unexpected_conns = new_conn; + } + + return SHMEM_SUCCESS; +} +static int bootstrap_send(void* comm_state, int peer, int tag, void* data, int size) { + if (comm_state == nullptr || data == nullptr || size < 0 || peer < 0) { + SHM_LOG_ERROR("bootstrap_send: invalid arguments"); + return SHMEM_BOOTSTRAP_ERROR; + } + + uid_bootstrap_state* state = (uid_bootstrap_state*)comm_state; + socket_t sock; + SHMEM_CHECK_RET(socket_init(&sock, SOCKET_TYPE_BOOTSTRAP, state->magic, &state->peer_addrs[peer]), "bootstrap_send: socket_init failed for peer " << peer); + + SHMEM_CHECK_RET_CLOSE_SOCK(socket_connect(&sock), "bootstrap_send: socket_connect failed for peer " << peer, sock); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_send(&sock, &state->rank, sizeof(int)), "bootstrap_send: send rank failed to peer " << peer, sock); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_send(&sock, &tag, sizeof(int)), "bootstrap_send: send tag " << tag << " failed to peer " << peer, sock); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_send(&sock, data, size), "bootstrap_send: send data (size=" << size << ") failed to peer " << peer, sock); + if (sock.fd >= 0) { + socket_close(&sock); + } + return SHMEM_SUCCESS; +} + + +static int bootstrap_recv(void* comm_state, int peer, int tag, void* data, int size) { + if (comm_state == NULL || data == NULL || size < 0 || peer < 0) { + return SHMEM_BOOTSTRAP_ERROR; + } + + uid_bootstrap_state* state = (uid_bootstrap_state*)comm_state; + socket_t sock; + int found = 0; + int retry_count = 0; + int ret = SHMEM_SUCCESS; + SHMEM_CHECK_RET(unexpected_dequeue(state, peer, tag, &sock, &found)); + + if (found == 1) { + ret = socket_recv(&sock, data, size); + socket_close(&sock); + return (ret == SHMEM_SUCCESS) ? SHMEM_SUCCESS : SHMEM_BOOTSTRAP_ERROR; + } + while (1) { + socket_t new_sock; + int new_peer = -1; + int new_tag = -1; + SHMEM_CHECK_RET(socket_init(&new_sock, SOCKET_TYPE_BOOTSTRAP, SOCKET_MAGIC, NULL), "socket_init new_sock failed"); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_accept(&new_sock, &state->listen_sock), "socket_accept new_sock failed", new_sock); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_recv(&new_sock, &new_peer, sizeof(int)), "socket_recv new_peer failed", new_sock); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_recv(&new_sock, &new_tag, sizeof(int)), "socket_recv new_tag failed", new_sock); + if (new_peer == peer && new_tag == tag) { + SHMEM_CHECK_RET_CLOSE_SOCK(socket_recv(&new_sock, data, size), "socket_recv failed", new_sock); + return SHMEM_SUCCESS; + } else { + SHMEM_CHECK_RET_CLOSE_SOCK(unexpected_enqueue(state, new_peer, new_tag, &new_sock), "unexpected_enqueue failed", new_sock); + } + } +} + +static int shmemi_bootstrap_uid_barrier_v2(shmemi_bootstrap_handle_t *handle) { + SHM_LOG_INFO("shmemi_bootstrap_uid_barrier_v2"); + uid_bootstrap_state* state = (uid_bootstrap_state*)(handle->bootstrap_state); + int rank = state->rank; + int tag = 0; + int nranks = state->nranks; + + if (nranks == 1) { + SHM_LOG_DEBUG("Single rank, skip barrier"); + return SHMEM_SUCCESS; + } + + SHM_LOG_DEBUG("Barrier start. rank: " << rank << " nranks: " << nranks <<" tag: "<< tag); + + int data[1]; + for (int mask = 1; mask < nranks; mask <<= 1) { + int src = (rank - mask + nranks) % nranks; + int dst = (rank + mask) % nranks; + tag++; + + SHMEM_CHECK_RET(bootstrap_send(state, dst, tag, data, sizeof(data)), "rank " << rank << ": barrier send failed, dst: " << dst << "tag: " << tag); + SHMEM_CHECK_RET(bootstrap_recv(state, src, tag, data, sizeof(data)), "rank " << rank << ": barrier recv failed, src: " << src << "tag: " << tag); + } + + SHM_LOG_DEBUG("Barrier end. rank: " << rank << " nranks: " << nranks <<" tag: "<< tag); + return SHMEM_SUCCESS; +} static int shmemi_bootstrap_uid_alltoall(const void *sendbuf, void *recvbuf, int length, shmemi_bootstrap_handle_t *handle) { @@ -957,7 +1100,7 @@ static int bootstrap_create_root(shmemx_bootstrap_uid_state_t* uid_args) { // 2. Initialize the listening socket (using the global network interface address) SHMEM_CHECK_RET(socket_init(listen_sock_root, SOCKET_TYPE_BOOTSTRAP, uid_args->magic, &uid_args->addr), "bootstrap_create_root: socket_init failed"); - SHMEM_CHECK_RET(socket_listen(listen_sock_root), "Listen_sock_root failed while executing listen. fd=" << listen_sock_root->fd); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_listen(listen_sock_root), "Listen_sock_root failed while executing listen. fd=" << listen_sock_root->fd, *listen_sock_root); // 3. Write the root node's listening address into uid_args (for slave nodes to connect to). memcpy(&uid_args->addr, &listen_sock_root->addr, sizeof(sockaddr_t)); @@ -1133,20 +1276,20 @@ int shmemi_bootstrap_plugin_init(void* comm, shmemi_bootstrap_handle_t* handle) handle->npes = nranks; SHMEM_CHECK_RET(socket_init(&state->listen_sock, SOCKET_TYPE_BOOTSTRAP, state->magic, &priv_info.bootstrap_netifaddr), "State's listen_sock failed while executing init. fd=" << state->listen_sock.fd); - SHMEM_CHECK_RET(socket_listen(&state->listen_sock), "State's listen_sock failed while executing listen. fd=" << state->listen_sock.fd); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_listen(&state->listen_sock), "State's listen_sock failed while executing listen. fd=" << state->listen_sock.fd, state->listen_sock); SHMEM_CHECK_RET(bootstrap_get_sock_addr(&state->listen_sock, &info.ext_addr_listen), "Get addr failed, the listen_sock in state maybe null. fd=" << state->listen_sock.fd); SHMEM_CHECK_RET(socket_init(&listen_sock_root, SOCKET_TYPE_BOOTSTRAP, state->magic, &priv_info.bootstrap_netifaddr), "Listen_sock_root failed while executing init. fd=" << listen_sock_root.fd); - SHMEM_CHECK_RET(socket_listen(&listen_sock_root), "listen_sock_root failed while executing listen. fd=" << listen_sock_root.fd); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_listen(&listen_sock_root), "listen_sock_root failed while executing listen. fd=" << listen_sock_root.fd, listen_sock_root); SHMEM_CHECK_RET(bootstrap_get_sock_addr(&listen_sock_root, &info.ext_address_listen_root), "Get addr failed, the listen_sock_root maybe null. fd=" << listen_sock_root.fd); SHMEM_CHECK_RET(socket_init(&sock, SOCKET_TYPE_BOOTSTRAP, magic, &uid_args->addr), "Sock failed while executing init. fd=" << sock.fd); - SHMEM_CHECK_RET(socket_connect(&sock), "Sock failed while executing connect. fd=" << sock.fd); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_connect(&sock), "Sock failed while executing connect. fd=" << sock.fd, sock); int peer_version = uid_args->version; int root_version; - SHMEM_CHECK_RET(socket_send(&sock, &peer_version, sizeof(peer_version)), "Sock failed while executing send peer_version. fd=" << sock.fd); - SHMEM_CHECK_RET(socket_recv(&sock, &root_version, sizeof(root_version)), "Sock failed while executing recv root_version. fd=" << sock.fd); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_send(&sock, &peer_version, sizeof(peer_version)), "Sock failed while executing send peer_version. fd=" << sock.fd, sock); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_recv(&sock, &root_version, sizeof(root_version)), "Sock failed while executing recv root_version. fd=" << sock.fd, sock); SHMEM_CHECK_RET(peer_version != root_version, " rank: " << rank << " . version mismatch with root", SHMEM_SMEM_ERROR); info.rank = rank; @@ -1190,13 +1333,13 @@ int shmemi_bootstrap_plugin_init(void* comm, shmemi_bootstrap_handle_t* handle) } - SHMEM_CHECK_RET(socket_send(&sock, &info, sizeof(info)), "Sock failed while executing send info. fd=" << sock.fd); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_send(&sock, &info, sizeof(info)), "Sock failed while executing send info. fd=" << sock.fd, sock); SHMEM_CHECK_RET(socket_close(&sock), "Sock failed while executing close. fd=" << sock.fd); SHMEM_CHECK_RET(socket_init(&sock, SOCKET_TYPE_BOOTSTRAP, SOCKET_MAGIC, nullptr), "Sock failed while executing init. fd=" << sock.fd); - SHMEM_CHECK_RET(socket_accept(&sock, &listen_sock_root), "Sock failed while executing accept listen_sock_root. fd=" << sock.fd); - SHMEM_CHECK_RET(socket_recv(&sock, &next_addr, sizeof(next_addr)), "Sock failed while executing recv next_addr. fd=" << sock.fd); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_accept(&sock, &listen_sock_root), "Sock failed while executing accept listen_sock_root. fd=" << sock.fd, sock); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_recv(&sock, &next_addr, sizeof(next_addr)), "Sock failed while executing recv next_addr. fd=" << sock.fd, sock); SHMEM_CHECK_RET(socket_close(&sock), "Sock failed while executing close. fd=" << sock.fd); SHMEM_CHECK_RET(socket_close(&listen_sock_root), "Listen_sock_root failed while executing close. fd=" << listen_sock_root.fd); @@ -1218,15 +1361,15 @@ int shmemi_bootstrap_plugin_init(void* comm, shmemi_bootstrap_handle_t* handle) // Initialize ring send socket SHMEM_CHECK_RET(socket_init(&state->ring_send_sock, SOCKET_TYPE_BOOTSTRAP, magic, &next_addr), "State's ring_send_sock failed while executing init. fd=" << state->ring_send_sock.fd); - SHMEM_CHECK_RET(socket_connect(&state->ring_send_sock), "State's ring_send_sock failed while executing connect. fd=" << state->ring_send_sock.fd); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_connect(&state->ring_send_sock), "State's ring_send_sock failed while executing connect. fd=" << state->ring_send_sock.fd, state->ring_send_sock); SHMEM_CHECK_RET(socket_init(&state->ring_recv_sock, SOCKET_TYPE_BOOTSTRAP, SOCKET_MAGIC, nullptr), "State's ring_recv_sock failed while executing init. fd=" << state->ring_recv_sock.fd); - SHMEM_CHECK_RET(socket_accept(&state->ring_recv_sock, &state->listen_sock),"State's ring_recv_sock failed while executing accept State's listen_sock. fd=" << state->ring_recv_sock.fd); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_accept(&state->ring_recv_sock, &state->listen_sock),"State's ring_recv_sock failed while executing accept State's listen_sock. fd=" << state->ring_recv_sock.fd, state->ring_recv_sock); SHMEM_CHECK_RET(bootstrap_get_sock_addr(&state->listen_sock, state->peer_addrs + handle->mype), "Get addr failed, the listen_sock in state maybe null. fd=" << state->listen_sock.fd); SHMEM_CHECK_RET(shmemi_bootstrap_uid_allgather(BOOTSTRAP_IN_PLACE, state->peer_addrs, sizeof(sockaddr_t), handle), "Bootstrap_uid_allgather failed"); handle->allgather = shmemi_bootstrap_uid_allgather; - handle->barrier = shmemi_bootstrap_uid_barrier; + handle->barrier = shmemi_bootstrap_uid_barrier_v2; handle->finalize = shmemi_bootstrap_uid_finalize; handle->alltoall = nullptr; handle->global_exit = nullptr; diff --git a/src/modules/bootstrap/socket/uid_socket.cpp b/src/modules/bootstrap/socket/uid_socket.cpp index 3103ee86..7d87bc10 100644 --- a/src/modules/bootstrap/socket/uid_socket.cpp +++ b/src/modules/bootstrap/socket/uid_socket.cpp @@ -15,6 +15,29 @@ #include #include "uid_socket.h" +static int socket_poll_fd(int fd, int events, int timeout_ms) { + struct pollfd pfd = {0}; + pfd.fd = fd; + pfd.events = events; + + int ret = poll(&pfd, 1, timeout_ms); + if (ret == -1) { + SHM_LOG_ERROR("poll failed: " << strerror(errno) << " (fd: " << fd << ")"); + return SHMEM_BOOTSTRAP_ERROR; + } else if (ret == 0) { + SHM_LOG_ERROR("poll timeout (" << timeout_ms << "ms) - fd: " << fd); + return SHMEM_TIMEOUT_ERROR; + } + + // 检查fd错误 + if (pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) { + SHM_LOG_ERROR("fd error (revents: " << pfd.revents << ") - fd: " << fd); + return SHMEM_BOOTSTRAP_ERROR; + } + + return SHMEM_SUCCESS; +} + static int socket_progress(int op, socket_t* sock, void* ptr, int size, int* offset, bool block = false, bool state_check = true) { if (sock == nullptr || ptr == nullptr || offset == nullptr || size < 0 || *offset < 0 || *offset > size) { SHM_LOG_ERROR("Invalid arguments: sock=" << sock << ", ptr=" << ptr @@ -26,6 +49,9 @@ static int socket_progress(int op, socket_t* sock, void* ptr, int size, int* off sock->state = SOCKET_STATE_ERROR; return SHMEM_BOOTSTRAP_ERROR; } + int poll_events = (op == SOCKET_TYPE_RECV) ? POLLIN : POLLOUT; + SHMEM_CHECK_RET(socket_poll_fd(sock->fd, poll_events, SOCKET_RECV_TIMEOUT_MS), "socket_poll_fd failed."); + int bytes = 0; int closed = 0; char* data = (char*)(ptr); @@ -257,7 +283,7 @@ static int socket_try_accept(socket_t* sock) { SHM_LOG_ERROR("socket_try_accept: invalid state " << sock->state); return SHMEM_BOOTSTRAP_ERROR; } - + SHMEM_CHECK_RET(socket_poll_fd(sock->accept_fd, POLLIN, SOCKET_ACCEPT_TIMEOUT_MS), "socket_poll_fd failed."); struct sockaddr sa; socklen_t socklen = sizeof(sa); diff --git a/src/modules/bootstrap/socket/uid_socket.h b/src/modules/bootstrap/socket/uid_socket.h index ed0e4780..0402a888 100644 --- a/src/modules/bootstrap/socket/uid_socket.h +++ b/src/modules/bootstrap/socket/uid_socket.h @@ -28,6 +28,9 @@ extern "C" { #define RETRY_TIMEDOUT_TIMES 50 #define SLEEP_INT 1000 // 重试间隔(微秒) +#define SOCKET_ACCEPT_TIMEOUT_MS 50000 // accept超时50秒 +#define SOCKET_RECV_TIMEOUT_MS 30000 // recv超时30秒 + #define SOCKET_BACKLOG 16384 typedef enum { @@ -88,6 +91,13 @@ struct bootstrap_netstate { pthread_t bootstrap_root; /* Socket Root Thread for phoning root to non-root peers */ }; +typedef struct unexpected_conn { + int peer; // 发送方rank + int tag; // 消息tag + socket_t sock; // 对应的socket连接 + struct unexpected_conn* next; // 链表下一个节点 +} unexpected_conn_t; + typedef struct { int rank; int nranks; @@ -96,6 +106,7 @@ typedef struct { socket_t ring_send_sock; socket_t ring_recv_sock; sockaddr_t* peer_addrs; + unexpected_conn_t* unexpected_conns; // 意外连接队列 } uid_bootstrap_state; int socket_init(socket_t* sock, socket_type_t type, uint64_t magic, const sockaddr_t* init_addr); diff --git a/src/modules/bootstrap/socket/uid_utils.h b/src/modules/bootstrap/socket/uid_utils.h index 5159b9bc..65ad8ca7 100644 --- a/src/modules/bootstrap/socket/uid_utils.h +++ b/src/modules/bootstrap/socket/uid_utils.h @@ -44,11 +44,23 @@ inline int bootstrap_calloc(T** ptr, size_t nelem, const char* file, int line) { bootstrap_calloc((ptr), (nelem), __FILE__, __LINE__) -#define SHMEM_BOOTSTRAP_PTR_FREE(ptr) \ - do { \ - if ((ptr) != NULL) { \ - free(ptr); \ - } \ +#define SHMEM_BOOTSTRAP_PTR_FREE(ptr) \ + do { \ + if ((ptr) != NULL) { \ + free(ptr); \ + } \ + } while (0) + +#define SHMEM_CHECK_RET_CLOSE_SOCK(x, LOG_STR, SOCK) \ + do { \ + int32_t check_ret = (x); \ + if (check_ret != 0) { \ + SHM_LOG_ERROR(" " << LOG_STR << " close sock " << #SOCK << " and return shmem error: " << check_ret); \ + if ((&(SOCK)) != nullptr) { \ + socket_close(&(SOCK)); \ + } \ + return check_ret; \ + } \ } while (0) #endif //SHMEM_UID_UTILS_H \ No newline at end of file diff --git a/tests/unittest/CMakeLists.txt b/tests/unittest/CMakeLists.txt index 60bbec87..e5b00e9e 100644 --- a/tests/unittest/CMakeLists.txt +++ b/tests/unittest/CMakeLists.txt @@ -19,4 +19,11 @@ target_link_directories(shmem_unittest PRIVATE ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib ${PROJECT_SOURCE_DIR}/3rdparty/googletest/lib ) -target_link_libraries(shmem_unittest PRIVATE shmem_unittest_device gtest gcov mf_smem shmem_unittest_include shmem) +target_link_libraries( + shmem_unittest PRIVATE + -Wl,--no-as-needed + shmem + -Wl,--as-needed +) + +target_link_libraries(shmem_unittest PRIVATE shmem_unittest_device gtest gcov mf_smem shmem_unittest_include) diff --git a/tests/unittest/host/init/init_host_test.cpp b/tests/unittest/host/init/init_host_test.cpp index 29ae01ad..78dc9c21 100644 --- a/tests/unittest/host/init/init_host_test.cpp +++ b/tests/unittest/host/init/init_host_test.cpp @@ -221,6 +221,12 @@ TEST(TestInitAPI, TestShmemSetLogLevel) auto ret = shmem_set_log_level(shm::DEBUG_LEVEL); EXPECT_EQ(ret, 0); + char* original_log_level = NULL; + const char* env_val = getenv("SHMEM_LOG_LEVEL"); + if (env_val != NULL) { + original_log_level = strdup(env_val); + } + setenv("SHMEM_LOG_LEVEL", "DEBUG", 1); EXPECT_EQ(shmem_set_log_level(-1), 0); @@ -237,4 +243,9 @@ TEST(TestInitAPI, TestShmemSetLogLevel) EXPECT_EQ(shmem_set_log_level(-1), 0); unsetenv("SHMEM_LOG_LEVEL"); + if (original_log_level != NULL) { + setenv("SHMEM_LOG_LEVEL", original_log_level, 1); + free(original_log_level); + original_log_level = NULL; + } } \ No newline at end of file diff --git a/tests/unittest/host/mem/shmem_host_heap_test.cpp b/tests/unittest/host/mem/shmem_host_heap_test.cpp index 2a9a85f2..7ac741e0 100644 --- a/tests/unittest/host/mem/shmem_host_heap_test.cpp +++ b/tests/unittest/host/mem/shmem_host_heap_test.cpp @@ -320,7 +320,7 @@ TEST_F(ShareMemoryManagerTest, stress_malloc_calloc_align_no_leak) aclrtStream stream; test_init(rank_id, n_ranks, local_mem_size, &stream); - constexpr int rounds = 500; + constexpr int rounds = 100; std::vector ptrs; ptrs.reserve(rounds * 3); -- Gitee