From 4b2b85541cbecf6941b281206ee8a9f702c552b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E6=99=BA=E6=85=A7?= Date: Fri, 18 Jul 2025 09:00:25 +0800 Subject: [PATCH] change sample2 depend from llm_engine to llm_datadist --- .../11_llm_data_dist/CMakeLists.txt | 4 +- .../11_llm_data_dist/decoder_sample2.cpp | 14 +- .../11_llm_data_dist/prompt_sample2.cpp | 14 +- .../11_llm_data_dist/readme.md | 7 +- .../level1_single_api/12_adxl/CMakeLists.txt | 48 ++++ .../12_adxl/adxl_engine_sample.cpp | 271 ++++++++++++++++++ cplusplus/level1_single_api/12_adxl/readme.md | 77 +++++ .../10_llm_data_dist/README.md | 10 +- .../switch_role_sample.py | 32 ++- 9 files changed, 446 insertions(+), 31 deletions(-) create mode 100644 cplusplus/level1_single_api/12_adxl/CMakeLists.txt create mode 100644 cplusplus/level1_single_api/12_adxl/adxl_engine_sample.cpp create mode 100644 cplusplus/level1_single_api/12_adxl/readme.md diff --git a/cplusplus/level1_single_api/11_llm_data_dist/CMakeLists.txt b/cplusplus/level1_single_api/11_llm_data_dist/CMakeLists.txt index 25addfeab..5691c49c5 100644 --- a/cplusplus/level1_single_api/11_llm_data_dist/CMakeLists.txt +++ b/cplusplus/level1_single_api/11_llm_data_dist/CMakeLists.txt @@ -90,7 +90,7 @@ target_link_directories(prompt_sample2 PRIVATE ) target_link_libraries(prompt_sample2 PRIVATE - llm_engine + llm_datadist graph ascendcl ) @@ -115,7 +115,7 @@ target_link_directories(decoder_sample2 PRIVATE ) target_link_libraries(decoder_sample2 PRIVATE - llm_engine + llm_datadist graph ascendcl ) \ No newline at end of file diff --git a/cplusplus/level1_single_api/11_llm_data_dist/decoder_sample2.cpp b/cplusplus/level1_single_api/11_llm_data_dist/decoder_sample2.cpp index 909be6ddd..41d94f042 100644 --- a/cplusplus/level1_single_api/11_llm_data_dist/decoder_sample2.cpp +++ b/cplusplus/level1_single_api/11_llm_data_dist/decoder_sample2.cpp @@ -50,11 +50,6 @@ int Initialize(LlmDataDist &llmDataDist, const std::string &deviceId) { std::map options; options[OPTION_DEVICE_ID] = deviceId.c_str(); - if (std::getenv("LOCAL_COMM_RES") == nullptr) { - printf("[ERROR] env:LOCAL_COMM_RES not set\n"); - return -1; - } - options[OPTION_LOCAL_COMM_RES] = std::getenv("LOCAL_COMM_RES"); auto ret = llmDataDist.Initialize(options); if (ret != LLM_SUCCESS) { printf("[ERROR] Initialize failed, ret = %u\n", ret); @@ -77,11 +72,16 @@ int32_t SetRole(LlmDataDist &llmDataDist, LlmRole role, const char *localIp) return 0; } -int Link(LlmDataDist &llmDataDist, const char *remoteIp) +int Link(LlmDataDist &llmDataDist, const char *localIp, const char *remoteIp) { std::vector rets; std::vector clusters; ClusterInfo clusterInfo; + clusterInfo.remote_cluster_id = 0; + IpInfo localIpInfo; + localIpInfo.ip = localIp; + localIpInfo.port = PROMPT_LISTEN_PORT; + clusterInfo.local_ip_infos.emplace_back(std::move(localIpInfo)); IpInfo remoteIpInfo; remoteIpInfo.ip = remoteIp; remoteIpInfo.port = PROMPT_LISTEN_PORT; @@ -226,7 +226,7 @@ int32_t RunDecoderSample(const char *deviceId, const char *localIp, const char * std::this_thread::sleep_for(std::chrono::seconds(WAIT_PROMPT_TIME)); // 5. 与prompt建链 - if (Link(llmDataDist, remoteIp) != 0) { + if (Link(llmDataDist, localIp, remoteIp) != 0) { Finalize(llmDataDist, cacheId, linked, remoteIp, buffers); return -1; } diff --git a/cplusplus/level1_single_api/11_llm_data_dist/prompt_sample2.cpp b/cplusplus/level1_single_api/11_llm_data_dist/prompt_sample2.cpp index 52abdafc4..83a176d7a 100644 --- a/cplusplus/level1_single_api/11_llm_data_dist/prompt_sample2.cpp +++ b/cplusplus/level1_single_api/11_llm_data_dist/prompt_sample2.cpp @@ -49,11 +49,6 @@ int Initialize(LlmDataDist &llmDataDist, const std::string &deviceId, const std: std::map options; options[OPTION_DEVICE_ID] = deviceId.c_str(); options[OPTION_LISTEN_IP_INFO] = (localIp + ":" + std::to_string(PROMPT_LISTEN_PORT)).c_str(); - if (std::getenv("LOCAL_COMM_RES") == nullptr) { - printf("[ERROR] env:LOCAL_COMM_RES not set\n"); - return -1; - } - options[OPTION_LOCAL_COMM_RES] = std::getenv("LOCAL_COMM_RES"); auto ret = llmDataDist.Initialize(options); if (ret != LLM_SUCCESS) { printf("[ERROR] Initialize failed, ret = %u\n", ret); @@ -75,11 +70,16 @@ int32_t SetRole(LlmDataDist &llmDataDist, LlmRole role) return 0; } -int Link(LlmDataDist &llmDataDist, const char *remoteIp) +int Link(LlmDataDist &llmDataDist, const char *localIp, const char *remoteIp) { std::vector rets; std::vector clusters; ClusterInfo clusterInfo; + clusterInfo.remote_cluster_id = 1; + IpInfo localIpInfo; + localIpInfo.ip = localIp; + localIpInfo.port = DECODER_LISTEN_PORT; + clusterInfo.local_ip_infos.emplace_back(std::move(localIpInfo)); IpInfo remoteIpInfo; remoteIpInfo.ip = remoteIp; remoteIpInfo.port = DECODER_LISTEN_PORT; @@ -228,7 +228,7 @@ int32_t RunPromptSample(const char *deviceId, const char *localIp, const char *r } // 6. 与decoder建链 - if (Link(llmDataDist, remoteIp) != 0) { + if (Link(llmDataDist, localIp, remoteIp) != 0) { Finalize(llmDataDist, cacheId, linked, remoteIp, buffers); return -1; } diff --git a/cplusplus/level1_single_api/11_llm_data_dist/readme.md b/cplusplus/level1_single_api/11_llm_data_dist/readme.md index 9c5546e3a..c591fbe8e 100644 --- a/cplusplus/level1_single_api/11_llm_data_dist/readme.md +++ b/cplusplus/level1_single_api/11_llm_data_dist/readme.md @@ -89,11 +89,10 @@ - 执行prompt_sample2, 参数为device_id、local_host_ip和remote_host_ip, 其中device_id为prompt要使用的device_id, local_host_ip为prompt所在host的ip, remote_host_ip为decoder所在host的ip,如: ``` - LOCAL_COMM_RES='{"status": "completed", "version": "1.0", "server_list": [{"server_id": "node_1", "device": [{"device_id": "0", "device_ip": "10.10.10.1"}]}]}' ./prompt_sample2 0 10.10.170.1 10.170.10.2 + ./prompt_sample2 0 10.10.170.1 10.170.10.2 ``` - 执行decoder_sample2, 参数为device_id、local_host_ip和remote_host_ip, 其中device_id为decoder要使用的device_id, local_host_ip为decoder所在host的ip,remote_host_ip为prompt所在host的ip,如: ``` - LOCAL_COMM_RES='{"status": "completed", "version": "1.0", "server_list": [{"server_id": "node_1", "device": [{"device_id": "1", "device_ip": "10.10.10.2"}]}]}' ./decoder_sample2 1 10.170.10.2 10.170.10.1 - ``` - **注**:LOCAL_COMM_RES为sample2执行所需环境变量,配置了当前进程所需的通信资源,将传递给llm_datadist作为初始化option; 配置格式与HCCL的ranktable一致,只需要配置本进程第一个参数device_id对应的信息,其中ranktable中的rank_id和server_count字段不需要配置,当前用例配置为A2的ranktable格式,其他环境需参考对应环境的ranktable格式进行配置 \ No newline at end of file + ./decoder_sample2 1 10.170.10.2 10.170.10.1 + ``` \ No newline at end of file diff --git a/cplusplus/level1_single_api/12_adxl/CMakeLists.txt b/cplusplus/level1_single_api/12_adxl/CMakeLists.txt new file mode 100644 index 000000000..bfc67c317 --- /dev/null +++ b/cplusplus/level1_single_api/12_adxl/CMakeLists.txt @@ -0,0 +1,48 @@ +cmake_minimum_required(VERSION 3.5.1) +project(adxl_sample) + +set(CMAKE_VERBOSE_MAKEFILE ON) +set(CMAKE_SKIP_INSTALL_ALL_DEPENDENCY TRUE) + +if (DEFINED ENV{ASCEND_INSTALL_PATH}) + set(ASCEND_PATH $ENV{ASCEND_INSTALL_PATH}) +else () + set(ASCEND_PATH /usr/local/Ascend/latest) +endif () + +set(INCLUDE_DIR ${ASCEND_PATH}/include) + +set(common_compile_options + --std=c++11 + -g + -Wall +) + +set(common_compile_definitions + _GLIBCXX_USE_CXX11_ABI=0 +) + +add_executable(adxl_engine_sample "adxl_engine_sample.cpp") + +target_compile_options(adxl_engine_sample PRIVATE + ${common_compile_options} +) + +target_compile_definitions(adxl_engine_sample PRIVATE + ${common_compile_definitions} +) + +target_include_directories(adxl_engine_sample PRIVATE + ${INCLUDE_DIR} + ${INCLUDE_DIR}/external/ge_common +) + +target_link_directories(adxl_engine_sample PRIVATE + ${ASCEND_PATH}/lib64 +) + +target_link_libraries(adxl_engine_sample PRIVATE + adxl + graph + ascendcl +) \ No newline at end of file diff --git a/cplusplus/level1_single_api/12_adxl/adxl_engine_sample.cpp b/cplusplus/level1_single_api/12_adxl/adxl_engine_sample.cpp new file mode 100644 index 000000000..a2252b8d3 --- /dev/null +++ b/cplusplus/level1_single_api/12_adxl/adxl_engine_sample.cpp @@ -0,0 +1,271 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "acl/acl.h" +#include "adxl/adxl_engine.h" + +using namespace adxl; +namespace{ +constexpr int32_t WAIT_REG_TIME = 5; +constexpr int32_t WAIT_TRANS_TIME = 20; +constexpr int32_t CLIENT_EXPECTED_ARG_CNT = 4; +constexpr uint32_t ARG_INDEX_DEVICE_ID = 1; +constexpr uint32_t ARG_INDEX_LOCAL_ENINE = 2; +constexpr uint32_t CLIENT_ARG_INDEX_REMOTE_ENINE = 3; +constexpr int32_t SERVER_EXPECTED_ARG_CNT = 3; + +#define CHECK_ACL(x) \ + do { \ + aclError __ret = x; \ + if (__ret != ACL_ERROR_NONE) { \ + std::cerr << __FILE__ << ":" << __LINE__ << " aclError:" << __ret << std::endl; \ + } \ + } while (0); +} + +int Initialize(AdxlEngine &adxlEngine, const char *localEngine) +{ + std::map options; + auto ret = adxlEngine.Initialize(localEngine, options); + if (ret != SUCCESS) { + printf("[ERROR] Initialize failed, ret = %u\n", ret); + return -1; + } + printf("[INFO] Initialize success\n"); + return 0; +} + +int Connect(AdxlEngine &adxlEngine, const char *remoteEngine) +{ + auto ret = adxlEngine.Connect(remoteEngine); + if (ret != SUCCESS) { + printf("[ERROR] Connect failed, ret = %u\n", ret); + return -1; + } + printf("[INFO] Connect success\n"); + return 0; +} + +int Disconnect(AdxlEngine &adxlEngine, const char *remoteEngine) +{ + auto ret = adxlEngine.Disconnect(remoteEngine); + if (ret != SUCCESS) { + printf("[ERROR] Disconnect failed, ret = %u\n", ret); + return -1; + } + printf("[INFO] Disconnect success\n"); + return 0; +} + +int32_t Transfer(AdxlEngine &adxlEngine, int32_t &src, const char *remoteEngine) +{ + uintptr_t dstAddr; + std::ifstream("./tmp") >> std::hex >> dstAddr; + + TransferOpDesc desc{reinterpret_cast(&src), reinterpret_cast(dstAddr), sizeof(int32_t)}; + auto ret = adxlEngine.TransferSync(remoteEngine, READ, {desc}); + if (ret != SUCCESS) { + printf("[ERROR] TransferSync read failed, ret = %u\n", ret); + return -1; + } + printf("[INFO] TransferSync read success, src = %d\n", src); + + src = 2; + ret = adxlEngine.TransferSync(remoteEngine, WRITE, {desc}); + if (ret != SUCCESS) { + printf("[ERROR] TransferSync write failed, ret = %u\n", ret); + return -1; + } + printf("[INFO] TransferSync write success, src = %d\n", src); + return 0; +} + +void ClientFinalize(AdxlEngine &adxlEngine, bool connected, const char *remoteEngine, + const std::vector handles, const std::vector hostBuffers = {}) +{ + if (connected) { + auto ret = Disconnect(adxlEngine, remoteEngine); + if (ret != 0) { + printf("[ERROR] Disconnect failed, ret = %d\n", ret); + } else { + printf("[INFO] Disconnect success\n"); + } + } + + for (auto handle : handles) { + auto ret = adxlEngine.DeregisterMem(handle); + if (ret != 0) { + printf("[ERROR] DeregisterMem failed, ret = %u\n", ret); + } else { + printf("[INFO] DeregisterMem success\n"); + } + } + for (auto buffer : hostBuffers) { + aclrtFreeHost(buffer); + } + adxlEngine.Finalize(); +} + +void ServerFinalize(AdxlEngine &adxlEngine, + const std::vector handles, + const std::vector devBuffers = {}) +{ + for (auto handle : handles) { + auto ret = adxlEngine.DeregisterMem(handle); + if (ret != 0) { + printf("[ERROR] DeregisterMem failed, ret = %u\n", ret); + } else { + printf("[INFO] DeregisterMem success\n"); + } + } + for (auto buffer : devBuffers) { + aclrtFree(buffer); + } + adxlEngine.Finalize(); +} + +int32_t RunClient(const char *localEngine, const char *remoteEngine) +{ + printf("[INFO] client start\n"); + // 1. 初始化 + AdxlEngine adxlEngine; + if (Initialize(adxlEngine, localEngine) != 0) { + printf("[ERROR] Initialize AdxlEngine failed\n"); + return -1; + } + // 2. 注册内存地址 + int32_t *src = nullptr; + CHECK_ACL(aclrtMallocHost(reinterpret_cast(&src), sizeof(int32_t))); + bool connected = false; + MemDesc desc{}; + desc.addr = reinterpret_cast(src); + desc.len = sizeof(int32_t); + MemHandle handle = nullptr; + auto ret = adxlEngine.RegisterMem(desc, MEM_HOST, handle); + if (ret != SUCCESS) { + printf("[ERROR] RegisterMem failed, ret = %u\n", ret); + ClientFinalize(adxlEngine, connected, remoteEngine, {handle}, {src}); + return -1; + } + printf("[INFO] RegisterMem success\n"); + + // 等待server注册完成 + std::this_thread::sleep_for(std::chrono::seconds(WAIT_REG_TIME)); + + // 3. 与server建链 + if (Connect(adxlEngine, remoteEngine) != 0) { + ClientFinalize(adxlEngine, connected, remoteEngine, {handle}, {src}); + return -1; + } + connected = true; + + // 4. 从server get内存,并向server put内存 + if (Transfer(adxlEngine, *src, remoteEngine) != 0) { + ClientFinalize(adxlEngine, connected, remoteEngine, {handle}, {src}); + return -1; + } + + // 5. 释放Cache与llmDataDist + ClientFinalize(adxlEngine, connected, remoteEngine, {handle}, {src}); + printf("[INFO] Finalize success\n"); + printf("[INFO] Prompt Sample end\n"); + return 0; +} + +int32_t RunServer(const char *localEngine) +{ + printf("[INFO] server start\n"); + // 1. 初始化 + AdxlEngine adxlEngine; + if (Initialize(adxlEngine, localEngine) != 0) { + printf("[ERROR] Initialize AdxlEngine failed\n"); + return -1; + } + // 2. 注册内存地址 + int32_t dst = 1; + int32_t *buffer = nullptr; + CHECK_ACL(aclrtMalloc((void **)&buffer, sizeof(int32_t), ACL_MEM_MALLOC_HUGE_ONLY)); + // init device buffer + CHECK_ACL(aclrtMemcpy(buffer, sizeof(int32_t), &dst, sizeof(int32_t), ACL_MEMCPY_HOST_TO_DEVICE)); + + MemDesc desc{}; + desc.addr = reinterpret_cast(buffer); + desc.len = sizeof(int32_t); + MemHandle handle = nullptr; + auto ret = adxlEngine.RegisterMem(desc, MEM_DEVICE, handle); + if (ret != SUCCESS) { + printf("[ERROR] RegisterMem failed, ret = %u\n", ret); + ServerFinalize(adxlEngine, {handle}, {buffer}); + return -1; + } + // 3. RegisterMem成功后,将地址保存到本地文件中等待client读取 + printf("[INFO] RegisterMem success, dst addr:%p\n", buffer); + std::ofstream tmp_file("./tmp"); // 默认就是 std::ios::out | std::ios::trunc + if (tmp_file) { + tmp_file << buffer << std::endl; + } + + // 4. 等待client transfer + std::this_thread::sleep_for(std::chrono::seconds(WAIT_TRANS_TIME)); + + CHECK_ACL(aclrtMemcpy(&dst, sizeof(int32_t), buffer, sizeof(int32_t), ACL_MEMCPY_DEVICE_TO_HOST)); + printf("[INFO] After transfer, dst value:%d\n", dst); + + // 5. 释放Cache与llmDataDist + ServerFinalize(adxlEngine, {handle}, {buffer}); + printf("[INFO] Finalize success\n"); + printf("[INFO] server Sample end\n"); + return 0; +} + +int main(int32_t argc, char **argv) +{ + bool isClient = false; + std::string deviceId; + std::string localEngine; + std::string remoteEngine; + if (argc == CLIENT_EXPECTED_ARG_CNT) { + isClient = true; + deviceId = argv[ARG_INDEX_DEVICE_ID]; + localEngine = argv[ARG_INDEX_LOCAL_ENINE]; + remoteEngine = argv[CLIENT_ARG_INDEX_REMOTE_ENINE]; + printf("[INFO] deviceId = %s, localEngine = %s, remoteEngine = %s\n", + deviceId.c_str(), localEngine.c_str(), remoteEngine.c_str()); + } else if (argc == SERVER_EXPECTED_ARG_CNT) { + deviceId = argv[ARG_INDEX_DEVICE_ID]; + localEngine = argv[ARG_INDEX_LOCAL_ENINE]; + printf("[INFO] deviceId = %s, localEngine = %s\n", deviceId.c_str(), localEngine.c_str()); + } else { + printf("[ERROR] client expect 3 args(deviceId, localEngine, remoteEngine), " + "server expect 2 args(deviceId, localEngine), but got %d\n", argc - 1); + } + int32_t device = std::stoi(deviceId); + CHECK_ACL(aclrtSetDevice(device)); + + int32_t ret = 0; + if (isClient) { + ret = RunClient(localEngine.c_str(), remoteEngine.c_str()); + } else { + ret = RunServer(localEngine.c_str()); + } + CHECK_ACL(aclrtResetDevice(device)); + return ret; +} \ No newline at end of file diff --git a/cplusplus/level1_single_api/12_adxl/readme.md b/cplusplus/level1_single_api/12_adxl/readme.md new file mode 100644 index 000000000..9fc957373 --- /dev/null +++ b/cplusplus/level1_single_api/12_adxl/readme.md @@ -0,0 +1,77 @@ +## 目录 + +- [样例介绍](#样例介绍) +- [目录结构](#目录结构) +- [环境要求](#环境要求) +- [程序编译](#程序编译) +- [样例运行](#样例运行) + + +## 样例介绍 + +功能:通过adxl engine接口实现Cache傳輸功能。 + + +## 目录结构 + +``` +├── adxl_engine_sample.cpp // adxl_engine样例 +├── CMakeLists.txt // 编译脚本 +``` + + +## 环境要求 + +- 操作系统及架构:Euleros x86系统、Euleros aarch64系统 +- 编译器:g++ +- 芯片:Atlas 训练系列产品、Atlas 推理系列产品(配置Ascend 310P AI处理器) +- python及依赖的库:python3.7.5 +- 已完成昇腾AI软件栈在运行环境上的部署 + +## 程序编译 + +1. 修改CMakeLists.txt文件中的安装包路径 + +2. 执行如下命令进行编译。 + + 依次执行: + + ``` + mkdir build && cd build + cmake .. && make + ``` + +3. 编译结束后,在**build**目录下生成可执行文件**adxl_engine_sample**。 + +## 样例运行 +1. 配置环境变量 + - 若运行环境上安装的“Ascend-cann-toolkit”包,环境变量设置如下: + + ``` + . ${HOME}/Ascend/ascend-toolkit/set_env.sh + ``` + + “$HOME/Ascend”请替换相关软件包的实际安装路径。 + + - 若运行环境上安装的“CANN-XXX.run”包,环境变量设置如下: + + ``` + source ${HOME}/Ascend/latest/bin/setenv.bash + ``` + + “$HOME/Ascend”请替换相关软件包的实际安装路径。 + +2. 在运行环境执行可执行文件。 + + 3.1 执行sample + + - 执行client adxl_engine_sample, 参数为device_id、local engine和remote engine, 其中device_id为client要使用的device_id,如: + ``` + HCCL_INTRA_ROCE_ENABLE=1 ./adxl_engine_sample 0 10.10.10.0 10.10.10.1:16000 + ``` + + - 执行server adxl_engine_sample, 参数为device_id、local engine, 其中device_id为server要使用的device_id, 如: + ``` + HCCL_INTRA_ROCE_ENABLE=1 ./adxl_engine_sample 1 1 10.10.10.1:16000 + ``` + **注**:HCCL_INTRA_ROCE_ENABLE=1表示使用RDMA进行传输 \ No newline at end of file diff --git a/python/level1_single_api/10_llm_data_dist/README.md b/python/level1_single_api/10_llm_data_dist/README.md index 76c3225e5..ddddf5e88 100644 --- a/python/level1_single_api/10_llm_data_dist/README.md +++ b/python/level1_single_api/10_llm_data_dist/README.md @@ -97,14 +97,14 @@ # Decoder主机: python pull_from_cache_to_blocks.py --device_id 0 --cluster_id 2 ``` - - 执行switch role样例程序,此样例程序使用单侧建链方式,首先torch自行申请内存并注册blocks, - decoder发起建链并pull blocks, 然后两侧切换角色, 并prompt发起建链, decoder进行push_blocks: + - switch_role_sample.py:执行switch role样例程序,此样例程序使用单侧建链方式,首先torch自行申请内存并注册blocks, + decoder发起建链并pull blocks, 然后两侧切换角色, 并prompt发起建链, decoder进行push_blocks,执行方式如下: 分别在Prompt主机与Decoder主机,执行样例程序: ``` # Prompt主机: - LOCAL_COMM_RES='{"status": "completed", "version": "1.0", "server_list": [{"server_id": "node_1", "device": [{"device_id": "0", "device_ip": "10.10.10.1"}]}]}' GLOO_SOCKET_IFNAME=enp67s0f5 HCCL_INTRA_ROCE_ENABLE=1 python switch_role_sample.py --device_id 0 --role p --local_host_ip 10.170.10 --remote_host_ip 10.170.10 + GLOO_SOCKET_IFNAME=enp67s0f5 HCCL_INTRA_ROCE_ENABLE=1 python switch_role_sample.py --device_id 0 --role p --local_host_ip 10.170.10.0 --remote_host_ip 10.170.10.1 # Decoder主机: - LOCAL_COMM_RES='{"status": "completed", "version": "1.0", "server_list": [{"server_id": "node_1", "device": [{"device_id": "1", "device_ip": "10.10.10.2"}]}]}' GLOO_SOCKET_IFNAME=enp67s0f5 HCCL_INTRA_ROCE_ENABLE=1 python switch_role_sample.py --device_id 1 --role d --local_host_ip 10.170.10 --remote_host_ip 10.170.10 + GLOO_SOCKET_IFNAME=enp67s0f5 HCCL_INTRA_ROCE_ENABLE=1 python switch_role_sample.py --device_id 1 --role d --local_host_ip 10.170.10.1 --remote_host_ip 10.170.10.0 ``` - **注**:**LOCAL_COMM_RES**为单侧建链方式执行所需环境变量,配置了当前进程所需的通信资源,将传递给llm_datadist作为初始化option; 配置格式与HCCL的ranktable一致,只需要配置本进程参数device_id对应的信息,其中ranktable中的rank_id和server_count字段不需要配置,当前用例配置为A2的ranktable格式,其他环境需参考对应环境的ranktable格式进行配置;**GLOO_SOCKET_IFNAME**为本地网卡名,可通过ifconfig查询;**HCCL_INTRA_ROCE_ENABLE=1**代表使用roce方式进行通信; + **注**:**GLOO_SOCKET_IFNAME**为本地网卡名,可通过ifconfig查询;**HCCL_INTRA_ROCE_ENABLE=1**代表使用roce方式进行通信; diff --git a/python/level1_single_api/10_llm_data_dist/cache_manager_api_samples/switch_role_sample.py b/python/level1_single_api/10_llm_data_dist/cache_manager_api_samples/switch_role_sample.py index 616e62eee..299f48c99 100644 --- a/python/level1_single_api/10_llm_data_dist/cache_manager_api_samples/switch_role_sample.py +++ b/python/level1_single_api/10_llm_data_dist/cache_manager_api_samples/switch_role_sample.py @@ -47,9 +47,8 @@ def init_llm_datadist(role: LLMRole, cluster_id, device_id: int, local_host_ip, datadist = LLMDataDist(role, cluster_id) llm_config = LLMConfig() llm_config.device_id = device_id - if os.getenv('LOCAL_COMM_RES') is None: - raise Exception("env:LOCAL_COMM_RES is not set") - llm_config.local_comm_res = os.getenv('LOCAL_COMM_RES') + llm_config.enable_cache_manager = True + llm_config.enable_remote_cache_accessible = True if role == LLMRole.PROMPT: llm_config.listen_ip_info = f"{local_host_ip}:26000" llm_options = llm_config.generate_options() @@ -58,7 +57,8 @@ def init_llm_datadist(role: LLMRole, cluster_id, device_id: int, local_host_ip, return datadist -def run_prompt_sample(datadist, remote_host_ip): +def run_prompt_sample(datadist, local_host_ip, remote_host_ip): + # 1. 注册内存 cache_manager = datadist.cache_manager cache_desc = CacheDesc(num_tensors=NUM_TENSORS, shape=[BLOCKS_NUM, KV_SHAPE], data_type=DataType.DT_FLOAT, placement=Placement.DEVICE) @@ -68,25 +68,34 @@ def run_prompt_sample(datadist, remote_host_ip): addr2 = int(tensor2.data_ptr()) cache = cache_manager.register_blocks_cache(cache_desc, [addr, addr2], BlocksCacheKey(PROMPT_CLUSTER_ID, 0)) logging.info('register_blocks_cache success') + + # 2. 等decoder pull blocks dist.barrier() # register end logging.info('wait decoder link and pull...') dist.barrier() # decoder unlink + # 3. 切换角色 datadist.switch_role(LLMRole.DECODER) dist.barrier() # prompt switch role end, close lisen dist.barrier() # decoder switch role end, lisen + # 4. 向decoder发起建链 cluster = LLMClusterInfo() + cluster.remote_cluster_id = DECODER_CLUSTER_ID + cluster.append_local_ip_info(local_host_ip, 26000) cluster.append_remote_ip_info(remote_host_ip, 26000) ret, _ = datadist.link_clusters([cluster], 5000) if ret != LLMStatusCode.LLM_SUCCESS: raise Exception("link failed") logging.info('link success, wait decoder push...') dist.barrier() # prompt link end + + # 5. 等decoder push blocks dist.barrier() # decoder push blocks end logging.info(f'after decoder push, {tensor=}') logging.info(f'after decoder push, {tensor2=}') + # 6. 解链 cluster = LLMClusterInfo() cluster.remote_cluster_id = DECODER_CLUSTER_ID ret, _ = datadist.unlink_clusters([cluster], 5000, force=True) @@ -99,6 +108,7 @@ def run_prompt_sample(datadist, remote_host_ip): def run_decoder_sample(datadist, local_host_ip, remote_host_ip): + # 1. 注册内存 cache_manager = datadist.cache_manager cache_desc = CacheDesc(num_tensors=NUM_TENSORS, shape=[BLOCKS_NUM, KV_SHAPE], data_type=DataType.DT_FLOAT, placement=Placement.DEVICE) @@ -110,16 +120,21 @@ def run_decoder_sample(datadist, local_host_ip, remote_host_ip): logging.info('register_blocks_cache success') dist.barrier() # register end + # 2. 向prompt建链 cluster = LLMClusterInfo() + cluster.remote_cluster_id = PROMPT_CLUSTER_ID + cluster.append_local_ip_info(local_host_ip, 26000) cluster.append_remote_ip_info(remote_host_ip, 26000) ret, _ = datadist.link_clusters([cluster], 5000) if ret != LLMStatusCode.LLM_SUCCESS: - raise Exception("unlink failed") + raise Exception("link failed") + # 3. 从prompt pull blocks cache_manager.pull_blocks(BlocksCacheKey(PROMPT_CLUSTER_ID, 0), cache, src_blocks=[0, 1], dst_blocks=[0, 2]) logging.info(f'after decoder pull, {tensor=}') logging.info(f'after decoder pull, {tensor2=}') + # 4. 断链并切换角色 cluster = LLMClusterInfo() cluster.remote_cluster_id = PROMPT_CLUSTER_ID cluster.append_remote_ip_info(remote_host_ip, 26000) @@ -134,12 +149,17 @@ def run_decoder_sample(datadist, local_host_ip, remote_host_ip): llm_options = llm_config.generate_options() datadist.switch_role(LLMRole.PROMPT, llm_options) logging.info('decoder link, pull, unlink, switch role success, wait prompt link...') + + # 5. 等待prompt发起建链 dist.barrier() # decoder switch role end, lisen dist.barrier() # prompt link end + # 6. 向prompt push blocks cache_manager.push_blocks(BlocksCacheKey(PROMPT_CLUSTER_ID, 0), cache, src_blocks=[0, 1, 2], dst_blocks=[0, 1,2], src_layer_range=range(0, 2), dst_layer_range=range(0, 2), tensor_num_per_layer=1) dist.barrier() # decoder push blocks end + + # 7. 断链 cluster = LLMClusterInfo() cluster.remote_cluster_id = PROMPT_CLUSTER_ID ret, _ = datadist.unlink_clusters([cluster], 5000, force=True) @@ -172,7 +192,7 @@ if __name__ == '__main__': cluster_id = PROMPT_CLUSTER_ID if args.role == 'p' else DECODER_CLUSTER_ID datadist = init_llm_datadist(role, cluster_id, args.device_id, args.local_host_ip, args.remote_host_ip) if role == LLMRole.PROMPT: - run_prompt_sample(datadist, args.remote_host_ip) + run_prompt_sample(datadist, args.local_host_ip, args.remote_host_ip) else: run_decoder_sample(datadist, args.local_host_ip, args.remote_host_ip) logging.info('Sample end') -- Gitee