diff --git a/examples/rdma_perftest/main.cpp b/examples/rdma_perftest/main.cpp index 5c9694640e3ebc9d8c24bb195f4e348716e52ec5..a8be82173fba21a9fd4694e6a81fe54acc254f50 100644 --- a/examples/rdma_perftest/main.cpp +++ b/examples/rdma_perftest/main.cpp @@ -26,11 +26,13 @@ const char *ipport; int f_rank = 0; int f_npu = 0; const char *test_type; +int num_qps; extern void rdma_highlevel_put_pingpong_latency_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length); extern void rdma_postsend_cost_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length); extern void rdma_highlevel_put_bw_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length); extern void rdma_mte_put_bw_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length, int64_t iter); +extern void rdma_lowlevel_put_bw_do(uint32_t block_dim, void* stream, uint64_t cfg, uint8_t* gva, int len, int64_t iter); int test_shmem_rdma_highlevel_put_pingpong_latency(int rank_id, int n_ranks, uint64_t local_mem_size, int message_length) { @@ -123,7 +125,7 @@ int test_shmem_rdma_postsend_cost(int rank_id, int n_ranks, uint64_t local_mem_s return 0; } -int test_shmem_rdma_highlevel_put_bw(int rank_id, int n_ranks, uint64_t local_mem_size, int message_length) +int test_shmem_rdma_highlevel_put_bw(int rank_id, int n_ranks, uint64_t local_mem_size, int message_length, int num_qps, int mqp_sttgy) { int32_t device_id = rank_id % g_npus + f_npu; int status = 0; @@ -135,6 +137,8 @@ int test_shmem_rdma_highlevel_put_bw(int rank_id, int n_ranks, uint64_t local_me shmem_init_attr_t *attributes; status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + attributes->option_attr.num_of_qps_per_conn = num_qps; + attributes->option_attr.mqp_sttgy = mqp_sttgy; attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); @@ -150,11 +154,17 @@ int test_shmem_rdma_highlevel_put_bw(int rank_id, int n_ranks, uint64_t local_me } aclrtMemcpy(gva + rank_id * message_length, message_length, xHost, message_length, ACL_MEMCPY_HOST_TO_DEVICE); - rdma_highlevel_put_bw_do(1, stream, fftsConfig, gva, message_length); + if (mqp_sttgy == SHMEM_MQP_STRATEGY_RR) { + rdma_highlevel_put_bw_do(1, stream, fftsConfig, gva, message_length); + } else if (mqp_sttgy == SHMEM_MQP_STRATEGY_MC) { + rdma_highlevel_put_bw_do(num_qps, stream, fftsConfig, gva, message_length); + } else { + return -1; + } aclrtSynchronizeStream(stream); if (rank_id == 0) { aclrtMemcpy(xHost, sizeof(int64_t), gva + message_length * n_ranks, sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST); - std::cout << "RDMA high level put bandwidth test. Message length = " << message_length << " Byte; time = " << xHost[0] / (50.0) << " us." << std::endl; + std::cout << "RDMA high level put bandwidth test. Message length = " << message_length << " Byte; BW = " << message_length * 10 / (xHost[0] / (50.0)) << " GB/s." << std::endl; } aclrtFreeHost(xHost); @@ -223,9 +233,65 @@ int test_shmem_rdma_mte_put_bw(int rank_id, int n_ranks, uint64_t local_mem_size return 0; } +int test_shmem_rdma_put_bw(int rank_id, int n_ranks, uint64_t local_mem_size, int message_length, int num_qps) +{ + int32_t device_id = rank_id % g_npus + f_npu; + int status = 0; + aclrtStream stream = nullptr; + + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); + status = aclrtCreateStream(&stream); + + shmem_init_attr_t *attributes; + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; + attributes->option_attr.num_of_qps_per_conn = num_qps; + status = shmem_init_attr(SHMEMX_INIT_WITH_MPI, attributes); + shmem_mte_set_ub_params(0, 128 * 1024, 0); + + uint64_t fftsConfig = shmemx_get_ffts_config(); + uint8_t* gva = (uint8_t*)shmem_malloc(1024 * 1024 * 32); + int64_t *inHost; + int64_t *outHost; + size_t totalSize = message_length * n_ranks * 3; + aclrtMallocHost((void **)(&inHost), totalSize); + aclrtMallocHost((void **)(&outHost), totalSize); + memset(inHost, 0, totalSize); + double rdmaTotalTime = 0.0; + + for (int iter = 0; iter < 20; iter++) { + for (uint32_t i = 0; i < message_length / sizeof(int64_t); i++) { + inHost[i + rank_id * message_length / sizeof(int64_t)] = rank_id + 10 + iter; + } + for (uint32_t i = 0; i < message_length / sizeof(int64_t); i++) { + inHost[i + (rank_id + n_ranks) * message_length / sizeof(int64_t)] = rank_id + 10 + iter; + } + aclrtMemcpy(gva, totalSize, inHost, totalSize, ACL_MEMCPY_HOST_TO_DEVICE); + shmemi_control_barrier_all(); + rdma_lowlevel_put_bw_do(num_qps, stream, fftsConfig, gva, message_length, iter); + aclrtSynchronizeStream(stream); + if (rank_id == 0 && iter >= 10) { + aclrtMemcpy(outHost, 64, gva + message_length * n_ranks * 2, 64, ACL_MEMCPY_DEVICE_TO_HOST); + rdmaTotalTime += outHost[0] / 50.0; + } + } + if (rank_id == 0) { + std::cout << "RDMA test. Message length = " << message_length << " Byte; average RDMA BW = " << (message_length / rdmaTotalTime) * 100 << " GB/s." << std::endl; + } + + aclrtFreeHost(inHost); + aclrtFreeHost(outHost); + shmem_finalize(); + aclrtDestroyStream(stream); + aclrtResetDevice(device_id); + aclFinalize(); + return 0; +} + int main(int argc, char *argv[]) { - if (argc != 7) { + if (argc < 7) { std::cout << "[ERROR] Paramater number mismatch." << std::endl; std::cout << "[USAGE] ./rdma_perftest . See README for more details." << std::endl; } @@ -251,15 +317,25 @@ int main(int argc, char *argv[]) f_npu = atoi(argv[4]); test_type = argv[5]; int msg_len = atoi(argv[6]); + int num_qps = 1; + if (argc == 8) { + num_qps = atoi(argv[7]); + } uint64_t local_mem_size = 1024UL * 1024UL * 64; if (std::string(test_type) == "highlevel_put_pingpong_latency") { test_shmem_rdma_highlevel_put_pingpong_latency(rank_id, n_ranks, local_mem_size, msg_len); } else if (std::string(test_type) == "postsend_cost") { test_shmem_rdma_postsend_cost(rank_id, n_ranks, local_mem_size, msg_len); } else if (std::string(test_type) == "highlevel_put_bw") { - test_shmem_rdma_highlevel_put_bw(rank_id, n_ranks, local_mem_size, msg_len); + int mqp_sttgy = SHMEM_MQP_STRATEGY_RR; + if (argc == 9) { + mqp_sttgy = atoi(argv[8]); + } + test_shmem_rdma_highlevel_put_bw(rank_id, n_ranks, local_mem_size, msg_len, num_qps, mqp_sttgy); } else if (std::string(test_type) == "rdma_mte_bw") { test_shmem_rdma_mte_put_bw(rank_id, n_ranks, local_mem_size, msg_len); + } else if (std::string(test_type) == "rdma_lowlevel_bw") { + test_shmem_rdma_put_bw(rank_id, n_ranks, local_mem_size, msg_len, num_qps); } std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; diff --git a/examples/rdma_perftest/rdma_perftest_kernel.cpp b/examples/rdma_perftest/rdma_perftest_kernel.cpp index c3eaab964386c1b998f213c4f949c1822bd5109d..a0cde87549245715df144ecc893b5081a4090a8a 100644 --- a/examples/rdma_perftest/rdma_perftest_kernel.cpp +++ b/examples/rdma_perftest/rdma_perftest_kernel.cpp @@ -111,11 +111,8 @@ void rdma_postsend_cost_do(uint32_t block_dim, void* stream, uint64_t fftsConfig rdma_postsend_cost<<<1, nullptr, stream>>>(fftsConfig, gva, message_length); } -extern "C" __global__ __aicore__ void rdma_highlevel_put_bw(uint64_t fftsConfig, GM_ADDR gva, int message_length) { +extern "C" __global__ __aicore__ void rdma_highlevel_put_bw(uint64_t fftsConfig, GM_ADDR gva, int message_length, int num_qps) { shmemx_set_ffts_config(fftsConfig); - if (AscendC::GetSubBlockIdx() != 0) { - return; - } AscendC::TPipe pipe; AscendC::TBuf buf; pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); @@ -125,16 +122,21 @@ extern "C" __global__ __aicore__ void rdma_highlevel_put_bw(uint64_t fftsConfig, int64_t rank = shmem_my_pe(); int64_t rank_size = shmem_n_pes(); uint32_t peer; + int core_id = AscendC::GetBlockIdx(); // Actual test GM_ADDR src_addr = gva + rank * message_length; if (rank == 0) { peer = 1; int64_t start = AscendC::GetSystemCycle(); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < 10000 / num_qps; i++) { shmem_put_uint8_mem_nbi(src_addr, src_addr, message_length, peer); } - shmemi_roce_quiet(peer, 0, ubLocal64, ubLocal32); + shmemi_roce_quiet(peer, core_id, ubLocal64, ubLocal32); + shmemi_barrier_core(); + if (core_id > 0) { + return; + } shmem_put_uint8_mem_nbi(gva + rank_size * message_length + 8, src_addr, sizeof(uint32_t), peer); while (*(__gm__ uint32_t*)(gva + message_length * rank_size + 16) != peer + MAGIC_VAL) { dcci_cachelines(gva + message_length * rank_size + 16, 8); @@ -144,6 +146,9 @@ extern "C" __global__ __aicore__ void rdma_highlevel_put_bw(uint64_t fftsConfig, int64_t end = AscendC::GetSystemCycle(); *(__gm__ int64_t*)(gva + message_length * rank_size) = end - start; } else { + if (core_id > 0) { + return; + } peer = 0; while (*(__gm__ uint32_t*)(gva + rank_size * message_length + 8) != peer + MAGIC_VAL) { dcci_cachelines(gva + rank_size * message_length + 8, 8); @@ -155,7 +160,8 @@ extern "C" __global__ __aicore__ void rdma_highlevel_put_bw(uint64_t fftsConfig, } void rdma_highlevel_put_bw_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length) { - rdma_highlevel_put_bw<<<1, nullptr, stream>>>(fftsConfig, gva, message_length); + // block_dim: num_qps 1:1 + rdma_highlevel_put_bw<<>>(fftsConfig, gva, message_length, block_dim); } extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_ADDR gva, int message_length, int64_t iter) { @@ -234,6 +240,63 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_AD } } +extern "C" __global__ __aicore__ void rdma_put_bw(uint64_t fftsConfig, GM_ADDR gva, int message_length, int64_t iter, int num_qps) { + shmemx_set_ffts_config(fftsConfig); + AscendC::LocalTensor ubLocal32; + ubLocal32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ubLocal32.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR); + ubLocal32.address_.dataLen = UB_ALIGN_SIZE; + AscendC::LocalTensor ubLocal64; + ubLocal64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ubLocal64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR + UB_ALIGN_SIZE); + ubLocal64.address_.dataLen = UB_ALIGN_SIZE; + + __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); + int64_t rank = device_state->mype; + int64_t rank_size = device_state->npes; + uint32_t peer; + int core_id = AscendC::GetBlockIdx(); + + GM_ADDR src_addr = gva + rank * message_length; + if (rank == 0) { + peer = 1; + int64_t start = AscendC::GetSystemCycle(); + for (int i = 0; i < 10000 / num_qps; i++) { + shmemi_roce_write((GM_ADDR)shmem_roce_ptr(src_addr, peer), src_addr, peer, core_id, message_length, ubLocal64, ubLocal32); + } + shmemi_roce_quiet(peer, core_id, ubLocal64, ubLocal32); + shmemi_barrier_core(); + if (core_id > 0) { + return; + } + shmemi_roce_write((GM_ADDR)shmem_roce_ptr(gva + rank_size * message_length * 2 + 8, peer), src_addr, peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); + while (*(__gm__ int64_t*)(gva + message_length * rank_size * 2 + 16) != peer + MAGIC_VAL + iter) { + dcci_cachelines(gva + message_length * rank_size * 2 + 16, 8); + AscendC::GetSystemCycle(); + } + AscendC::PipeBarrier(); + int64_t end = AscendC::GetSystemCycle(); + *(__gm__ int64_t*)(gva + message_length * rank_size * 2) = end - start; + } else { + if (core_id > 0) { + return; + } + peer = 0; + while (*(__gm__ int64_t*)(gva + rank_size * message_length * 2 + 8) != peer + MAGIC_VAL + iter) { + dcci_cachelines(gva + rank_size * message_length * 2 + 8, 8); + AscendC::GetSystemCycle(); + } + AscendC::PipeBarrier(); + shmemi_roce_write((GM_ADDR)shmem_roce_ptr(gva + rank_size * message_length * 2 + 16, peer), src_addr, peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); + } +} + void rdma_mte_put_bw_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length, int64_t iter) { rdma_mte_put_bw<<<2, nullptr, stream>>>(fftsConfig, gva, message_length, iter); +} + +void rdma_lowlevel_put_bw_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int len, int64_t iter) +{ + // block_dim: num_qps 1:1 + rdma_put_bw<<>>(fftsConfig, gva, len, iter, block_dim); } \ No newline at end of file diff --git a/include/device/low_level/shmem_device_low_level_roce.h b/include/device/low_level/shmem_device_low_level_roce.h index 60c6fab4c5c7aaecfb3030f29d059fd5eb0b482d..9be89e41fe6d035cdba67526bfde251033ca9fe8 100644 --- a/include/device/low_level/shmem_device_low_level_roce.h +++ b/include/device/low_level/shmem_device_low_level_roce.h @@ -25,7 +25,7 @@ enum class SHMEMAIVOPCODE : uint32_t { }; struct SHMEMAIVRDMAInfo { - uint32_t qpNum; // number of QP per connection + uint64_t nextQpIdxPtr; // pointer to next QP index array of size [MAX_PE_NUM] uint64_t sqPtr; // pointer to send queue address array of size [PE_NUM][qpNum] uint64_t rqPtr; // pointer to receive queue address array of size [PE_NUM][qpNum] uint64_t scqPtr; // pointer to send completion queue address array of size [PE_NUM][qpNum] @@ -102,6 +102,30 @@ struct SHMEMHybmDeviceMeta { uint64_t reserved[12]; // total 128B, equal HYBM_DEVICE_PRE_META_SIZE }; +/** + * @brief Get the next QP index to use for the next transmission to the given pe + * + * @param pe [in] Remote PE number + * @param device_state [in] Global device state + * @return the next QP index + */ +SHMEM_DEVICE int shmem_next_qp_idx(int pe, __gm__ shmemi_device_host_state_t *device_state) +{ + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(device_state->qp_info); + __gm__ uint8_t* idx_addr = (__gm__ uint8_t *)RDMAInfo->nextQpIdxPtr; + uint8_t next_qp_idx; + + if (pe >= SHMEM_MAX_RANKS) + return -1; + + next_qp_idx = idx_addr[pe]; + + idx_addr[pe] = (next_qp_idx + 1) & + (device_state->num_qps_per_conn - 1); + + return next_qp_idx; +} + /** * @brief RDMA Poll Completion Queue (CQ) function. Return status: 0 means success, non-zero means error. * @@ -118,7 +142,7 @@ SHMEM_DEVICE uint32_t shmemi_roce_poll_cq(uint32_t remoteRankId, uint32_t qpIdx, { __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(device_state->qp_info); - uint32_t qpNum = RDMAInfo->qpNum; + uint32_t qpNum = device_state->num_qps_per_conn; __gm__ SHMEMCQCtx* cqCtxEntry = (__gm__ SHMEMCQCtx*)(RDMAInfo->scqPtr + (remoteRankId * qpNum + qpIdx) * sizeof(SHMEMCQCtx)); auto cqBaseAddr = cqCtxEntry->bufAddr; auto cqeSize = cqCtxEntry->cqeSize; @@ -213,7 +237,7 @@ SHMEM_DEVICE void shmemi_rdma_post_send(__gm__ uint8_t* remoteAddr, __gm__ uint8 { __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(device_state->qp_info); - uint32_t qpNum = RDMAInfo->qpNum; + uint32_t qpNum = device_state->num_qps_per_conn; __gm__ SHMEMWQCtx* qpCtxEntry = (__gm__ SHMEMWQCtx*)(RDMAInfo->sqPtr + (destRankId * qpNum + qpIdx) * sizeof(SHMEMWQCtx)); auto SHMEMmemInfoTable = RDMAInfo->memPtr; auto sqBaseAddr = qpCtxEntry->bufAddr; @@ -345,7 +369,7 @@ SHMEM_DEVICE void shmemi_roce_quiet(uint32_t remoteRankId, uint32_t qpIdx, { __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(device_state->qp_info); - uint32_t qpNum = RDMAInfo->qpNum; + uint32_t qpNum = device_state->num_qps_per_conn; __gm__ SHMEMWQCtx* qpCtxEntry = (__gm__ SHMEMWQCtx*)(RDMAInfo->sqPtr + (remoteRankId * qpNum + qpIdx) * sizeof(SHMEMWQCtx)); auto curHardwareHeadAddr = qpCtxEntry->headAddr; dcci_cachelines((__gm__ uint8_t*)curHardwareHeadAddr, 8); @@ -358,7 +382,7 @@ SHMEM_DEVICE void shmemi_roce_qpinfo_test(__gm__ uint8_t* gva, uint32_t destRank __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(device_state->qp_info); *(__gm__ uint64_t*)(gva) = (uint64_t)RDMAInfo; - uint32_t qpNum = RDMAInfo->qpNum; + uint32_t qpNum = device_state->num_qps_per_conn; *(__gm__ uint64_t*)(gva + 8) = (uint64_t)qpNum; __gm__ SHMEMWQCtx* qpCtxEntry = (__gm__ SHMEMWQCtx*)(RDMAInfo->sqPtr + (destRankId * qpNum + qpIdx) * sizeof(SHMEMWQCtx)); *(__gm__ uint64_t*)(gva + 16) = (uint64_t)qpCtxEntry; @@ -407,7 +431,7 @@ SHMEM_DEVICE void shmemi_roce_pollcq_test(__gm__ T* srcDmaAddr, __gm__ T* destDm uint32_t idx = 1; __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(device_state->qp_info); - uint32_t qpNum = RDMAInfo->qpNum; + uint32_t qpNum = device_state->num_qps_per_conn; __gm__ SHMEMCQCtx* cqCtxEntry = (__gm__ SHMEMCQCtx*)(RDMAInfo->scqPtr + (destRankId * qpNum + qpIdx) * sizeof(SHMEMCQCtx)); *(__gm__ uint64_t*)(gva) = (uint64_t)cqCtxEntry; auto cqBaseAddr = cqCtxEntry->bufAddr; diff --git a/include/device/shmem_device_rma.h b/include/device/shmem_device_rma.h index cb85f12b3f7a1e0f03eb75ca96aa12c68e2dfff9..b3452d20f9136ceef97aec766e9d9a0c45663dcf 100644 --- a/include/device/shmem_device_rma.h +++ b/include/device/shmem_device_rma.h @@ -256,7 +256,13 @@ SHMEM_DEVICE void shmem_getmem_nbi(__gm__ void *dst, __gm__ void *src, uint32_t } else if (device_state->topo_list[pe] & SHMEM_TRANSPORT_ROCE) { \ /* RoCE */ \ auto ptr = shmem_roce_ptr(src, pe); \ - if (ptr == nullptr) return; \ + int qp_idx; \ + if (device_state->mqp_sttgy == SHMEM_MQP_STRATEGY_RR) { \ + qp_idx = shmem_next_qp_idx(pe, device_state); \ + } else if (device_state->mqp_sttgy == SHMEM_MQP_STRATEGY_MC) { \ + qp_idx = AscendC::GetBlockIdx() & (device_state->num_qps_per_conn - 1); \ + } \ + if (qp_idx < 0) return; \ /* Create LocalTensor */ \ AscendC::LocalTensor ub_tensor_32; \ ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); \ @@ -267,7 +273,7 @@ SHMEM_DEVICE void shmem_getmem_nbi(__gm__ void *dst, __gm__ void *src, uint32_t ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR \ + UB_ALIGN_SIZE); \ ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; \ - shmemi_roce_read((__gm__ uint8_t*)dst, (__gm__ uint8_t*)ptr, pe, 0, elem_size * sizeof(TYPE), \ + shmemi_roce_read((__gm__ uint8_t*)dst, (__gm__ uint8_t*)ptr, pe, qp_idx, elem_size * sizeof(TYPE), \ ub_tensor_64, ub_tensor_32); \ } \ } @@ -332,6 +338,13 @@ SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_DETAILED_NBI); /* RoCE */ \ auto ptr = shmem_roce_ptr((__gm__ void *)src.GetPhyAddr(), pe); \ if (ptr == nullptr) return; \ + int qp_idx; \ + if (device_state->mqp_sttgy == SHMEM_MQP_STRATEGY_RR) { \ + qp_idx = shmem_next_qp_idx(pe, device_state); \ + } else if (device_state->mqp_sttgy == SHMEM_MQP_STRATEGY_MC) { \ + qp_idx = AscendC::GetBlockIdx() & (device_state->num_qps_per_conn - 1); \ + } \ + if (qp_idx < 0) return; \ /* Create LocalTensor */ \ AscendC::LocalTensor ub_tensor_32; \ ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); \ @@ -342,7 +355,7 @@ SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_DETAILED_NBI); ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR \ + UB_ALIGN_SIZE); \ ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; \ - shmemi_roce_read((__gm__ uint8_t*)(dst.GetPhyAddr()), (__gm__ uint8_t*)ptr, pe, 0, \ + shmemi_roce_read((__gm__ uint8_t*)(dst.GetPhyAddr()), (__gm__ uint8_t*)ptr, pe, qp_idx, \ elem_size * sizeof(TYPE), ub_tensor_64, ub_tensor_32); \ } \ } @@ -403,8 +416,15 @@ SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_TENSOR_DETAILED_NBI); copy_event_id); \ } else if (device_state->topo_list[pe] & SHMEM_TRANSPORT_ROCE) { \ /* RoCE */ \ - auto ptr = shmem_roce_ptr(dst, pe); \ + auto ptr = shmem_roce_ptr(dst, pe); \ if (ptr == nullptr) return; \ + int qp_idx; \ + if (device_state->mqp_sttgy == SHMEM_MQP_STRATEGY_RR) { \ + qp_idx = shmem_next_qp_idx(pe, device_state); \ + } else if (device_state->mqp_sttgy == SHMEM_MQP_STRATEGY_MC) { \ + qp_idx = AscendC::GetBlockIdx() & (device_state->num_qps_per_conn - 1); \ + } \ + if (qp_idx < 0) return; \ /* Create LocalTensor */ \ AscendC::LocalTensor ub_tensor_32; \ ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); \ @@ -415,7 +435,7 @@ SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_TENSOR_DETAILED_NBI); ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR \ + UB_ALIGN_SIZE); \ ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; \ - shmemi_roce_write((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)src, pe, 0, elem_size * sizeof(TYPE), \ + shmemi_roce_write((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)src, pe, qp_idx, elem_size * sizeof(TYPE), \ ub_tensor_64, ub_tensor_32); \ } \ } @@ -477,8 +497,15 @@ SHMEM_TYPE_FUNC(SHMEM_PUT_TYPENAME_MEM_DETAILED_NBI); shmem_mte_put_mem_nbi(dst, src, ub_tensor, elem_size, pe, copy_event_id); \ } else if (device_state->topo_list[pe] & SHMEM_TRANSPORT_ROCE) { \ /* RoCE */ \ - auto ptr = shmem_roce_ptr((__gm__ void *)dst.GetPhyAddr(), pe); \ + auto ptr = shmem_roce_ptr((__gm__ void *)dst.GetPhyAddr(), pe); \ if (ptr == nullptr) return; \ + int qp_idx; \ + if (device_state->mqp_sttgy == SHMEM_MQP_STRATEGY_RR) { \ + qp_idx = shmem_next_qp_idx(pe, device_state); \ + } else if (device_state->mqp_sttgy == SHMEM_MQP_STRATEGY_MC) { \ + qp_idx = AscendC::GetBlockIdx() & (device_state->num_qps_per_conn - 1); \ + } \ + if (qp_idx < 0) return; \ /* Create LocalTensor */ \ AscendC::LocalTensor ub_tensor_32; \ ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); \ @@ -489,7 +516,7 @@ SHMEM_TYPE_FUNC(SHMEM_PUT_TYPENAME_MEM_DETAILED_NBI); ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR \ + UB_ALIGN_SIZE); \ ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; \ - shmemi_roce_write((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)(src.GetPhyAddr()), pe, 0, \ + shmemi_roce_write((__gm__ uint8_t*)ptr, (__gm__ uint8_t*)(src.GetPhyAddr()), pe, qp_idx, \ elem_size * sizeof(TYPE), ub_tensor_64, ub_tensor_32); \ } \ } diff --git a/include/host/shmem_host_def.h b/include/host/shmem_host_def.h index 5fa1672670ff00061b290abe36e91b9750bee4d2..17536df6043da7b2a8395d91ed1356abd446266e 100644 --- a/include/host/shmem_host_def.h +++ b/include/host/shmem_host_def.h @@ -134,6 +134,8 @@ typedef struct { uint32_t shm_init_timeout; uint32_t shm_create_timeout; uint32_t control_operation_timeout; + uint8_t num_of_qps_per_conn; + uint8_t mqp_sttgy; } shmem_init_optional_attr_t; /** diff --git a/include/internal/host_device/shmemi_types.h b/include/internal/host_device/shmemi_types.h index cc0389b44c63f15468339a46861a74942129b672..32975669e9707cc7542dba07323e012440c47456 100644 --- a/include/internal/host_device/shmemi_types.h +++ b/include/internal/host_device/shmemi_types.h @@ -46,6 +46,19 @@ extern "C" { #define SHMEM_EXTRA_SIZE_UNALIGHED SYNC_POOL_SIZE #define SHMEM_EXTRA_SIZE ALIGH_TO(SHMEM_EXTRA_SIZE_UNALIGHED, SHMEM_PAGE_SIZE) +// Multi QP related +#define SHMEM_MAX_QPS_PER_CONN 32 +#define SHMEM_DEF_QPS_PER_CONN 1 +// Multi QP strategies for high-level APIs +#define SHMEM_MQP_STRATEGY_RR 0 // Uses round-robin to serially access QPs. + // A given PE can't be concurrently accessed from multiple cores + +#define SHMEM_MQP_STRATEGY_MC 1 // Uses multiple cores to concurrenlty access QPs. + // A given PE can be accessed by a single set of contiguous AIV + // cores of size no larger than the number of QPs. + // Example, with 4 QPs, cores 0-3/1-4/.../36-39 can be used concurrently. +#define SHMEM_DEF_MQP_STRATEGY SHMEM_MQP_STRATEGY_RR + // global_state constexpr uint64_t DEVMM_SVM_MEM_START = 0x100000000000ULL; constexpr uint64_t SVM_END_ADDR = 0x100000000000ULL + 0x80000000000ULL - (1UL << 30UL); // svm end @@ -98,6 +111,8 @@ typedef struct { shmemi_mte_config_t mte_config; uint64_t qp_info; + int num_qps_per_conn; + uint8_t mqp_sttgy; } shmemi_device_host_state_t; #ifdef __cplusplus diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index 54dba01c550082cc7e057808c8d31d11fc3df4e6..51974c8ea903c9328fcd914ba5dab4826db4f442 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -54,6 +54,8 @@ constexpr int DEFAULT_BLOCK_NUM = 1; false, /* shmem_is_shmem_created */ \ {0, 16 * 1024, 0}, /* shmem_mte_config */ \ 0, /* qp_info */ \ + SHMEM_DEF_QPS_PER_CONN, /* num_qps_per_conn */ \ + SHMEM_DEF_MQP_STRATEGY, /* mqp_sttgy */ \ } shmemi_device_host_state_t g_state = SHMEM_DEVICE_HOST_STATE_INITIALIZER; @@ -105,6 +107,8 @@ int32_t shmemi_state_init_attr(shmem_init_attr_t *attributes) g_state.npes = attributes->n_ranks; g_state.heap_size = attributes->local_mem_size + SHMEM_EXTRA_SIZE; g_state.host_hash = shmemi_get_host_hash(); + g_state.num_qps_per_conn = attributes->option_attr.num_of_qps_per_conn; + g_state.mqp_sttgy = attributes->option_attr.mqp_sttgy; aclrtStream stream = nullptr; SHMEM_CHECK_RET(aclrtCreateStream(&stream)); @@ -188,7 +192,7 @@ int32_t shmem_set_attr(int32_t my_rank, int32_t n_ranks, uint64_t local_mem_size g_attr.ip_port = g_ipport; g_attr.local_mem_size = local_mem_size; g_attr.option_attr = {attr_version, SHMEM_DATA_OP_MTE, DEFAULT_TIMEOUT, - DEFAULT_TIMEOUT, DEFAULT_TIMEOUT}; + DEFAULT_TIMEOUT, DEFAULT_TIMEOUT, SHMEM_DEF_QPS_PER_CONN, SHMEM_DEF_MQP_STRATEGY}; // g_attr_init = true; return SHMEM_SUCCESS; } diff --git a/src/modules/transport/rdma/device_qp_manager.cpp b/src/modules/transport/rdma/device_qp_manager.cpp index dc37f32a4f50c1b331b58381065774fb17a4d155..287ae25958c5aab6a4c1bac3a19c7b8edd56cdbf 100644 --- a/src/modules/transport/rdma/device_qp_manager.cpp +++ b/src/modules/transport/rdma/device_qp_manager.cpp @@ -8,18 +8,36 @@ * See LICENSE in the root of the software repository for the full text of the License. */ +#include +#include #include "dl_hccp_api.h" #include "device_qp_manager.h" using namespace shm; DeviceQpManager::DeviceQpManager(uint32_t deviceId, uint32_t rankId, uint32_t rankCount, sockaddr_in devNet, - hybm_role_type role) noexcept + hybm_role_type role, uint32_t numOfQpsPerConn) noexcept : deviceId_{deviceId}, rankId_{rankId}, rankCount_{rankCount}, deviceAddress_{devNet}, rankRole_{role} { + numOfQpsPerConn_ = numOfQpsPerConn; + + /* Adjust if needed */ + if (numOfQpsPerConn_ > SHMEM_MAX_QPS_PER_CONN) + numOfQpsPerConn_ = SHMEM_MAX_QPS_PER_CONN; + if (numOfQpsPerConn_ == 0) + numOfQpsPerConn_ = SHMEM_DEF_QPS_PER_CONN; + + /* Round down to the next power of 2 */ + numOfQpsPerConn_ |= (numOfQpsPerConn_ >> 1); + numOfQpsPerConn_ |= (numOfQpsPerConn_ >> 2); + numOfQpsPerConn_ |= (numOfQpsPerConn_ >> 4); + numOfQpsPerConn_ |= (numOfQpsPerConn_ >> 8); + numOfQpsPerConn_ |= (numOfQpsPerConn_ >> 16); + + numOfQpsPerConn_ = (numOfQpsPerConn_ + 1) >> 1; } void *DeviceQpManager::CreateLocalSocket() noexcept @@ -206,7 +224,7 @@ void *DeviceQpManager::GetQpHandleWithRankId(uint32_t rankId) const noexcept return nullptr; } - return pos->second.qpHandles[CONN_QP_STARS]; + return pos->second.qpHandles[CONN_QP_STARS][0]; } bool DeviceQpManager::ReserveQpInfoSpace() noexcept @@ -217,7 +235,7 @@ bool DeviceQpManager::ReserveQpInfoSpace() noexcept void *ptr = nullptr; auto oneQpSize = 2U * (sizeof(AiQpRMAWQ) + sizeof(AiQpRMACQ)) + sizeof(RdmaMemRegionInfo); - qpInfoSize_ = sizeof(AiQpRMAQueueInfo) + oneQpSize * rankCount_; + qpInfoSize_ = sizeof(AiQpRMAQueueInfo) + rankCount_ * (oneQpSize * numOfQpsPerConn_ + sizeof(uint8_t)); auto ret = aclrtMalloc(&ptr, qpInfoSize_, ACL_MEM_MALLOC_HUGE_FIRST); if (ret != 0) { SHM_LOG_ERROR("allocate device size: " << qpInfoSize_ << ", failed: " << ret); @@ -437,54 +455,67 @@ int DeviceQpManager::CreateQpWaitingReady(std::unordered_mapsecond); - if (ret != 0) { - SHM_LOG_ERROR("create QP type:" << qpType << " to " << it->first << " failed: " << ret); - return SHMEM_INNER_ERROR; - } - - for (auto pos = currentLocalMrs_.begin(); pos != currentLocalMrs_.end(); ++pos) { - HccpMrInfo info{}; - info.addr = (void *)(ptrdiff_t)pos->second.address; - info.size = pos->second.size; - info.access = 7; - ret = DlHccpApi::RaMrReg(it->second.qpHandles[qpType], info); + for (int i = 0; i < (qpType == CONN_QP_AI_CORE ? numOfQpsPerConn_ : 1); i++) { + auto ret = CreateOneQp(qpType, it->second, i); if (ret != 0) { - SHM_LOG_ERROR("register MR failed: " << ret); + SHM_LOG_ERROR("create QP #" << i << " type:" << qpType << " to " << it->first << " failed: " << ret); return SHMEM_INNER_ERROR; } - } - ret = DlHccpApi::RaQpConnectAsync(it->second.qpHandles[qpType], it->second.socketFd); - if (ret != 0) { - SHM_LOG_ERROR("connect AI QP to " << it->first << " failed: " << ret); - return SHMEM_INNER_ERROR; - } - } + if (i == 0) { + for (auto pos = currentLocalMrs_.begin(); pos != currentLocalMrs_.end(); ++pos) { + HccpMrInfo info{}; + info.addr = (void *)(ptrdiff_t)pos->second.address; + info.size = pos->second.size; + info.access = 7; + ret = DlHccpApi::RaMrReg(it->second.qpHandles[qpType][i], info); + if (ret != 0) { + SHM_LOG_ERROR("register MR failed: " << ret); + return SHMEM_INNER_ERROR; + } + } + } - auto start = std::chrono::steady_clock::now(); - auto timeout = start + std::chrono::minutes(1); - while (std::chrono::steady_clock::now() < timeout) { - int connectingCount = 0; - for (auto it = connections.begin(); it != connections.end(); ++it) { - int status = 0; - auto ret = DlHccpApi::RaGetQpStatus(it->second.qpHandles[qpType], status); + ret = DlHccpApi::RaQpConnectAsync(it->second.qpHandles[qpType][i], it->second.socketFd); if (ret != 0) { - SHM_LOG_ERROR("get AI QP status to " << it->first << " failed: " << ret); + SHM_LOG_ERROR("connect QP #" << i << " type:" << qpType << " to " << it->first << " failed: " << ret); return SHMEM_INNER_ERROR; } - if (status != 1) { - connectingCount++; - } - } - if (connectingCount == 0) { - return FillQpInfo(qpType); + + /* The DlHccpApi functions RaQpConnectAsync & RaGetQpStatus has some + * limitation when creating more than one QP in paralell hence the + * next block is to bypass it by enforcing sync connect and sync + * connection establishment state machine + */ + auto start = std::chrono::steady_clock::now(); + auto timeout = start + std::chrono::minutes(1); + int status; + struct timespec req; + req.tv_sec = 0; + req.tv_nsec = 1000000; + do { + status = 0; + auto ret = DlHccpApi::RaGetQpStatus(it->second.qpHandles[qpType][i], status); + if (ret != 0) { + SHM_LOG_ERROR("get AI QP #" << i << " status to " << it->first << " failed: " << ret); + return SHMEM_INNER_ERROR; + } + if (status != 1) { + if (std::chrono::steady_clock::now() > timeout) { + SHM_LOG_ERROR("connect AI QP #" << i << " status to " << it->first << " timeout!"); + return SHMEM_INNER_ERROR; + } else { + nanosleep(&req, nullptr); + } + } + } while (status != 1); } } - return SHMEM_INNER_ERROR; + + return FillQpInfo(qpType); } -int DeviceQpManager::CreateOneQp(ConnQpType qpType, ConnectionChannel &channel) noexcept +int DeviceQpManager::CreateOneQp(ConnQpType qpType, ConnectionChannel &channel, uint8_t qp_idx) noexcept { int ret; if (qpType == CONN_QP_AI_CORE) { @@ -499,9 +530,9 @@ int DeviceQpManager::CreateOneQp(ConnQpType qpType, ConnectionChannel &channel) attr.qp_attr.cap.max_recv_sge = MAX_RECV_SGE; attr.qp_attr.qp_type = IBV_QPT_RC; attr.data_plane_flag.bs.cq_cstm = CALLER_POLL_CQ_CSTM; - ret = DlHccpApi::RaQpAiCreate(rdmaHandle_, attr, channel.aiQpInfo, channel.qpHandles[qpType]); + ret = DlHccpApi::RaQpAiCreate(rdmaHandle_, attr, channel.aiQpInfo[qp_idx], channel.qpHandles[qpType][qp_idx]); } else { - ret = DlHccpApi::RaQpCreate(rdmaHandle_, 0, QP_MODE, channel.qpHandles[qpType]); + ret = DlHccpApi::RaQpCreate(rdmaHandle_, 0, QP_MODE, channel.qpHandles[qpType][0]); } return ret; } @@ -515,12 +546,16 @@ int DeviceQpManager::FillQpInfo(ConnQpType qpType) noexcept const uint32_t slevel = 4; std::vector qpInfoBuffer(qpInfoSize_); auto copyInfo = (AiQpRMAQueueInfo *)(void *)qpInfoBuffer.data(); - copyInfo->count = 1; copyInfo->sq = (AiQpRMAWQ *)(void *)(copyInfo + 1); - copyInfo->rq = (AiQpRMAWQ *)(void *)(copyInfo->sq + rankCount_); - copyInfo->scq = (AiQpRMACQ *)(void *)(copyInfo->rq + rankCount_); - copyInfo->rcq = (AiQpRMACQ *)(void *)(copyInfo->scq + rankCount_); - copyInfo->mr = (RdmaMemRegionInfo *)(void *)(copyInfo->rcq + rankCount_); + copyInfo->rq = (AiQpRMAWQ *)(void *)(copyInfo->sq + rankCount_ * + numOfQpsPerConn_); + copyInfo->scq = (AiQpRMACQ *)(void *)(copyInfo->rq + rankCount_ * + numOfQpsPerConn_); + copyInfo->rcq = (AiQpRMACQ *)(void *)(copyInfo->scq + rankCount_ * + numOfQpsPerConn_); + copyInfo->mr = (RdmaMemRegionInfo *)(void *)(copyInfo->rcq + rankCount_ * + numOfQpsPerConn_); + copyInfo->next_qp_idx = (uint8_t *)(void *)(copyInfo->mr + rankCount_); for (auto it = currentRanksInfo_.begin(); it != currentRanksInfo_.end(); ++it) { copyInfo->mr[it->first].size = it->second.mr.size; copyInfo->mr[it->first].addr = it->second.mr.address; @@ -543,28 +578,41 @@ int DeviceQpManager::FillQpInfo(ConnQpType qpType) noexcept return SHMEM_INNER_ERROR; } - CopyAiWQInfo(copyInfo->sq[it->first], pos->second.aiQpInfo.data_plane_info.sq, DBMode::HW_DB, slevel); - CopyAiWQInfo(copyInfo->rq[it->first], pos->second.aiQpInfo.data_plane_info.rq, DBMode::SW_DB, slevel); - CopyAiCQInfo(copyInfo->scq[it->first], pos->second.aiQpInfo.data_plane_info.scq, DBMode::HW_DB); - CopyAiCQInfo(copyInfo->rcq[it->first], pos->second.aiQpInfo.data_plane_info.rcq, DBMode::SW_DB); + struct AiQpRMAWQ *sq = ©Info->sq[it->first * numOfQpsPerConn_]; + struct AiQpRMAWQ *rq = ©Info->rq[it->first * numOfQpsPerConn_]; + struct AiQpRMACQ *scq = ©Info->scq[it->first * numOfQpsPerConn_]; + struct AiQpRMACQ *rcq = ©Info->rcq[it->first * numOfQpsPerConn_]; + for (int i = 0; i < numOfQpsPerConn_; i++) { + CopyAiWQInfo(*sq, pos->second.aiQpInfo[i].data_plane_info.sq, DBMode::HW_DB, slevel); + CopyAiWQInfo(*rq, pos->second.aiQpInfo[i].data_plane_info.rq, DBMode::SW_DB, slevel); + CopyAiCQInfo(*scq, pos->second.aiQpInfo[i].data_plane_info.scq, DBMode::HW_DB); + CopyAiCQInfo(*rcq, pos->second.aiQpInfo[i].data_plane_info.rcq, DBMode::SW_DB); + sq++; + rq++; + scq++; + rcq++; + } } auto pointer = (ptrdiff_t)(void *)(qpInfo_); pointer += sizeof(AiQpRMAQueueInfo); copyInfo->sq = (AiQpRMAWQ *)(void *)(pointer); - pointer += sizeof(AiQpRMAWQ) * rankCount_; + pointer += sizeof(AiQpRMAWQ) * rankCount_ * numOfQpsPerConn_; copyInfo->rq = (AiQpRMAWQ *)(void *)(pointer); - pointer += sizeof(AiQpRMAWQ) * rankCount_; + pointer += sizeof(AiQpRMAWQ) * rankCount_ * numOfQpsPerConn_; copyInfo->scq = (AiQpRMACQ *)(void *)(pointer); - pointer += sizeof(AiQpRMACQ) * rankCount_; + pointer += sizeof(AiQpRMACQ) * rankCount_ * numOfQpsPerConn_; copyInfo->rcq = (AiQpRMACQ *)(void *)(pointer); - pointer += sizeof(AiQpRMACQ) * rankCount_; + pointer += sizeof(AiQpRMACQ) * rankCount_ * numOfQpsPerConn_; copyInfo->mr = (RdmaMemRegionInfo *)(void *)pointer; + pointer += sizeof(RdmaMemRegionInfo) * rankCount_; + copyInfo->next_qp_idx = (uint8_t *)(void *)pointer; + auto ret = aclrtMemcpy(qpInfo_, qpInfoSize_, copyInfo, qpInfoSize_, ACL_MEMCPY_HOST_TO_DEVICE); if (ret != 0) { SHM_LOG_ERROR("copy qp info to device failed: " << ret); @@ -637,20 +685,22 @@ void DeviceQpManager::CloseConnections(std::unordered_map socketCloseInfos; for (auto it = connections.begin(); it != connections.end(); ++it) { - if (it->second.qpHandles[CONN_QP_AI_CORE] != nullptr) { - auto ret = DlHccpApi::RaQpDestroy(it->second.qpHandles[CONN_QP_AI_CORE]); - if (ret != 0) { - SHM_LOG_WARN("destroy AI QP to server: " << it->first << " failed: " << ret); + for (int i = 0; i < numOfQpsPerConn_; i++) { + if (it->second.qpHandles[CONN_QP_AI_CORE][i] != nullptr) { + auto ret = DlHccpApi::RaQpDestroy(it->second.qpHandles[CONN_QP_AI_CORE][i]); + if (ret != 0) { + SHM_LOG_WARN("destroy AI QP #" << i << " to server: " << it->first << " failed: " << ret); + } + it->second.qpHandles[CONN_QP_AI_CORE][i] = nullptr; } - it->second.qpHandles[CONN_QP_AI_CORE] = nullptr; } - if (it->second.qpHandles[CONN_QP_STARS] != nullptr) { - auto ret = DlHccpApi::RaQpDestroy(it->second.qpHandles[CONN_QP_STARS]); + if (it->second.qpHandles[CONN_QP_STARS][0] != nullptr) { + auto ret = DlHccpApi::RaQpDestroy(it->second.qpHandles[CONN_QP_STARS][0]); if (ret != 0) { SHM_LOG_WARN("destroy stars QP to server: " << it->first << " failed: " << ret); } - it->second.qpHandles[CONN_QP_STARS] = nullptr; + it->second.qpHandles[CONN_QP_STARS][0] = nullptr; } if (it->second.socketFd != nullptr) { diff --git a/src/modules/transport/rdma/device_qp_manager.h b/src/modules/transport/rdma/device_qp_manager.h index 313a9316bcfbb9fc5aeccdce3d343cbf72ec3149..ee81e895ba7383a54c778e635d90bc0e081b6293 100644 --- a/src/modules/transport/rdma/device_qp_manager.h +++ b/src/modules/transport/rdma/device_qp_manager.h @@ -19,7 +19,7 @@ class DeviceQpManager { public: DeviceQpManager(uint32_t deviceId, uint32_t rankId, uint32_t rankCount, sockaddr_in devNet, - hybm_role_type role) noexcept; + hybm_role_type role, uint32_t numOfQpsPerConn) noexcept; ~DeviceQpManager() noexcept; int SetRemoteRankInfo(const std::unordered_map &ranks) noexcept; @@ -42,6 +42,7 @@ protected: const hybm_role_type rankRole_; sockaddr_in deviceAddress_; void *serverSocketHandle_{nullptr}; + uint8_t numOfQpsPerConn_; private: enum ConnQpType : uint32_t { @@ -54,8 +55,8 @@ private: in_addr remoteIp; void *socketHandle; void *socketFd{nullptr}; - void *qpHandles[CONN_QP_COUNT]{}; - HccpAiQpInfo aiQpInfo{}; + void *qpHandles[CONN_QP_COUNT][SHMEM_MAX_QPS_PER_CONN]{}; + HccpAiQpInfo aiQpInfo[SHMEM_MAX_QPS_PER_CONN]{}; int qpStatus{-1}; explicit ConnectionChannel(const in_addr ip) : ConnectionChannel{ip, nullptr} {} @@ -68,7 +69,7 @@ private: int GenerateWhiteList() noexcept; int WaitConnectionsReady(std::unordered_map &connections) noexcept; int CreateQpWaitingReady(std::unordered_map &connections, ConnQpType qpType) noexcept; - int CreateOneQp(ConnQpType qpType, ConnectionChannel &channel) noexcept; + int CreateOneQp(ConnQpType qpType, ConnectionChannel &channel, uint8_t qp_idx) noexcept; int FillQpInfo(ConnQpType qpType) noexcept; void CopyAiWQInfo(struct AiQpRMAWQ &dest, const struct ai_data_plane_wq &src, DBMode dbMode, uint32_t sl) noexcept; void CopyAiCQInfo(struct AiQpRMACQ &dest, const ai_data_plane_cq &source, DBMode dbMode) noexcept; diff --git a/src/modules/transport/rdma/dl_hccp_def.h b/src/modules/transport/rdma/dl_hccp_def.h index 4d905f8388f84e449235ca5b54fa2fbabeb61be1..750c9380e68381796ebaa6f385dc58f2b14465cc 100644 --- a/src/modules/transport/rdma/dl_hccp_def.h +++ b/src/modules/transport/rdma/dl_hccp_def.h @@ -393,7 +393,7 @@ struct RdmaMemRegionInfo { }; struct AiQpRMAQueueInfo { - uint32_t count; + uint8_t *next_qp_idx; struct AiQpRMAWQ *sq; struct AiQpRMAWQ *rq; struct AiQpRMACQ *scq; @@ -561,6 +561,7 @@ struct TransportOptions { int nic; int32_t dev_id; int32_t logic_dev_id; + int num_qps_per_conn; }; struct TransportMemoryRegion { diff --git a/src/modules/transport/rdma/rdma_manager.h b/src/modules/transport/rdma/rdma_manager.h index 9f2fbbe32dcfb69bdfcd15ffe3da5e66d9166fc4..9f9526194d1d63f96a8b9dcd491c9f4d853e18ef 100644 --- a/src/modules/transport/rdma/rdma_manager.h +++ b/src/modules/transport/rdma/rdma_manager.h @@ -71,7 +71,7 @@ public: deviceAddr.sin_family = AF_INET; deviceAddr.sin_addr = deviceIp_; deviceAddr.sin_port = devicePort_; - qpManager_ = new DeviceQpManager(deviceId_, rankId_, rankCount_, deviceAddr, HYBM_ROLE_PEER); + qpManager_ = new DeviceQpManager(deviceId_, rankId_, rankCount_, deviceAddr, HYBM_ROLE_PEER, options.num_qps_per_conn); return 0; } diff --git a/src/modules/transport/shmemi_rdma.cpp b/src/modules/transport/shmemi_rdma.cpp index 7a71114387f6b7ac2805bfe0fa3175dc10a3e178..bbadaf38bbde0c6cab3ac8e8cc25bd5785904a4b 100644 --- a/src/modules/transport/shmemi_rdma.cpp +++ b/src/modules/transport/shmemi_rdma.cpp @@ -75,6 +75,7 @@ int shmemi_rdma_init(shmemi_transport *t, shmemi_device_host_state_t *state) { options.nic = 10002; options.dev_id = t->dev_id; options.logic_dev_id = t->logical_dev_id; + options.num_qps_per_conn = state->num_qps_per_conn; manager->OpenDevice(options); TransportMemoryRegion mr;