diff --git a/WORKSPACE b/WORKSPACE index 72dcfb865565d66fb11a8a2999c0ec693d50a4b1..fdd5e4429e2a757856968db8108d74d287616233 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -176,6 +176,15 @@ http_archive( build_file = "@//bazel:jacoco.bzl" ) +local_patched_repository( + name = "gloo", + path = "../thirdparty/gloo", + build_file = "@//bazel:gloo.bzl", + patch_files = [ + "@//patch:gloo-fix-sign-compare.patch", + ], +) + maybe( new_local_repository, name = "securec", diff --git a/api/cpp/BUILD.bazel b/api/cpp/BUILD.bazel index 8d9ea9f85665f93f405dc66acab46cd747cc9ba0..28522f41a207d65df480aeef3f559e3cdbebc417 100644 --- a/api/cpp/BUILD.bazel +++ b/api/cpp/BUILD.bazel @@ -23,12 +23,15 @@ cc_library( "src/utils/*.cpp", "src/executor/*.h", "src/executor/*.cpp", + "src/collective/*.h", + "src/collective/*.cpp", ]) + ["src/utils/version.h"], hdrs = glob([ "include/yr/*.h", "include/yr/api/*.h", "include/yr/parallel/*.h", "include/yr/parallel/detail/*.h", + "include/yr/collective/*.h", ]), copts = COPTS, linkopts = ["-lstdc++fs"], @@ -43,6 +46,7 @@ cc_library( "@securec", "@nlohmann_json", "@yaml-cpp", + "@gloo//:gloo", ], alwayslink = True, ) @@ -64,6 +68,7 @@ cc_library( hdrs = glob([ "include/yr/*.h", "include/yr/api/*.h", + "include/yr/collective/*.h", "include/yr/parallel/*.h", "include/yr/parallel/detail/*.h", ]), @@ -160,6 +165,7 @@ genrule( ":cpp_strip", "//:grpc_strip", "@boringssl//:shared", + "@gloo//:shared", ":cpp_include", ], outs = ["yr_cpp_pkg.out"], @@ -214,6 +220,7 @@ genrule( chrpath -d $(locations :cpp_strip) $(locations //:grpc_strip) && cp -rf $(locations :cpp_strip) $(locations //:grpc_strip) $$CPP_SDK_DIR/lib/ && cp -rf $(locations :cpp_strip) $$CPP_SDK_DIR/bin/ && + cp -rf $(locations @gloo//:shared) $$CPP_SDK_DIR/lib/ && rm -rf $$CPP_SDK_DIR/bin/libyr-api.so $$CPP_SDK_DIR/bin/libfunctionsdk.so cp -rf $$BASE_DIR/api/cpp/include $$CPP_SDK_DIR/ && DATASYSTEM_DIR=$$BASE_DIR/external/datasystem_sdk/cpp && diff --git a/api/cpp/example/collective_example.cpp b/api/cpp/example/collective_example.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3431354eca030fc1c7bb8a7b1a13926c89b66676 --- /dev/null +++ b/api/cpp/example/collective_example.cpp @@ -0,0 +1,274 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "yr/api/exception.h" +#include "yr/collective/collective.h" +#include "yr/yr.h" + +class CollectiveActor { +public: + int count; + + CollectiveActor() = default; + + ~CollectiveActor() = default; + + static CollectiveActor *FactoryCreate() + { + return new CollectiveActor(); + } + + int Compute(std::vector in, std::string &groupName, uint8_t op) + { + //! [allreduce example] + std::vector output(in.size()); + + // AllReduce: perform reduction across all processes and broadcast result to all + YR::Collective::AllReduce(in.data(), output.data(), in.size(), YR::DataType::INT, YR::ReduceOp(op), groupName); + + // Barrier: synchronize all processes + YR::Collective::Barrier(groupName); + YR::Collective::DestroyCollectiveGroup(groupName); + int result = 0; + for (int i = 0; i < in.size(); ++i) { + result += output[i]; + } + return result; + //! [allreduce example] + } + + // Example: Initialize collective group in actor + void InitCollectiveGroupExample() + { + //! [init collective group] + YR::Collective::CollectiveGroupSpec spec; + spec.worldSize = 4; + spec.groupName = "my_group"; + spec.backend = YR::Collective::Backend::GLOO; + spec.timeout = 60000; + + int rank = 0; // Current process rank + YR::Collective::InitCollectiveGroup(spec, rank); + + // Use collective communication operations... + //! [init collective group] + } + + // Example: Get world size + void GetWorldSizeExample() + { + //! [get world size] + // Initialize collective communication group... + std::string groupName = "my_group"; + int worldSize = YR::Collective::GetWorldSize(groupName); + std::cout << "World size: " << worldSize << std::endl; + //! [get world size] + } + + // Example: Get rank + void GetRankExample() + { + //! [get rank] + // Initialize collective communication group... + std::string groupName = "my_group"; + int rank = YR::Collective::GetRank(groupName); + std::cout << "My rank: " << rank << std::endl; + //! [get rank] + } + + // Example: Reduce operation + void ReduceExample() + { + //! [reduce example] + std::string groupName = "my_group"; + std::vector localData(100, 1); + std::vector result(100); + + int rootRank = 0; // Result sent to rank 0 + YR::Collective::Reduce(localData.data(), result.data(), localData.size(), YR::DataType::INT, YR::ReduceOp::SUM, + rootRank, groupName); + + int rank = YR::Collective::GetRank(groupName); + if (rank == rootRank) { + // Only rootRank's result contains valid result + std::cout << "Reduced result: " << result[0] << std::endl; + } + //! [reduce example] + } + + // Example: AllGather operation + void AllGatherExample() + { + //! [allgather example] + std::string groupName = "my_group"; + int rank = YR::Collective::GetRank(groupName); + int worldSize = YR::Collective::GetWorldSize(groupName); + + // Each process sends different data + std::vector localData(10, static_cast(rank)); + std::vector gatheredData(10 * worldSize); + + YR::Collective::AllGather(localData.data(), gatheredData.data(), localData.size(), YR::DataType::FLOAT, + groupName); + //! [allgather example] + } + + // Example: Broadcast operation + void BroadcastExample() + { + //! [broadcast example] + std::string groupName = "my_group"; + int rank = YR::Collective::GetRank(groupName); + int srcRank = 0; + + std::vector data(100); + if (rank == srcRank) { + // Source process initializes data + for (int i = 0; i < 100; i++) { + data[i] = i; + } + } + + YR::Collective::Broadcast(data.data(), data.data(), data.size(), YR::DataType::INT, srcRank, groupName); + // All processes' data now contains the same content (from srcRank) + //! [broadcast example] + } + + // Example: Scatter operation + void ScatterExample() + { + //! [scatter example] + std::string groupName = "my_group"; + int rank = YR::Collective::GetRank(groupName); + int worldSize = YR::Collective::GetWorldSize(groupName); + int srcRank = 0; + + std::vector recvData(10); + + if (rank == srcRank) { + // Source process prepares data to scatter + std::vector> sendData(worldSize, std::vector(10)); + for (int i = 0; i < worldSize; i++) { + for (int j = 0; j < 10; j++) { + sendData[i][j] = i * 10 + j; + } + } + + std::vector sendbuf; + for (auto &vec : sendData) { + sendbuf.push_back(vec.data()); + } + + YR::Collective::Scatter(sendbuf, recvData.data(), 10, YR::DataType::INT, srcRank, groupName); + } else { + // Non-source processes + std::vector sendbuf; // Can be empty + YR::Collective::Scatter(sendbuf, recvData.data(), 10, YR::DataType::INT, srcRank, groupName); + } + + // Each process's recvData now contains corresponding data from srcRank + //! [scatter example] + } + + // Example: Barrier synchronization + void BarrierExample() + { + //! [barrier example] + std::string groupName = "my_group"; + int rank = YR::Collective::GetRank(groupName); + + // Perform some asynchronous operations... + std::cout << "Rank " << rank << " before barrier" << std::endl; + + YR::Collective::Barrier(groupName); + + // All processes will wait here until all processes arrive + std::cout << "Rank " << rank << " after barrier" << std::endl; + //! [barrier example] + } + + // Example: Send Recv operation + void SendRecvExample() + { + //! [send recv example] + std::string groupName = "my_group"; + int rank = YR::Collective::GetRank(groupName); + int worldSize = YR::Collective::GetWorldSize(groupName); + + if (rank == 0) { + // Rank 0 sends data to Rank 1 + std::vector data(100, 1.0f); + YR::Collective::Send(data.data(), data.size(), YR::DataType::FLOAT, 1, 0, groupName); + } else if (rank == 1) { + // Rank 1 receives data from Rank 0 + std::vector recvData(100); + YR::Collective::Recv(recvData.data(), recvData.size(), YR::DataType::FLOAT, 0, 0, groupName); + } + //! [send recv example] + } +}; + +YR_INVOKE(CollectiveActor::FactoryCreate, &CollectiveActor::Compute, &CollectiveActor::InitCollectiveGroupExample, + &CollectiveActor::GetWorldSizeExample, &CollectiveActor::GetRankExample, &CollectiveActor::ReduceExample, + &CollectiveActor::AllGatherExample, &CollectiveActor::BroadcastExample, &CollectiveActor::ScatterExample, + &CollectiveActor::BarrierExample, &CollectiveActor::SendRecvExample) + +int main(void) +{ + //! [create collective group] + YR::Config conf; + YR::Init(conf); + + // Create collective communication group with 4 instances + std::vector> instances; + std::vector instanceIDs; + for (int i = 0; i < 4; ++i) { + auto ins = YR::Instance(CollectiveActor::FactoryCreate).Invoke(); + instances.push_back(ins); + instanceIDs.push_back(ins.GetInstanceId()); + } + + std::string groupName = "test-group"; + YR::Collective::CollectiveGroupSpec spec{ + .worldSize = 4, + .groupName = groupName, + .backend = YR::Collective::Backend::GLOO, + .timeout = YR::Collective::DEFAULT_COLLECTIVE_TIMEOUT, + }; + YR::Collective::CreateCollectiveGroup(spec, instanceIDs, {0, 1, 2, 3}); + + // Invoke collective operations on all instances + std::vector input = {1, 2, 3, 4}; + std::vector> res; + for (int i = 0; i < 4; ++i) { + res.push_back(instances[i] + .Function(&CollectiveActor::Compute) + .Invoke(input, groupName, static_cast(YR::ReduceOp::SUM))); + } + + // Get results + auto res0 = *YR::Get(res[0]); // allreduce result + auto res1 = *YR::Get(res[1]); // send recv result(input) + std::cout << "AllReduce result: " << res0 << ", Recv result: " << res1 << std::endl; + + // Destroy the collective group + YR::Collective::DestroyCollectiveGroup(groupName); + YR::Finalize(); + return 0; + //! [create collective group] +} diff --git a/api/cpp/include/yr/api/kv.h b/api/cpp/include/yr/api/kv.h index 92eb65c6b42a21739ca744e425204000eb458446..5d0a2163cb520168f1202f6ce28cf22180a7ff40 100644 --- a/api/cpp/include/yr/api/kv.h +++ b/api/cpp/include/yr/api/kv.h @@ -15,7 +15,7 @@ */ #pragma once -#include "yr/api/check_initialized.h" + #include "yr/api/kv_manager.h" namespace YR { diff --git a/api/cpp/include/yr/api/kv_manager.h b/api/cpp/include/yr/api/kv_manager.h index 7ae114e9c825c1163d0500dd9104614e5e7e55f5..fe387d8546a2efd17f89905a710d5465036e048e 100644 --- a/api/cpp/include/yr/api/kv_manager.h +++ b/api/cpp/include/yr/api/kv_manager.h @@ -21,7 +21,9 @@ #include "constant.h" #include "runtime_manager.h" +#include "yr/api/check_initialized.h" #include "yr/api/exception.h" +#include "yr/api/local_mode_runtime.h" #include "yr/api/serdes.h" namespace YR { @@ -32,6 +34,7 @@ public: static KVManager kvManager; return kvManager; } + /** * @brief Sets the value of a key. * diff --git a/api/cpp/include/yr/api/named_instance.h b/api/cpp/include/yr/api/named_instance.h index 1cc6e7baad3634f23ccb8380777818f8de523297..7bf545c9b23ff602517553eb7a9a73551fa46125 100644 --- a/api/cpp/include/yr/api/named_instance.h +++ b/api/cpp/include/yr/api/named_instance.h @@ -135,6 +135,8 @@ public: const std::string GetInstanceId(); + const std::string GetObjectId(); + void SetAlwaysLocalMode(bool isLocalMode); void SetName(const std::string &instanceName); @@ -434,6 +436,12 @@ std::vector> NamedInstance::BuildInsta template const std::string NamedInstance::GetInstanceId() +{ + return YR::internal::GetRuntime()->GetRealInstanceId(instanceId); +} + +template +const std::string NamedInstance::GetObjectId() { return instanceId; } diff --git a/api/cpp/include/yr/collective/collective.h b/api/cpp/include/yr/collective/collective.h new file mode 100644 index 0000000000000000000000000000000000000000..05170b081f998984b82fcb5138c6860e1d3ce638 --- /dev/null +++ b/api/cpp/include/yr/collective/collective.h @@ -0,0 +1,413 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "reduce_op.h" + +namespace YR::Collective { +/*! + * @enum Backend + * @brief Backend type for collective communication. + */ +enum Backend : uint8_t { + GLOO = 0, ///< Use GLOO backend for collective communication. + INVALID, ///< Invalid backend type. +}; + +/*! + * @var const int DEFAULT_COLLECTIVE_TIMEOUT + * @brief Default timeout for collective communication operations in milliseconds. + */ +const int DEFAULT_COLLECTIVE_TIMEOUT = 60 * 1000; // ms + +/*! + * @var const std::string DEFAULT_GROUP_NAME + * @brief Default name for collective communication groups. + */ +const std::string DEFAULT_GROUP_NAME = "default"; + +/*! + * @struct CollectiveGroupSpec + * @brief Configuration specification for collective communication groups. + */ +struct CollectiveGroupSpec { + int worldSize; ///< Total number of processes in the group (world size). + std::string groupName = + "default"; ///< Name of the group, default is "default". Must match regex: ^[a-zA-Z0-9\-_!#%\^\*\(\)\+\=\:;]*$ + Backend backend = Backend::GLOO; ///< Backend type to use, default is GLOO. + int timeout = DEFAULT_COLLECTIVE_TIMEOUT; ///< Operation timeout in milliseconds, default is 60000 ms (60 seconds). +}; + +/*! + * @class CollectiveGroup + * @brief Abstract base class for collective communication groups, used to manage collective communication operations in + * distributed environments. + * + * The CollectiveGroup class provides a unified interface for collective communication operations, including AllReduce, + * Reduce, AllGather, Broadcast, etc. Different backend implementations (such as GLOO) can inherit from this class and + * provide specific implementations. + */ +class CollectiveGroup { +public: + CollectiveGroup(const CollectiveGroupSpec &groupSpec, int rank, std::string storePrefix) + : groupName_(groupSpec.groupName), + rank_(rank), + backend_(groupSpec.backend), + worldSize_(groupSpec.worldSize), + timeout_(groupSpec.timeout), + storePrefix_(std::move(storePrefix)) + { + } + + virtual ~CollectiveGroup() = default; + + virtual void AllReduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op) = 0; + + virtual void Reduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, + int dstRank) = 0; + + virtual void AllGather(const void *sendbuf, void *recvbuf, int count, DataType dtype) = 0; + + virtual void Barrier() = 0; + + virtual void Scatter(const std::vector sendbuf, void *recvbuf, int count, DataType dtype, int srcRank) = 0; + + virtual void Broadcast(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank) = 0; + + virtual void Recv(void *recvbuf, int count, DataType dtype, int srcRank, int tag) = 0; + + virtual void Send(const void *sendbuf, int count, DataType dtype, int dstRank, int tag) = 0; + + int GetRank() const; + std::string GetGroupName(); + Backend GetBackend(); + int GetWorldSize() const; + +protected: + std::string groupName_; + + int rank_; + + Backend backend_; + + int worldSize_; + + int timeout_; + + std::string storePrefix_; +}; + +/*! + * @brief Initializes a collective communication group in an actor instance. + * + * This function is used to initialize a collective communication group in the current actor instance. + * Typically in distributed training or parallel computing scenarios, each process needs to call this function to join a + * collective communication group. + * + * @param groupSpec Configuration specification for the collective communication group, including worldSize, groupName, + * backend, and timeout. + * @param rank Rank of the current process in the group, should be in the range [0, worldSize-1]. + * @param prefix Storage prefix for key-value storage used by backend communication. Default is an empty string. + * @note Must be called after YR::Init(). The same groupName cannot be initialized repeatedly, otherwise an exception + * will be thrown. groupName must match the regex: ^[a-zA-Z0-9\-_!#%\^\*\(\)\+\=\:;]*$ + * @throws Exception Thrown if called before initialization, if groupName is invalid, or if the collective group already + * exists. + * + * @snippet{trimleft} collective_example.cpp init collective group + */ +void InitCollectiveGroup(const CollectiveGroupSpec &groupSpec, int rank, const std::string &prefix = ""); + +/*! + * @brief Creates a collective communication group in the driver using actor instance IDs. + * + * This function is used to create a collective communication group in the driver process, specifying the actor + * instances participating in collective communication and their corresponding ranks. Typically in distributed training + * scenarios, the driver process calls this function to create the group, and then each actor instance joins the group + * through InitCollectiveGroup. + * + * @param groupSpec Configuration specification for the collective communication group. + * @param instanceIDs List of actor instance IDs, size must equal worldSize. + * @param ranks List of ranks corresponding to each instance, size must equal worldSize, and rank values should be in + * the range [0, worldSize-1]. + * @note The size of instanceIDs must equal worldSize. The size of ranks must equal worldSize. + * If groupName already exists, an exception will be thrown. You need to call DestroyCollectiveGroup first to + * destroy the existing group. + * @throws Exception Thrown if instanceIDs, ranks, and worldSize don't match, if groupName already exists, or if + * groupName is invalid. + * + * @snippet{trimleft} collective_example.cpp create collective group + */ +void CreateCollectiveGroup(const CollectiveGroupSpec &groupSpec, const std::vector &instanceIDs, + const std::vector &ranks); + +/*! + * @brief Destroys the specified collective communication group. + * + * This function is used to destroy a created collective communication group and release related resources. + * + * @param groupName Name of the group to destroy. + * @note If the group doesn't exist, this function won't throw an exception and will handle it silently. + */ +void DestroyCollectiveGroup(const std::string &groupName); + +/*! + * @brief Performs a reduction operation across all processes and broadcasts the result to all processes. + * + * This function performs a reduction operation (such as sum, max, etc.) across all processes and writes the result back + * to all processes' recvbuf. All processes' recvbuf will eventually contain the same result. + * + * @param sendbuf Send buffer containing local input data. All processes' sendbuf should have the same size. + * @param recvbuf Receive buffer for storing the reduction result. Size should be the same as sendbuf. + * @param count Number of data elements. + * @param dtype Data type (DataType::INT, DataType::FLOAT, DataType::DOUBLE, DataType::LONG). + * @param op Reduction operator (ReduceOp::SUM, ReduceOp::PRODUCT, ReduceOp::MIN, ReduceOp::MAX). + * @param groupName Name of the group, default is "default". + * @note sendbuf and recvbuf can point to the same memory (in-place operation). All processes must call this function + * with consistent parameters. Must be called after the group is created and initialized. + * @throws Exception Thrown if the group doesn't exist or hasn't been created yet. + * + * @snippet{trimleft} collective_example.cpp allreduce example + */ +void AllReduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, + const std::string &groupName = DEFAULT_GROUP_NAME); + +/*! + * @brief Performs a reduction operation across all processes and sends the result to the specified destination process. + * + * This function performs a reduction operation across all processes, but only the dstRank process's recvbuf will + * contain the reduction result. The recvbuf content of other processes is undefined. + * + * @param sendbuf Send buffer containing local input data. + * @param recvbuf Receive buffer, only used in the dstRank process, for storing the reduction result. Other processes' + * recvbuf content is undefined. + * @param count Number of data elements. + * @param dtype Data type. + * @param op Reduction operator. + * @param dstRank Rank of the destination process where the reduction result will be sent. + * @param groupName Name of the group, default is "default". + * @throws Exception Thrown if the group doesn't exist or hasn't been created yet. + * + * @snippet{trimleft} collective_example.cpp reduce example + */ +void Reduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, int dstRank, + const std::string &groupName = DEFAULT_GROUP_NAME); + +/*! + * @brief Gathers data from all processes and broadcasts the result to all processes. + * + * This function gathers data from all processes and writes the gathered data back to all processes' recvbuf in rank + * order. The size of recvbuf should be count * worldSize. + * + * @param sendbuf Send buffer containing local data to send. + * @param recvbuf Receive buffer for storing data gathered from all processes. Size should be count * worldSize. + * @param count Number of data elements sent by each process. + * @param dtype Data type. + * @param groupName Name of the group, default is "default". + * @note The size of recvbuf must be at least count * worldSize. Gathered data is arranged in rank order in recvbuf. + * @throws Exception Thrown if the group doesn't exist or hasn't been created yet. + * + * @snippet{trimleft} collective_example.cpp allgather example + */ +void AllGather(const void *sendbuf, void *recvbuf, int count, DataType dtype, + const std::string &groupName = DEFAULT_GROUP_NAME); + +/*! + * @brief Synchronization barrier that blocks until all processes in the group reach this point. + * + * This function is used to synchronize all processes, ensuring that all processes reach the Barrier call point before + * executing subsequent code. + * + * @param groupName Name of the group, default is "default". + * @note All processes must call this function, otherwise it will cause a deadlock. This function blocks until all + * processes reach the Barrier call point. + * @throws Exception Thrown if the group doesn't exist or hasn't been created yet. + * + * @snippet{trimleft} collective_example.cpp barrier example + */ +void Barrier(const std::string &groupName = DEFAULT_GROUP_NAME); + +/*! + * @brief Scatters data from the source process to all processes. + * + * This function scatters data from the srcRank process's sendbuf vector to all processes. + * The sendbuf vector of the srcRank process should contain worldSize buffers, each buffer corresponding to a target + * rank. + * + * @param sendbuf Send buffer vector, only used in the srcRank process. Size should be worldSize, each element points to + * data to be sent to the corresponding rank. + * @param recvbuf Receive buffer for storing data received from srcRank. + * @param count Number of data elements received by each process. + * @param dtype Data type. + * @param srcRank Rank of the source process. + * @param groupName Name of the group, default is "default". + * @note The sendbuf vector is only used in the srcRank process, other processes can pass an empty vector. The size of + * the sendbuf vector must equal worldSize. + * @throws Exception Thrown if the group doesn't exist or hasn't been created yet. + * + * @snippet{trimleft} collective_example.cpp scatter example + */ +void Scatter(const std::vector sendbuf, void *recvbuf, int count, DataType dtype, int srcRank, + const std::string &groupName = DEFAULT_GROUP_NAME); + +/*! + * @brief Broadcasts data from the source process to all processes. + * + * This function broadcasts data from the srcRank process to all processes in the group. + * All processes' recvbuf will eventually contain the same data (from srcRank's sendbuf). + * + * @param sendbuf Send buffer, only used in the srcRank process, containing data to broadcast. + * @param recvbuf Receive buffer for storing the broadcast data. All processes' recvbuf will eventually contain the same + * data. + * @param count Number of data elements. + * @param dtype Data type. + * @param srcRank Rank of the source process. + * @param groupName Name of the group, default is "default". + * @note sendbuf and recvbuf can point to the same memory (in-place operation). All processes must call this function + * with consistent parameters. + * @throws Exception Thrown if the group doesn't exist or hasn't been created yet. + * + * @snippet{trimleft} collective_example.cpp broadcast example + */ +void Broadcast(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank, + const std::string &groupName = DEFAULT_GROUP_NAME); + +/*! + * @brief Receives data from the specified process. + * + * This function is used for point-to-point communication, receiving data from the specified source process. + * Must be paired with the Send function, matched through the tag parameter. + * + * @param recvbuf Receive buffer for storing received data. + * @param count Number of data elements to receive. + * @param dtype Data type. + * @param srcRank Rank of the source process. + * @param tag Message tag for matching the corresponding Send operation, default is 0. + * @param groupName Name of the group, default is "default". + * @note Recv and Send must be paired, and tags must match. Recv operation blocks until the corresponding Send operation + * is called. count and dtype must match the corresponding Send operation. + * @throws Exception Thrown if the group doesn't exist or hasn't been created yet. + * + * @snippet{trimleft} collective_example.cpp send recv example + */ +void Recv(void *recvbuf, int count, DataType dtype, int srcRank, int tag = 0, + const std::string &groupName = DEFAULT_GROUP_NAME); + +/*! + * @brief Sends data to the specified process. + * + * This function is used for point-to-point communication, sending data to the specified destination process. + * Must be paired with the Recv function, matched through the tag parameter. + * + * @param sendbuf Send buffer containing data to send. + * @param count Number of data elements to send. + * @param dtype Data type. + * @param dstRank Rank of the destination process. + * @param tag Message tag for matching the corresponding Recv operation, default is 0. + * @param groupName Name of the group, default is "default". + * @note Send and Recv must be paired, and tags must match. Send operation blocks until the corresponding Recv operation + * is called. + * @throws Exception Thrown if the group doesn't exist or hasn't been created yet. + * + * @snippet{trimleft} collective_example.cpp send recv example + */ +void Send(const void *sendbuf, int count, DataType dtype, int dstRank, int tag = 0, + const std::string &groupName = DEFAULT_GROUP_NAME); + +/*! + * @brief Gets the total number of processes in the specified group. + * + * This function returns the total number of processes (world size) in the collective communication group. + * + * @param groupName Name of the group, default is "default". + * @return The total number of processes (world size) in the group. + * @throws Exception Thrown if the group doesn't exist or hasn't been created yet. + * + * @snippet{trimleft} collective_example.cpp get world size + */ +int GetWorldSize(const std::string &groupName = DEFAULT_GROUP_NAME); + +/*! + * @brief Gets the rank of the current process in the specified group. + * + * This function returns the rank of the current process in the collective communication group. + * + * @param groupName Name of the group, default is "default". + * @return The rank of the current process in the group, in the range [0, worldSize-1]. + * @throws Exception Thrown if the group doesn't exist or hasn't been created yet. + * + * @snippet{trimleft} collective_example.cpp get rank + */ +int GetRank(const std::string &groupName = DEFAULT_GROUP_NAME); + +/*! + * @class CollectiveGroupMgr + * @brief Manager class for collective communication groups (internal use). + * + * This class manages the lifecycle of collective communication groups using a singleton pattern. + * It is used internally by the collective communication API functions. + */ +class CollectiveGroupMgr { +public: + /*! + * @brief Gets the singleton instance of CollectiveGroupMgr. + * @return Reference to the singleton instance. + */ + static CollectiveGroupMgr &GetInstance() + { + static CollectiveGroupMgr instance; + return instance; + } + + /*! + * @brief Checks if a group exists, and creates it if it doesn't exist. + * @param groupName Name of the group. + * @return Shared pointer to the CollectiveGroup instance. + */ + std::shared_ptr CheckAndCreateGroup(const std::string &groupName); + + /*! + * @brief Initializes a collective communication group (internal use). + * @param groupSpec Configuration specification for the collective communication group. + * @param rank Rank of the current process. + * @param prefix Storage prefix for backend communication. + */ + void InitCollectiveGroup(const CollectiveGroupSpec &groupSpec, int rank, const std::string &prefix); + + /*! + * @brief Destroys a collective communication group (internal use). + * @param groupName Name of the group to destroy. + */ + void DestroyCollectiveGroup(const std::string &groupName); + +private: + CollectiveGroupMgr() = default; + + ~CollectiveGroupMgr(); + + std::recursive_mutex mtx_{}; + + std::unordered_map> groups_{}; +}; + +} // namespace YR::collective diff --git a/api/cpp/include/yr/collective/reduce_op.h b/api/cpp/include/yr/collective/reduce_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ce60810eafa85055850ddd48847f56faee6f39f3 --- /dev/null +++ b/api/cpp/include/yr/collective/reduce_op.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace YR { + +/*! + * @enum DataType + * @brief Data type enumeration for collective communication operations. + */ +enum DataType : uint8_t { + INT = 0, ///< Integer type (int). + DOUBLE, ///< Double precision floating point type (double). + LONG, ///< Long integer type (long). + FLOAT, ///< Single precision floating point type (float). + INVALID ///< Invalid data type. +}; + +/*! + * @var const std::unordered_map DATA_TYPE_SIZE_MAP + * @brief Mapping from DataType to its size in bytes. + */ +const static std::unordered_map DATA_TYPE_SIZE_MAP = {{DataType::INT, sizeof(int)}, + {DataType::DOUBLE, sizeof(double)}, + {DataType::LONG, sizeof(long)}, + {DataType::FLOAT, sizeof(float)}}; + +/*! + * @enum ReduceOp + * @brief Reduction operator enumeration for collective communication operations. + */ +enum ReduceOp : uint8_t { + SUM = 0, ///< Sum operation. + PRODUCT = 1, ///< Product operation. + MIN = 2, ///< Minimum operation. + MAX = 3, ///< Maximum operation. +}; +} // namespace YR diff --git a/api/cpp/include/yr/yr.h b/api/cpp/include/yr/yr.h index c557733d8e33ccc799179a1f385c4ecfd4ae47bd..1a9dc19ba4ab345d33ec6b3c79d9469f629def62 100644 --- a/api/cpp/include/yr/yr.h +++ b/api/cpp/include/yr/yr.h @@ -42,6 +42,7 @@ #include "yr/api/wait_result.h" #include "yr/api/yr_core.h" #include "yr/api/yr_invoke.h" +#include "yr/collective/collective.h" extern thread_local std::unordered_set localNestedObjList; namespace YR { diff --git a/api/cpp/src/collective/collective.cpp b/api/cpp/src/collective/collective.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d14820d061b403d19f2011acbdb3de2f8b9463f3 --- /dev/null +++ b/api/cpp/src/collective/collective.cpp @@ -0,0 +1,290 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "yr/collective/collective.h" + +#include +#include +#include +#include + +#include "api/cpp/src/utils/utils.h" +#include "gloo_collective_group.h" +#include "src/dto/config.h" +#include "src/libruntime/err_type.h" +#include "src/utility/id_generator.h" +#include "src/utility/logger/logger.h" +#include "yr/api/exception.h" +#include "yr/api/kv_manager.h" + +namespace YR::Collective { + +const std::string COLLECTIVE_GROUP_INFO_PREFIX = "collective-group-"; + +struct CollectiveGroupInfo { + CollectiveGroupSpec groupSpec; + std::vector instances; + std::vector ranks; + std::string prefix; + + std::string ToString() + { + nlohmann::json j = + nlohmann::json{{"instances", instances}, {"worldSize", groupSpec.worldSize}, {"ranks", ranks}, + {"backend", groupSpec.backend}, {"timeout", groupSpec.timeout}, {"prefix", prefix}}; + return j.dump(); + } + + void FromJson(const std::string &str) + { + nlohmann::json j; + try { + j = nlohmann::json::parse(str); + } catch (const std::exception &e) { + YRLOG_WARN("json parse error: {}", e.what()); + return; + } + + j.at("instances").get_to(instances); + j.at("worldSize").get_to(groupSpec.worldSize); + j.at("ranks").get_to(ranks); + j.at("backend").get_to(groupSpec.backend); + j.at("timeout").get_to(groupSpec.timeout); + j.at("prefix").get_to(prefix); + } +}; + +int CollectiveGroup::GetRank() const +{ + return rank_; +} + +std::string CollectiveGroup::GetGroupName() +{ + return groupName_; +} + +Backend CollectiveGroup::GetBackend() +{ + return backend_; +} + +int CollectiveGroup::GetWorldSize() const +{ + return worldSize_; +} + +void CheckGroupNameValid(const std::string &groupName) +{ + static const std::regex GROUP_NAME_REGEX(R"(^[a-zA-Z0-9\-_!#%\^\*\(\)\+\=\:;]*$)"); + ThrowIfTrue(!std::regex_match(groupName, GROUP_NAME_REGEX), YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + R"(groupName is invalid. It should match the regex: ^[a-zA-Z0-9\-_!#%^\*\(\)\+\=\:;]*$)"); +} + +void InitCollectiveGroup(const CollectiveGroupSpec &groupSpec, int rank, const std::string &prefix) +{ + CheckInitialized(); + CollectiveGroupMgr::GetInstance().InitCollectiveGroup(groupSpec, rank, prefix); +} + +void CreateCollectiveGroup(const CollectiveGroupSpec &groupSpec, const std::vector &instanceIDs, + const std::vector &ranks) +{ + CheckInitialized(); + CheckGroupNameValid(groupSpec.groupName); + + ThrowIfTrue(instanceIDs.size() != static_cast(groupSpec.worldSize) || + static_cast(groupSpec.worldSize) != ranks.size(), + YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "failed to create collective group, unequal actor, rank or world size, please check"); + + CollectiveGroupInfo info{.groupSpec = groupSpec, + .instances = instanceIDs, + .ranks = ranks, + .prefix = YR::utility::IDGenerator::GenObjectId()}; + + try { + YR::KVManager::Set(COLLECTIVE_GROUP_INFO_PREFIX + groupSpec.groupName, info.ToString(), YR::ExistenceOpt::NX); + } catch (YR::Exception &e) { + throw YR::Exception("collective group " + groupSpec.groupName + " already existed, please destroy it first"); + } +} + +void DestroyCollectiveGroup(const std::string &groupName) +{ + CheckInitialized(); + CollectiveGroupMgr::GetInstance().DestroyCollectiveGroup(groupName); +} + +void AllReduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, + const std::string &groupName) +{ + CheckInitialized(); + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + ThrowIfTrue(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + " first"); + group->AllReduce(sendbuf, recvbuf, count, dtype, op); +} + +void Reduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, int dstRank, + const std::string &groupName) +{ + CheckInitialized(); + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + ThrowIfTrue(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + " first"); + group->Reduce(sendbuf, recvbuf, count, dtype, op, dstRank); +} + +void AllGather(const void *sendbuf, void *recvbuf, int count, DataType dtype, const std::string &groupName) +{ + CheckInitialized(); + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + ThrowIfTrue(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + " first"); + group->AllGather(sendbuf, recvbuf, count, dtype); +} + +void Barrier(const std::string &groupName) +{ + CheckInitialized(); + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + ThrowIfTrue(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + " first"); + group->Barrier(); +} + +void Scatter(const std::vector sendbuf, void *recvbuf, int count, DataType dtype, int srcRank, + const std::string &groupName) +{ + CheckInitialized(); + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + ThrowIfTrue(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + " first"); + group->Scatter(sendbuf, recvbuf, count, dtype, srcRank); +} + +void Broadcast(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank, const std::string &groupName) +{ + CheckInitialized(); + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + ThrowIfTrue(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + " first"); + group->Broadcast(sendbuf, recvbuf, count, dtype, srcRank); +} + +void Recv(void *recvbuf, int count, DataType dtype, int srcRank, int tag, const std::string &groupName) +{ + CheckInitialized(); + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + ThrowIfTrue(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + " first"); + group->Recv(recvbuf, count, dtype, srcRank, tag); +} + +void Send(const void *sendbuf, int count, DataType dtype, int dstRank, int tag, const std::string &groupName) +{ + CheckInitialized(); + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + ThrowIfTrue(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + " first"); + group->Send(sendbuf, count, dtype, dstRank, tag); +} + +int GetWorldSize(const std::string &groupName) +{ + CheckInitialized(); + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + ThrowIfTrue(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + " first"); + return group->GetWorldSize(); +} + +int GetRank(const std::string &groupName) +{ + CheckInitialized(); + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + ThrowIfTrue(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + " first"); + return group->GetRank(); +} + +std::shared_ptr CollectiveGroupMgr::CheckAndCreateGroup(const std::string &groupName) +{ + std::lock_guard lock(mtx_); + if (groups_.find(groupName) != groups_.end()) { + return groups_[groupName]; + } + + auto str = YR::KVManager::Get(COLLECTIVE_GROUP_INFO_PREFIX + groupName); + ThrowIfTrue(str.empty(), YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + " first"); + + CollectiveGroupInfo info; + info.FromJson(str); + info.groupSpec.groupName = groupName; + + size_t index = 0; + for (; index < info.instances.size(); ++index) { + if (info.instances.at(index) == YR::Libruntime::Config::Instance().INSTANCE_ID()) { + break; + } + } + + ThrowIfTrue(index == info.instances.size(), YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "invalid instance id"); + InitCollectiveGroup(info.groupSpec, info.ranks[index], info.prefix); + return groups_[groupName]; +} + +void CollectiveGroupMgr::InitCollectiveGroup(const CollectiveGroupSpec &groupSpec, int rank, const std::string &prefix) +{ + CheckGroupNameValid(groupSpec.groupName); + std::lock_guard lock(mtx_); + if (groups_.find(groupSpec.groupName) != groups_.end()) { + YRLOG_DEBUG("collective group({}) already existed", groupSpec.groupName); + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "collective group already existed"); + } + + std::shared_ptr group; + switch (groupSpec.backend) { + case Backend::GLOO: + group = std::make_shared(groupSpec, rank, prefix); + break; + + case Backend::INVALID: + // fall-through + default: + YRLOG_ERROR("failed to init collective group({}), invalid backend: {}", groupSpec.groupName, + groupSpec.backend); + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "invalid collective group backend"); + } + groups_[groupSpec.groupName] = group; +} + +void CollectiveGroupMgr::DestroyCollectiveGroup(const std::string &groupName) +{ + std::lock_guard lock(mtx_); + YR::KVManager::Del(COLLECTIVE_GROUP_INFO_PREFIX + groupName); + groups_.erase(groupName); +} + +CollectiveGroupMgr::~CollectiveGroupMgr() +{ + std::lock_guard lock(mtx_); + groups_.clear(); +} + +} // namespace YR::collective \ No newline at end of file diff --git a/api/cpp/src/collective/gloo_collective_group.cpp b/api/cpp/src/collective/gloo_collective_group.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ba5e5a1ab1b0b09bc29d39a6d790e6a65a690c5c --- /dev/null +++ b/api/cpp/src/collective/gloo_collective_group.cpp @@ -0,0 +1,306 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gloo_collective_group.h" + +#include "api/cpp/src/utils/utils.h" +#include "gloo/algorithm.h" +#include "gloo/allgather.h" +#include "gloo/barrier.h" +#include "gloo/broadcast.h" +#include "gloo/reduce.h" +#include "gloo/rendezvous/prefix_store.h" +#include "gloo/scatter.h" +#include "gloo/transport/ibverbs/device.h" +#include "gloo/transport/tcp/device.h" +#include "src/dto/config.h" +#include "yr/api/kv_manager.h" + +namespace YR::Collective { + +#define EXECUTE_BY_TYPE(dType, OPERATION, ...) \ + switch (dType) { \ + case DataType::INT: \ + OPERATION(__VA_ARGS__); \ + break; \ + case DataType::DOUBLE: \ + OPERATION(__VA_ARGS__); \ + break; \ + case DataType::LONG: \ + OPERATION(__VA_ARGS__); \ + break; \ + case DataType::FLOAT: \ + OPERATION(__VA_ARGS__); \ + break; \ + case DataType::INVALID: \ + default: \ + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, \ + "invalid dType: " + std::to_string(dType)); \ + } + +const int DS_STORE_CHECK_INTERVAL = 100; // ms + +std::string ReplaceDsStoreKey(const std::string &str) +{ + std::string editedKey = str; + // datasystem doesn't support key that contains '/', replace it with '@' + std::replace(editedKey.begin(), editedKey.end(), '/', '@'); + return editedKey; +} + +DsStore::~DsStore() +{ + clear(); +} + +void DsStore::set(const std::string &key, const std::vector &data) +{ + auto str = std::string(data.begin(), data.end()); + YR::KVManager::Set(ReplaceDsStoreKey(key), std::string(data.begin(), data.end())); + keys_.emplace(ReplaceDsStoreKey(key)); + std::this_thread::sleep_for(std::chrono::milliseconds(DS_STORE_CHECK_INTERVAL)); +} + +std::vector DsStore::get(const std::string &key) +{ + std::string result = YR::KVManager::Get(ReplaceDsStoreKey(key)); + std::vector out; + out.assign(result.begin(), result.end()); + return out; +} + +void DsStore::wait(const std::vector &keys) +{ + std::vector editedKeys; + for (const auto &key : keys) { + editedKeys.push_back(ReplaceDsStoreKey(key)); + } + YR::KVManager::Get(editedKeys); +} + +void DsStore::wait(const std::vector &keys, const std::chrono::milliseconds &timeout) +{ + std::vector editedKeys; + for (const auto &key : keys) { + editedKeys.push_back(ReplaceDsStoreKey(key)); + } + // Convert timeout to seconds for KVManager, ensuring at least 1s if timeout > 0 + int timeoutSec = static_cast(timeout.count() / S_TO_MS); + if (timeout.count() > 0 && timeoutSec == 0) { + timeoutSec = 1; + } + YR::KVManager::Get(editedKeys, timeoutSec); +} + +void DsStore::clear() +{ + // clear ds kv + std::vector dels(keys_.begin(), keys_.end()); + YR::KVManager::Del(dels); +} + +GlooCollectiveGroup::GlooCollectiveGroup(const CollectiveGroupSpec &groupSpec, int rank, std::string storePrefix) + : CollectiveGroup(groupSpec, rank, std::move(storePrefix)) +{ + auto dsStore = std::make_shared(); + std::string prefixKey = groupName_; + if (!storePrefix_.empty()) { + prefixKey = groupName_ + "-" + storePrefix_; + } + store_ = dsStore; + auto prefixStore = std::make_shared(prefixKey, dsStore); + std::shared_ptr dev; + auto backend = GetEnv("GLOO_BACKEND_TYPE"); + + // Enable Gloo internal logs + setenv("GLOO_LOG_LEVEL", "INFO", 1); + if (backend.empty() || backend == "TCP") { + gloo::transport::tcp::attr attr; + attr.hostname = YR::Libruntime::Config::Instance().HOST_IP(); + std::cout << "[DEBUG] Gloo TCP Device binding to HOST_IP: " << attr.hostname << std::endl; + dev = gloo::transport::tcp::CreateDevice(attr); + } else if (backend == "IBVERBS") { + gloo::transport::ibverbs::attr attr; + attr.name = GetEnv("GLOO_IBVERBS_NAME"); + + auto indexStr = GetEnv("GLOO_IBVERBS_INDEX"); + auto portStr = GetEnv("GLOO_IBVERBS_PORT"); + try { + attr.index = indexStr.empty() ? 0 : std::stoi(indexStr); + attr.port = portStr.empty() ? 0 : std::stoi(portStr); + } catch (const std::exception &e) { + throw YR::Exception( + YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "invalid GLOO_IBVERBS_INDEX: " + indexStr + ", or invalid GLOO_IBVERBS_PORT: " + portStr); + } + dev = gloo::transport::ibverbs::CreateDevice(attr); + } else { + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "invalid backend type: " + backend); + } + + context_ = std::make_shared(rank_, worldSize_); + context_->setTimeout(std::chrono::milliseconds(timeout_)); + context_->connectFullMesh(prefixStore, dev); +} + +GlooCollectiveGroup::~GlooCollectiveGroup() +{ + if (store_) { + store_->clear(); + } +} + +void GlooCollectiveGroup::AllReduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op) +{ + EXECUTE_BY_TYPE(dtype, DoAllReduce, sendbuf, recvbuf, count, op); +} + +void GlooCollectiveGroup::Reduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, + int dstRank) +{ + EXECUTE_BY_TYPE(dtype, DoReduce, sendbuf, recvbuf, count, op, dstRank); +} + +void GlooCollectiveGroup::AllGather(const void *sendbuf, void *recvbuf, int count, DataType dtype) +{ + EXECUTE_BY_TYPE(dtype, DoAllGather, sendbuf, recvbuf, count); +} + +void GlooCollectiveGroup::Barrier() +{ + std::lock_guard lock(mtx_); + gloo::BarrierOptions opts(context_); + gloo::barrier(opts); +} + +void GlooCollectiveGroup::Scatter(const std::vector sendbuf, void *recvbuf, int count, DataType dtype, + int srcRank) +{ + EXECUTE_BY_TYPE(dtype, DoScatter, sendbuf, recvbuf, count, srcRank); +} + +void GlooCollectiveGroup::Broadcast(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank) +{ + EXECUTE_BY_TYPE(dtype, DoBroadcast, sendbuf, recvbuf, count, srcRank); +} + +void GlooCollectiveGroup::Recv(void *recvbuf, int count, DataType dtype, int srcRank, int tag) +{ + ThrowIfTrue(DATA_TYPE_SIZE_MAP.find(dtype) == DATA_TYPE_SIZE_MAP.end(), + YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "invalid dtype: " + std::to_string(dtype)); + std::lock_guard lock(mtx_); + auto ubuf = context_->createUnboundBuffer(recvbuf, count * DATA_TYPE_SIZE_MAP.at(dtype)); + ubuf->recv(srcRank, tag); + ubuf->waitRecv(std::chrono::milliseconds(timeout_)); +} + +void GlooCollectiveGroup::Send(const void *sendbuf, int count, DataType dtype, int dstRank, int tag) +{ + ThrowIfTrue(DATA_TYPE_SIZE_MAP.find(dtype) == DATA_TYPE_SIZE_MAP.end(), + YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "invalid dtype: " + std::to_string(dtype)); + std::lock_guard lock(mtx_); + auto ubuf = context_->createUnboundBuffer(const_cast(sendbuf), count * DATA_TYPE_SIZE_MAP.at(dtype)); + ubuf->send(dstRank, tag); + ubuf->waitSend(std::chrono::milliseconds(timeout_)); +} + +template +gloo::AllreduceOptions::Func GlooCollectiveGroup::GetReduceOp(const ReduceOp &op) +{ + void (*fn)(void *, const void *, const void *, long unsigned int) = &gloo::sum; + switch (op) { + case ReduceOp::SUM: + fn = &gloo::sum; + break; + case ReduceOp::PRODUCT: + fn = &gloo::product; + break; + case ReduceOp::MIN: + fn = &gloo::min; + break; + case ReduceOp::MAX: + fn = &gloo::max; + break; + } + return fn; +} + +template +void GlooCollectiveGroup::DoAllReduce(const void *sendbuf, void *recvbuf, int count, const ReduceOp &op) +{ + std::lock_guard lock(mtx_); + gloo::AllreduceOptions opts(context_); + opts.setInput(const_cast((T *)sendbuf), count); + opts.setOutput(static_cast(recvbuf), count); + opts.setReduceFunction(GetReduceOp(op)); + gloo::allreduce(opts); +} + +template +void GlooCollectiveGroup::DoReduce(const void *sendbuf, void *recvbuf, int count, const ReduceOp &op, int dstRank) +{ + std::lock_guard lock(mtx_); + gloo::ReduceOptions opts(context_); + opts.setInput(const_cast((T *)sendbuf), count); + opts.setOutput(static_cast(recvbuf), count); + opts.setRoot(dstRank); + opts.setReduceFunction(GetReduceOp(op)); + gloo::reduce(opts); +} + +template +void GlooCollectiveGroup::DoBroadcast(const void *sendbuf, void *recvbuf, int count, int srcRank) +{ + std::lock_guard lock(mtx_); + gloo::BroadcastOptions opts(context_); + if (srcRank == GetRank()) { + opts.setInput(const_cast((T *)sendbuf), count); + } + + opts.setOutput(static_cast(recvbuf), count); + opts.setRoot(srcRank); + gloo::broadcast(opts); +} + +template +void GlooCollectiveGroup::DoScatter(const std::vector sendbuf, void *recvbuf, int count, int srcRank) +{ + std::lock_guard lock(mtx_); + gloo::ScatterOptions opts(context_); + if (srcRank == GetRank()) { + std::vector inputs; + for (auto item : sendbuf) { + inputs.push_back(const_cast((T *)item)); + } + opts.setInputs(inputs, count); + } + + opts.setOutput(static_cast(recvbuf), count); + opts.setRoot(srcRank); + gloo::scatter(opts); +} + +template +void GlooCollectiveGroup::DoAllGather(const void *sendbuf, void *recvbuf, int count) +{ + std::lock_guard lock(mtx_); + gloo::AllgatherOptions opts(context_); + opts.setInput(const_cast((T *)sendbuf), count); + opts.setOutput(static_cast(recvbuf), count * GetWorldSize()); + gloo::allgather(opts); +} +} // namespace YR::collective \ No newline at end of file diff --git a/api/cpp/src/collective/gloo_collective_group.h b/api/cpp/src/collective/gloo_collective_group.h new file mode 100644 index 0000000000000000000000000000000000000000..639c5af63ea71e35f6058c0b0fbf309a67d767bb --- /dev/null +++ b/api/cpp/src/collective/gloo_collective_group.h @@ -0,0 +1,96 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +#include "gloo/allreduce.h" +#include "gloo/rendezvous/context.h" +#include "yr/collective/collective.h" + +namespace YR::Collective { + +const int DS_STORE_CHECK_DEFAULT_TIMEOUT = 5 * 60 * 1000; // ms + +struct DsStore : public gloo::rendezvous::Store { +public: + DsStore() = default; + ~DsStore() override; + + void set(const std::string &key, const std::vector &data) override; + + std::vector get(const std::string &key) override; + + void wait(const std::vector &keys) override; + + void wait(const std::vector &keys, const std::chrono::milliseconds &timeout) override; + + void clear(); + +private: + std::unordered_set keys_; +}; + +class GlooCollectiveGroup : public CollectiveGroup { +public: + GlooCollectiveGroup(const CollectiveGroupSpec &groupSpec, int rank, std::string storePrefix); + + ~GlooCollectiveGroup() override; + + void AllReduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op) override; + + void Reduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, + int dstRank) override; + + void AllGather(const void *sendbuf, void *recvbuf, int count, DataType dtype) override; + + void Barrier() override; + + void Scatter(const std::vector sendbuf, void *recvbuf, int count, DataType dtype, int srcRank) override; + + void Broadcast(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank) override; + + void Recv(void *recvbuf, int count, DataType dtype, int srcRank, int tag) override; + + void Send(const void *sendbuf, int count, DataType dtype, int dstRank, int tag) override; + +private: + template + void DoAllReduce(const void *sendbuf, void *recvbuf, int count, const ReduceOp &op); + + template + void DoReduce(const void *sendbuf, void *recvbuf, int count, const ReduceOp &op, int dstRank); + + template + void DoAllGather(const void *sendbuf, void *recvbuf, int count); + + template + void DoScatter(const std::vector sendbuf, void *recvbuf, int count, int srcRank); + + template + void DoBroadcast(const void *sendbuf, void *recvbuf, int count, int srcRank); + + template + static gloo::AllreduceOptions::Func GetReduceOp(const ReduceOp &op); + + std::unordered_map map; + + std::shared_ptr context_; + std::shared_ptr store_; + std::recursive_mutex mtx_{}; +}; +} // namespace YR::collective \ No newline at end of file diff --git a/api/cpp/src/utils/utils.cpp b/api/cpp/src/utils/utils.cpp index 197f725f52a2537e370d911ff66375c0c58da90b..f5de23183cb2b426f04bdcff6efc063a8fcdd9be 100644 --- a/api/cpp/src/utils/utils.cpp +++ b/api/cpp/src/utils/utils.cpp @@ -27,6 +27,14 @@ const int FUNCTION_NAME_INDEX = 5; const int FUNCTION_VERSION_INDEX = 6; const int URN_CUT_NUM = 7; +void ThrowIfTrue(bool condition, Libruntime::ErrorCode code, const std::string &msg) +{ + if (condition) { + YRLOG_ERROR(msg); + throw YR::Exception(code, msg); + } +} + std::string ConvertFunctionUrnToId(const std::string &functionUrn) { char sep = ':'; diff --git a/api/cpp/src/utils/utils.h b/api/cpp/src/utils/utils.h index 683260df43dae728eb3ec1aee86ada3eb4e026b9..70ba5394fdfa83a1648cc2b1252d6e17ec695237 100644 --- a/api/cpp/src/utils/utils.h +++ b/api/cpp/src/utils/utils.h @@ -22,6 +22,8 @@ #include "src/dto/data_object.h" namespace YR { +void ThrowIfTrue(bool condition, Libruntime::ErrorCode code, const std::string &msg); + std::string ConvertFunctionUrnToId(const std::string &functionUrn); YR::Libruntime::ErrorInfo WriteDataObject(const void *data, std::shared_ptr &dataObj, uint64_t size, const std::unordered_set &nestedIds); diff --git a/bazel/gloo.bzl b/bazel/gloo.bzl new file mode 100644 index 0000000000000000000000000000000000000000..1ca3aa6500c8017e4ce567e315845f991b3558bb --- /dev/null +++ b/bazel/gloo.bzl @@ -0,0 +1,105 @@ +load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library") +load("@yuanrong_multi_language_runtime//bazel:yr.bzl", "filter_files_with_suffix") + +# Files to exclude from the build +GLOO_EXCLUDES = [ + "gloo/benchmark/**", + "gloo/examples/**", + "gloo/test/**", + "gloo/cuda/**", + "gloo/cuda*", + "gloo/hip/**", + "gloo/nccl/**", + "gloo/mpi/**", + "gloo/rendezvous/redis_store*", + "gloo/transport/uv/**", + "gloo/transport/tcp/tls/**", + "gloo/common/win.*", +] + +# All source files - use glob to include all .cc files except excluded ones +GLOO_ALL_SRCS = glob( + ["gloo/**/*.cc"], + exclude = GLOO_EXCLUDES, +) + +# All header files - use glob to include all .h files except excluded ones +GLOO_ALL_HDRS = glob( + ["gloo/**/*.h"], + exclude = GLOO_EXCLUDES, +) + +genrule( + name = "config_header", + outs = ["gloo/config.h"], + cmd = """ + cat <<'EOF' > $@ +#pragma once + +#define GLOO_VERSION_MAJOR 0 +#define GLOO_VERSION_MINOR 5 +#define GLOO_VERSION_PATCH 0 + +static_assert(GLOO_VERSION_MINOR < 100, "Invalid GLOO minor version"); +static_assert(GLOO_VERSION_PATCH < 100, "Invalid GLOO patch version"); + +#define GLOO_VERSION (GLOO_VERSION_MAJOR * 10000 + GLOO_VERSION_MINOR * 100 + GLOO_VERSION_PATCH) + +#define GLOO_USE_CUDA 0 +#define GLOO_USE_NCCL 0 +#define GLOO_USE_ROCM 0 +#define GLOO_USE_RCCL 0 +#define GLOO_USE_REDIS 0 +#define GLOO_USE_IBVERBS 1 +#define GLOO_USE_MPI 0 +#define GLOO_USE_AVX 0 +#define GLOO_USE_LIBUV 0 + +#define GLOO_HAVE_TRANSPORT_TCP 1 +#define GLOO_HAVE_TRANSPORT_TCP_TLS 0 +#define GLOO_HAVE_TRANSPORT_IBVERBS 1 +#define GLOO_HAVE_TRANSPORT_UV 0 + +#define GLOO_USE_TORCH_DTYPES 0 +EOF + """, +) + +cc_library( + name = "gloo", + srcs = GLOO_ALL_SRCS, + hdrs = GLOO_ALL_HDRS + [ + ":config_header", + ], + includes = ["."], + copts = [ + "-std=c++17", + "-fPIC", + ], + linkopts = [ + "-lpthread", + "-libverbs", + ], + visibility = ["//visibility:public"], +) + +cc_binary( + name = "gloo_shared_lib", + linkshared = True, + deps = [":gloo"], +) + +genrule( + name = "libgloo_so", + srcs = [":gloo_shared_lib"], + outs = ["libgloo_shared.so"], + cmd = "cp $(location :gloo_shared_lib) $@", + visibility = ["//visibility:public"], +) + +filter_files_with_suffix( + name = "shared", + srcs = [":libgloo_so"], + suffix = ".so", + visibility = ["//visibility:public"], +) \ No newline at end of file diff --git a/docs/multi_language_function_programming_interface/api/distributed_programming/C++/Collective.md b/docs/multi_language_function_programming_interface/api/distributed_programming/C++/Collective.md new file mode 100644 index 0000000000000000000000000000000000000000..b0852dc3d83376b7d92cceb38bdb821c4c70aade --- /dev/null +++ b/docs/multi_language_function_programming_interface/api/distributed_programming/C++/Collective.md @@ -0,0 +1,73 @@ +# Collective + +## CollectiveGroup + +The parameter structure is supplemented with the following explanation: + +```{doxygenstruct} YR::Collective::CollectiveGroupSpec + :members: +``` + +Type and enumeration definitions: + +```{doxygenenum} YR::Collective::Backend +``` + +```{doxygenenum} YR::DataType +``` + +```{doxygenenum} YR::ReduceOp +``` + +Constants: + +```{doxygenvariable} YR::Collective::DEFAULT_COLLECTIVE_TIMEOUT +``` + +```{doxygenvariable} YR::Collective::DEFAULT_GROUP_NAME +``` +## Collective-GroupOps + +Collective communication group management interfaces for creating, initializing, and destroying collective communication groups. + +```{doxygenfunction} YR::Collective::InitCollectiveGroup(const CollectiveGroupSpec &groupSpec, int rank, const std::string &prefix = "") +``` + +```{doxygenfunction} YR::Collective::CreateCollectiveGroup(const CollectiveGroupSpec &groupSpec, const std::vector &instanceIDs, const std::vector &ranks) +``` + +```{doxygenfunction} YR::Collective::DestroyCollectiveGroup(const std::string &groupName) +``` + +```{doxygenfunction} YR::Collective::GetWorldSize(const std::string &groupName = DEFAULT_GROUP_NAME) +``` + +```{doxygenfunction} YR::Collective::GetRank(const std::string &groupName = DEFAULT_GROUP_NAME) +``` +## Collective-CommOps + +Collective communication operation interfaces, providing distributed communication primitives such as AllReduce, Reduce, AllGather, Broadcast, etc. + +```{doxygenfunction} YR::Collective::AllReduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, const std::string &groupName = DEFAULT_GROUP_NAME) +``` + +```{doxygenfunction} YR::Collective::Reduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, int dstRank, const std::string &groupName = DEFAULT_GROUP_NAME) +``` + +```{doxygenfunction} YR::Collective::AllGather(const void *sendbuf, void *recvbuf, int count, DataType dtype, const std::string &groupName = DEFAULT_GROUP_NAME) +``` + +```{doxygenfunction} YR::Collective::Broadcast(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank, const std::string &groupName = DEFAULT_GROUP_NAME) +``` + +```{doxygenfunction} YR::Collective::Scatter(const std::vector sendbuf, void *recvbuf, int count, DataType dtype, int srcRank, const std::string &groupName = DEFAULT_GROUP_NAME) +``` + +```{doxygenfunction} YR::Collective::Barrier(const std::string &groupName = DEFAULT_GROUP_NAME) +``` + +```{doxygenfunction} YR::Collective::Send(const void *sendbuf, int count, DataType dtype, int dstRank, int tag = 0, const std::string &groupName = DEFAULT_GROUP_NAME) +``` + +```{doxygenfunction} YR::Collective::Recv(void *recvbuf, int count, DataType dtype, int srcRank, int tag = 0, const std::string &groupName = DEFAULT_GROUP_NAME) +``` diff --git a/docs/multi_language_function_programming_interface/api/distributed_programming/C++/index.rst b/docs/multi_language_function_programming_interface/api/distributed_programming/C++/index.rst index 3f07ce190c512ce3818e636280ff3bac83b3fa5a..49bc5f9c853e2df8466803aebfd2ba8dfe0ab8e1 100644 --- a/docs/multi_language_function_programming_interface/api/distributed_programming/C++/index.rst +++ b/docs/multi_language_function_programming_interface/api/distributed_programming/C++/index.rst @@ -57,7 +57,14 @@ C++ :maxdepth: 1 Stream - + +.. toctree:: + :glob: + :hidden: + :maxdepth: 1 + + Collective + .. toctree:: :glob: :hidden: @@ -206,7 +213,18 @@ Stream API :widths: 30 70 * - :doc:`Stream` - - Passing ordered unbounded data between openYuanrong functions. + - Passing ordered unbounded data between openYuanrong functions. + + +Collective Communication API +----------------------------- + +.. list-table:: + :header-rows: 0 + :widths: 30 70 + + * - :doc:`Collective` + - Complete documentation for collective communication, covering type definitions, group management APIs, and collective operators such as AllReduce, Reduce, AllGather, Broadcast, Scatter, Barrier, Send, Recv. Function Interoperation API diff --git a/docs/multi_language_function_programming_interface/api/distributed_programming/zh_cn/C++/Collective.rst b/docs/multi_language_function_programming_interface/api/distributed_programming/zh_cn/C++/Collective.rst new file mode 100644 index 0000000000000000000000000000000000000000..029a48c955f542a2f79b73b834bb56030c252e0b --- /dev/null +++ b/docs/multi_language_function_programming_interface/api/distributed_programming/zh_cn/C++/Collective.rst @@ -0,0 +1,648 @@ +Collective +========== + +CollectiveGroup +--------------- + +类型和枚举定义 +---------------- + +.. cpp:enum:: YR::Collective::Backend : uint8_t + + 集合通信后端类型。 + + .. cpp:enumerator:: GLOO = 0 + + 使用 GLOO 后端进行集合通信。 + + .. cpp:enumerator:: INVALID + + 无效的后端类型。 + +.. cpp:enum:: YR::DataType : uint8_t + + 数据类型枚举,定义在 ``reduce_op.h`` 中。 + + .. cpp:enumerator:: INT = 0 + + 整数类型(int)。 + + .. cpp:enumerator:: DOUBLE + + 双精度浮点数类型(double)。 + + .. cpp:enumerator:: LONG + + 长整数类型(long)。 + + .. cpp:enumerator:: FLOAT + + 单精度浮点数类型(float)。 + + .. cpp:enumerator:: INVALID + + 无效的数据类型。 + +.. cpp:enum:: YR::ReduceOp : uint8_t + + 规约操作符枚举,定义在 ``reduce_op.h`` 中。 + + .. cpp:enumerator:: SUM = 0 + + 求和操作。 + + .. cpp:enumerator:: PRODUCT = 1 + + 乘积操作。 + + .. cpp:enumerator:: MIN = 2 + + 最小值操作。 + + .. cpp:enumerator:: MAX = 3 + + 最大值操作。 + +参数结构补充说明如下: + +.. cpp:struct:: YR::Collective::CollectiveGroupSpec + + 集合通信组的配置规格。 + + **公共成员** + + .. cpp:member:: int worldSize + + 组中的进程总数(world size)。 + + .. cpp:member:: std::string groupName = "default" + + 组的名称,默认为 "default"。组名必须匹配正则表达式:``^[a-zA-Z0-9\-_!#%\^\*\(\)\+\=\:;]*$`` + + .. cpp:member:: Backend backend = Backend::GLOO + + 使用的后端类型,默认为 GLOO。 + + .. cpp:member:: int timeout = DEFAULT_COLLECTIVE_TIMEOUT + + 操作超时时间(毫秒),默认为 60000 毫秒(60 秒)。 + +常量定义 +--------- + +.. cpp:var:: const int YR::Collective::DEFAULT_COLLECTIVE_TIMEOUT + + 默认的集合通信超时时间,值为 60000 毫秒(60 秒)。 + +.. cpp:var:: const std::string YR::Collective::DEFAULT_GROUP_NAME + + 默认的组名称,值为 "default"。 + +Collective-GroupOps +------------------- + +集合通信组管理接口,用于创建、初始化和销毁集合通信组。 + +InitCollectiveGroup +-------------------- + +.. cpp:function:: void YR::Collective::InitCollectiveGroup(const CollectiveGroupSpec &groupSpec, int rank, \ + const std::string &prefix = "") + + 在 actor 实例中初始化集合通信组。 + + 此函数用于在当前 actor 实例中初始化一个集合通信组。通常在分布式训练或并行计算场景中,每个进程可以调用此函数来加入集合通信组。 + + .. code-block:: cpp + + YR::Collective::CollectiveGroupSpec spec; + spec.worldSize = 4; + spec.groupName = "my_group"; + spec.backend = YR::Collective::Backend::GLOO; + spec.timeout = 60000; + + int rank = 0; // 当前进程的 rank + YR::Collective::InitCollectiveGroup(spec, rank); + + // 使用集合通信操作... + return 0; + + + 参数: + - **groupSpec** - 集合通信组的配置规格,包含 worldSize、groupName、backend 和 timeout。 + - **rank** - 当前进程在组中的排名(rank),范围应为 [0, worldSize-1]。 + - **prefix** - 存储前缀,用于后端通信的键值存储。默认为空字符串。 + + .. note:: + 限制条件 + + - 必须在调用 ``YR::Init()`` 之后才能调用此函数。 + - 同一个 groupName 不能重复初始化,否则会抛出异常。 + - groupName 必须匹配正则表达式:``^[a-zA-Z0-9\-_!#%\^\*\(\)\+\=\:;]*$`` + + 抛出: + :cpp:class:`Exception` - 以下情形会抛出异常: + + - 如果在初始化前调用此函数,将抛出异常。 + - 如果 groupName 无效,将抛出异常。 + - 如果集合通信组已存在,将抛出异常。 + +CreateCollectiveGroup +---------------------- + +.. cpp:function:: void YR::Collective::CreateCollectiveGroup(const CollectiveGroupSpec &groupSpec, \ + const std::vector &instanceIDs, \ + const std::vector &ranks) + + 在 driver 中创建集合通信组,使用 actor 实例 ID 列表。 + + 此函数用于在 driver 进程中创建集合通信组,指定参与集合通信的 actor 实例及其对应的 rank。通常在分布式训练场景中,driver 进程会调用此函数来创建组,然后各个 actor 实例不用通过 InitCollectiveGroup 加入组。 + + .. code-block:: cpp + + #include "yr/collective/collective.h" + + int main() { + YR::Config conf; + YR::Init(conf); + + // 创建包含4个actor的集合通信组 + std::vector> instances; + std::vector instanceIDs; + for (int i = 0; i < 4; ++i) { + auto ins = YR::Instance(CollectiveActor::FactoryCreate).Invoke(); + instances.push_back(ins); + instanceIDs.push_back(ins.GetInstanceId()); + } + + std::string groupName = "test-group"; + YR::Collective::CollectiveGroupSpec spec{ + .worldSize = 4, + .groupName = groupName, + .backend = YR::Collective::Backend::GLOO, + .timeout = YR::Collective::DEFAULT_COLLECTIVE_TIMEOUT, + }; + YR::Collective::CreateCollectiveGroup(spec, instanceIDs, {0, 1, 2, 3}); + + // 调用每个actor上的计算方法 + std::vector input = {1, 2, 3, 4}; + std::vector> res; + for (int i = 0; i < 4; ++i) { + res.push_back(instances[i] + .Function(&CollectiveActor::Compute) + .Invoke(input, groupName, static_cast(YR::ReduceOp::SUM))); + } + + // Get 结果 + auto res0 = *YR::Get(res[0]); + auto res1 = *YR::Get(res[1]); + std::cout << "AllReduce result: " << res0 << ", Recv result: " << res1 << std::endl; + + // 销毁组 + YR::Collective::DestroyCollectiveGroup(groupName); + YR::Finalize(); + return 0; + + + 参数: + - **groupSpec** - 集合通信组的配置规格。 + - **instanceIDs** - actor 实例 ID 列表,大小必须等于 worldSize。 + - **ranks** - 每个实例对应的 rank 列表,大小必须等于 worldSize,且 ranks 的值应在 [0, worldSize-1] 范围内。 + + .. note:: + 限制条件 + + - instanceIDs 的大小必须等于 worldSize。 + - ranks 的大小必须等于 worldSize。 + - 如果 groupName 已存在,将抛出异常,需要先调用 DestroyCollectiveGroup 销毁已存在的组。 + + 抛出: + :cpp:class:`Exception` - 以下情形会抛出异常: + + - 如果 instanceIDs、ranks 和 worldSize 不匹配,将抛出异常。 + - 如果 groupName 已存在,将抛出异常。 + - 如果 groupName 无效,将抛出异常。 + +DestroyCollectiveGroup +----------------------- + +.. cpp:function:: void YR::Collective::DestroyCollectiveGroup(const std::string &groupName) + + 销毁指定的集合通信组。 + + 此函数用于销毁一个已创建的集合通信组,释放相关资源。 + + 参数: + - **groupName** - 要销毁的组的名称。 + + .. note:: + 限制条件 + + - 如果组不存在,此函数不会抛出异常,会静默处理。 + +GetWorldSize +------------ + +.. cpp:function:: int YR::Collective::GetWorldSize(const std::string &groupName = DEFAULT_GROUP_NAME) + + 获取指定组中的进程总数。 + + 此函数返回集合通信组中的进程总数(world size)。 + + .. code-block:: cpp + + // 初始化集合通信组... + std::string groupName = "my_group"; + int worldSize = YR::Collective::GetWorldSize(groupName); + std::cout << "World size: " << worldSize << std::endl; + + 参数: + - **groupName** - 组的名称,默认为 "default"。 + + 返回: + 组中的进程总数(world size)。 + + 抛出: + :cpp:class:`Exception` - 如果组不存在或尚未创建,将抛出异常。 + +GetRank +------- + +.. cpp:function:: int YR::Collective::GetRank(const std::string &groupName = DEFAULT_GROUP_NAME) + + 获取当前进程在指定组中的排名。 + + 此函数返回当前进程在集合通信组中的排名(rank)。 + + .. code-block:: cpp + + // 初始化集合通信组... + std::string groupName = "my_group"; + int rank = YR::Collective::GetRank(groupName); + std::cout << "My rank: " << rank << std::endl; + + 参数: + - **groupName** - 组的名称,默认为 "default"。 + + 返回: + 当前进程在组中的排名(rank),范围在 [0, worldSize-1]。 + + 抛出: + :cpp:class:`Exception` - 如果组不存在或尚未创建,将抛出异常。 +Collective-CommOps +------------------ + +集合通信操作接口,提供 AllReduce、Reduce、AllGather、Broadcast 等分布式通信原语。 + +AllReduce +---------- + +.. cpp:function:: void YR::Collective::AllReduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, \ + const ReduceOp &op, const std::string &groupName = DEFAULT_GROUP_NAME) + + 在所有进程中执行规约操作,并将结果广播到所有进程。 + + 此函数在所有进程中执行规约操作(如求和、求最大值等),并将结果写回到所有进程的 recvbuf 中。所有进程的 recvbuf 最终包含相同的结果。 + + .. code-block:: cpp + + int Compute(std::vector in, std::string &groupName, uint8_t op) + { + std::string groupName = "my_group"; + std::vector localData(100, 1); + std::vector result(100); + + // AllReduce:在所有进程执行规约并广播结果 + YR::Collective::AllReduce(localData.data(), result.data(), localData.size(), YR::DataType::INT, YR::ReduceOp::SUM, groupName); + + // Barrier:同步所有进程 + YR::Collective::Barrier(groupName); + YR::Collective::DestroyCollectiveGroup(groupName); + int res = 0; + for (int i = 0; i < localData.size(); ++i) { + res += result[i]; + } + return res; + } + + 参数: + - **sendbuf** - 发送缓冲区,包含本地输入数据。所有进程的 sendbuf 大小应相同。 + - **recvbuf** - 接收缓冲区,用于存储规约结果。大小应与 sendbuf 相同。 + - **count** - 数据元素的数量。 + - **dtype** - 数据类型(DataType::INT、DataType::FLOAT、DataType::DOUBLE、DataType::LONG)。 + - **op** - 规约操作符(ReduceOp::SUM、ReduceOp::PRODUCT、ReduceOp::MIN、ReduceOp::MAX)。 + - **groupName** - 组的名称,默认为 "default"。 + + .. note:: + 注意事项 + + - sendbuf 和 recvbuf 可以指向同一块内存(in-place 操作)。 + - 所有进程必须调用此函数,且参数 count、dtype、op 必须一致。 + - 必须在组创建并初始化后才能调用此函数。 + + 抛出: + :cpp:class:`Exception` - 如果组不存在或尚未创建,将抛出异常。 + +Reduce +------ + +.. cpp:function:: void YR::Collective::Reduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, \ + const ReduceOp &op, int dstRank, \ + const std::string &groupName = DEFAULT_GROUP_NAME) + + 在所有进程中执行规约操作,并将结果发送到指定的目标进程。 + + 此函数在所有进程中执行规约操作,但只有 dstRank 进程的 recvbuf 会包含规约结果,其他进程的 recvbuf 内容未定义。 + + .. code-block:: cpp + + std::string groupName = "my_group"; + std::vector localData(100, 1); + std::vector result(100); + + int rootRank = 0; // 规约结果发送到 rank 0 + YR::Collective::Reduce(localData.data(), result.data(), localData.size(), YR::DataType::INT, YR::ReduceOp::SUM, + rootRank, groupName); + + int rank = YR::Collective::GetRank(groupName); + if (rank == rootRank) { + // 只有 rootRank 的 result 包含有效结果 + std::cout << "Reduced result: " << result[0] << std::endl; + } + + 参数: + - **sendbuf** - 发送缓冲区,包含本地输入数据。 + - **recvbuf** - 接收缓冲区,仅在 dstRank 进程中使用,用于存储规约结果。其他进程的 recvbuf 内容未定义。 + - **count** - 数据元素的数量。 + - **dtype** - 数据类型。 + - **op** - 规约操作符。 + - **dstRank** - 目标进程的 rank,规约结果将发送到此进程。 + - **groupName** - 组的名称,默认为 "default"。 + + 抛出: + :cpp:class:`Exception` - 如果组不存在或尚未创建,将抛出异常。 + +AllGather +--------- + +.. cpp:function:: void YR::Collective::AllGather(const void *sendbuf, void *recvbuf, int count, DataType dtype, \ + const std::string &groupName = DEFAULT_GROUP_NAME) + + 从所有进程收集数据,并将结果广播到所有进程。 + + 此函数从所有进程收集数据,并将收集到的数据按 rank 顺序排列后写回到所有进程的 recvbuf 中。recvbuf 的大小应为 count * worldSize。 + + .. code-block:: cpp + + std::string groupName = "my_group"; + int rank = YR::Collective::GetRank(groupName); + int worldSize = YR::Collective::GetWorldSize(groupName); + + // 每个进程发送不同的数据 + std::vector localData(10, static_cast(rank)); + std::vector gatheredData(10 * worldSize); + + YR::Collective::AllGather(localData.data(), gatheredData.data(), localData.size(), YR::DataType::FLOAT, + groupName); + + 参数: + - **sendbuf** - 发送缓冲区,包含本地要发送的数据。 + - **recvbuf** - 接收缓冲区,用于存储从所有进程收集的数据。大小应为 count * worldSize。 + - **count** - 每个进程发送的数据元素数量。 + - **dtype** - 数据类型。 + - **groupName** - 组的名称,默认为 "default"。 + + .. note:: + 注意事项 + + - recvbuf 的大小必须至少为 count * worldSize。 + - 收集的数据按 rank 顺序排列在 recvbuf 中。 + + 抛出: + :cpp:class:`Exception` - 如果组不存在或尚未创建,将抛出异常。 + +Broadcast +--------- + +.. cpp:function:: void YR::Collective::Broadcast(const void *sendbuf, void *recvbuf, int count, DataType dtype, \ + int srcRank, const std::string &groupName = DEFAULT_GROUP_NAME) + + 从源进程广播数据到所有进程。 + + 此函数从 srcRank 进程广播数据到组中的所有进程。所有进程的 recvbuf 最终包含相同的数据(来自 srcRank 的 sendbuf)。 + + .. code-block:: cpp + + std::string groupName = "my_group"; + int rank = YR::Collective::GetRank(groupName); + int srcRank = 0; + + std::vector data(100); + if (rank == srcRank) { + // 源进程初始化数据 + for (int i = 0; i < 100; i++) { + data[i] = i; + } + } + + YR::Collective::Broadcast(data.data(), data.data(), data.size(), YR::DataType::INT, srcRank, groupName); + // All processes' data now contains the same content (from srcRank) + + 参数: + - **sendbuf** - 发送缓冲区,仅在 srcRank 进程中使用,包含要广播的数据。 + - **recvbuf** - 接收缓冲区,用于存储广播的数据。所有进程的 recvbuf 最终包含相同的数据。 + - **count** - 数据元素的数量。 + - **dtype** - 数据类型。 + - **srcRank** - 源进程的 rank。 + - **groupName** - 组的名称,默认为 "default"。 + + .. note:: + 注意事项 + + - sendbuf 和 recvbuf 可以指向同一块内存(in-place 操作)。 + - 所有进程必须调用此函数,且参数 count、dtype、srcRank 必须一致。 + + 抛出: + :cpp:class:`Exception` - 如果组不存在或尚未创建,将抛出异常。 + +Scatter +------- + +.. cpp:function:: void YR::Collective::Scatter(const std::vector sendbuf, void *recvbuf, int count, \ + DataType dtype, int srcRank, \ + const std::string &groupName = DEFAULT_GROUP_NAME) + + 从源进程分散数据到所有进程。 + + 此函数从 srcRank 进程的 sendbuf 向量中分散数据到所有进程。srcRank 进程的 sendbuf 向量应包含 worldSize 个缓冲区,每个缓冲区对应一个目标 rank。 + + .. code-block:: cpp + + std::string groupName = "my_group"; + int rank = YR::Collective::GetRank(groupName); + int worldSize = YR::Collective::GetWorldSize(groupName); + int srcRank = 0; + + std::vector recvData(10); + + if (rank == srcRank) { + // 源进程准备要分发的数据 + std::vector> sendData(worldSize, std::vector(10)); + for (int i = 0; i < worldSize; i++) { + for (int j = 0; j < 10; j++) { + sendData[i][j] = i * 10 + j; + } + } + + std::vector sendbuf; + for (auto &vec : sendData) { + sendbuf.push_back(vec.data()); + } + + YR::Collective::Scatter(sendbuf, recvData.data(), 10, YR::DataType::INT, srcRank, groupName); + } else { + // 非源进程传入空向量 + std::vector sendbuf; // Can be empty + YR::Collective::Scatter(sendbuf, recvData.data(), 10, YR::DataType::INT, srcRank, groupName); + } + + // 每个进程的 recvData 包含来自 srcRank 的数据 + + 参数: + - **sendbuf** - 发送缓冲区向量,仅在 srcRank 进程中使用。大小应为 worldSize,每个元素指向要发送给对应 rank 的数据。 + - **recvbuf** - 接收缓冲区,用于存储从 srcRank 接收的数据。 + - **count** - 每个进程接收的数据元素数量。 + - **dtype** - 数据类型。 + - **srcRank** - 源进程的 rank。 + - **groupName** - 组的名称,默认为 "default"。 + + .. note:: + 注意事项 + + - sendbuf 向量仅在 srcRank 进程中使用,其他进程可以传入空向量。 + - sendbuf 向量的大小必须等于 worldSize。 + + 抛出: + :cpp:class:`Exception` - 如果组不存在或尚未创建,将抛出异常。 + +Barrier +------- + +.. cpp:function:: void YR::Collective::Barrier(const std::string &groupName = DEFAULT_GROUP_NAME) + + 同步屏障,阻塞直到组中的所有进程都到达此点。 + + 此函数用于同步所有进程,确保所有进程在执行后续代码之前都到达了 Barrier 调用点。 + + .. code-block:: cpp + + std::string groupName = "my_group"; + int rank = YR::Collective::GetRank(groupName); + + // 执行一些异步操作... + std::cout << "Rank " << rank << " before barrier" << std::endl; + + YR::Collective::Barrier(groupName); + + // 所有进程在此同步等待 + std::cout << "Rank " << rank << " after barrier" << std::endl; + + 参数: + - **groupName** - 组的名称,默认为 "default"。 + + .. note:: + 注意事项 + + - 所有进程必须调用此函数,否则会导致死锁。 + - 此函数会阻塞,直到所有进程都到达 Barrier 调用点。 + + 抛出: + :cpp:class:`Exception` - 如果组不存在或尚未创建,将抛出异常。 + +Send +---- + +.. cpp:function:: void YR::Collective::Send(const void *sendbuf, int count, DataType dtype, int dstRank, \ + int tag = 0, const std::string &groupName = DEFAULT_GROUP_NAME) + + 向指定进程发送数据。 + + 此函数用于点对点通信,向指定的目标进程发送数据。需要与 Recv 函数配对使用,通过 tag 参数匹配对应的接收操作。 + + .. code-block:: cpp + + std::string groupName = "my_group"; + int rank = YR::Collective::GetRank(groupName); + int worldSize = YR::Collective::GetWorldSize(groupName); + + if (rank == 0) { + // Rank 0 向 Rank 1 发送数据 + std::vector data(100, 1.0f); + YR::Collective::Send(data.data(), data.size(), YR::DataType::FLOAT, 1, 0, groupName); + } else if (rank == 1) { + // Rank 1 接收 Rank 0 发送的数据 + std::vector recvData(100); + YR::Collective::Recv(recvData.data(), recvData.size(), YR::DataType::FLOAT, 0, 0, groupName); + } + + 参数: + - **sendbuf** - 发送缓冲区,包含要发送的数据。 + - **count** - 要发送的数据元素数量。 + - **dtype** - 数据类型。 + - **dstRank** - 目标进程的 rank。 + - **tag** - 消息标签,用于匹配对应的 Recv 操作,默认为 0。 + - **groupName** - 组的名称,默认为 "default"。 + + .. note:: + 注意事项 + + - Send 和 Recv 必须配对使用,且 tag 必须匹配。 + - Send 操作会阻塞,直到对应的 Recv 操作被调用。 + + 抛出: + :cpp:class:`Exception` - 如果组不存在或尚未创建,将抛出异常。 + +Recv +---- + +.. cpp:function:: void YR::Collective::Recv(void *recvbuf, int count, DataType dtype, int srcRank, \ + int tag = 0, const std::string &groupName = DEFAULT_GROUP_NAME) + + 从指定进程接收数据。 + + 此函数用于点对点通信,从指定的源进程接收数据。需要与 Send 函数配对使用,通过 tag 参数匹配对应的发送操作。 + + .. code-block:: cpp + + std::string groupName = "my_group"; + int rank = YR::Collective::GetRank(groupName); + int worldSize = YR::Collective::GetWorldSize(groupName); + + if (rank == 0) { + // Rank 0 sends data to Rank 1 + std::vector data(100, 1.0f); + YR::Collective::Send(data.data(), data.size(), YR::DataType::FLOAT, 1, 0, groupName); + } else if (rank == 1) { + // Rank 1 receives data from Rank 0 + std::vector recvData(100); + YR::Collective::Recv(recvData.data(), recvData.size(), YR::DataType::FLOAT, 0, 0, groupName); + } + + 参数: + - **recvbuf** - 接收缓冲区,用于存储接收到的数据。 + - **count** - 要接收的数据元素数量。 + - **dtype** - 数据类型。 + - **srcRank** - 源进程的 rank。 + - **tag** - 消息标签,用于匹配对应的 Send 操作,默认为 0。 + - **groupName** - 组的名称,默认为 "default"。 + + .. note:: + 注意事项 + + - Recv 和 Send 必须配对使用,且 tag 必须匹配。 + - Recv 操作会阻塞,直到对应的 Send 操作被调用。 + - count 和 dtype 必须与对应的 Send 操作一致。 + + 抛出: + :cpp:class:`Exception` - 如果组不存在或尚未创建,将抛出异常。 + diff --git a/docs/multi_language_function_programming_interface/api/distributed_programming/zh_cn/C++/index.rst b/docs/multi_language_function_programming_interface/api/distributed_programming/zh_cn/C++/index.rst index 88d123ee3c0fb71f991156156ce4b5db64f32411..2d004a0ec66e7f7befcc9bed8445087901a78896 100644 --- a/docs/multi_language_function_programming_interface/api/distributed_programming/zh_cn/C++/index.rst +++ b/docs/multi_language_function_programming_interface/api/distributed_programming/zh_cn/C++/index.rst @@ -57,7 +57,14 @@ C++ :maxdepth: 1 Stream - + +.. toctree:: + :glob: + :hidden: + :maxdepth: 1 + + Collective + .. toctree:: :glob: :hidden: @@ -209,6 +216,17 @@ C++ - 在 openYuanrong 函数间传递共享的有序无界内存数据集。 +集合通信 API +----------------- + +.. list-table:: + :header-rows: 0 + :widths: 30 70 + + * - :doc:`Collective` + - 集合通信 API 全量说明,涵盖类型定义、组管理接口与集合通信算子。 + + 函数互调 API ----------------------- diff --git a/install_tools.sh b/install_tools.sh index 699c1e98561a2e8d830b0f7a23837c64659fa33e..5b623b820179bcb51a19656ce38e77e62b6eb9c0 100644 --- a/install_tools.sh +++ b/install_tools.sh @@ -35,7 +35,7 @@ sudo yum install -y \ sqlite-devel bison gawk texinfo glibc glibc-devel wget bzip2-devel sudo \ rsync nfs-utils xz libuuid unzip util-linux-devel cpio libcap-devel libatomic \ chrpath numactl-devel openeuler-lsb libasan dos2unix net-tools pigz cmake \ - protobuf protobuf-devel patchelf + protobuf protobuf-devel patchelf libibverbs-devel libibverbs echo "✅ Detected architecture: $ARCH → using package suffix: $PKG_ARCH" diff --git a/patch/gloo-fix-sign-compare.patch b/patch/gloo-fix-sign-compare.patch new file mode 100644 index 0000000000000000000000000000000000000000..2b0506abd1b93c00776376c9e59184b19efce604 --- /dev/null +++ b/patch/gloo-fix-sign-compare.patch @@ -0,0 +1,40 @@ +diff -ruN gloo/math.h gloo/math.h +--- gloo/math.h 2025-01-01 00:00:00.000000000 +0800 ++++ gloo/math.h 2025-01-01 00:00:00.000000000 +0800 +@@ -17,7 +17,7 @@ + T* c = static_cast(c_); + const T* a = static_cast(a_); + const T* b = static_cast(b_); +- for (auto i = 0; i < n; i++) { ++ for (size_t i = 0; i < n; i++) { + c[i] = a[i] + b[i]; + } + } +@@ -32,7 +32,7 @@ + T* c = static_cast(c_); + const T* a = static_cast(a_); + const T* b = static_cast(b_); +- for (auto i = 0; i < n; i++) { ++ for (size_t i = 0; i < n; i++) { + c[i] = a[i] * b[i]; + } + } +@@ -47,7 +47,7 @@ + T* c = static_cast(c_); + const T* a = static_cast(a_); + const T* b = static_cast(b_); +- for (auto i = 0; i < n; i++) { ++ for (size_t i = 0; i < n; i++) { + c[i] = std::max(a[i], b[i]); + } + } +@@ -62,7 +62,7 @@ + T* c = static_cast(c_); + const T* a = static_cast(a_); + const T* b = static_cast(b_); +- for (auto i = 0; i < n; i++) { ++ for (size_t i = 0; i < n; i++) { + c[i] = std::min(a[i], b[i]); + } + } + diff --git a/test/api/local_mode_test.cpp b/test/api/local_mode_test.cpp index 6432ad06c42b3954fb628ba77bb4e04bdb716c66..914792fc2077e5f915618fd14ec1572677919589 100644 --- a/test/api/local_mode_test.cpp +++ b/test/api/local_mode_test.cpp @@ -240,7 +240,7 @@ TEST_F(LocalTest, When_Invoke_Actor_Should_Return_Final_Correct_Result) TEST_F(LocalTest, When_Invoke_Actor_Should_Not_Contain_WorkerId) { auto counter = YR::Instance(Counter::FactoryCreate).Invoke(1); - auto id = counter.GetInstanceId(); + auto id = counter.GetObjectId(); EXPECT_EQ(id.size(), 20); auto res = counter.Function(&Counter::Add).Invoke(3); auto v = *YR::Get(res); diff --git a/test/st/cpp/src/base/actor_test.cpp b/test/st/cpp/src/base/actor_test.cpp index 9987cc4a7a9771d923e371846c7873c01a86071b..548d0c0a46d84bfd0ecf9405e06e50d12de3b4e2 100644 --- a/test/st/cpp/src/base/actor_test.cpp +++ b/test/st/cpp/src/base/actor_test.cpp @@ -925,7 +925,7 @@ TEST_F(ActorTest, GroupInvokeAfterTerminate) TEST_F(ActorTest, CheckActorObjIdSuccessfully) { auto creator = YR::Instance(Counter::FactoryCreate).Invoke(1); - auto id = creator.GetInstanceId(); + auto id = creator.GetObjectId(); EXPECT_EQ(id.size(), 20); auto ret = creator.Function(&Counter::Add).Invoke(1); EXPECT_EQ(*YR::Get(ret), 2); diff --git a/test/st/cpp/src/base/collective_test.cpp b/test/st/cpp/src/base/collective_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..996c946cf4f1a190d6aee1f6a2e35069c01fdf1b --- /dev/null +++ b/test/st/cpp/src/base/collective_test.cpp @@ -0,0 +1,369 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "user_common_func.h" +#include "utils.h" +#include "yr/yr.h" + +using testing::HasSubstr; + +class CollectiveTest : public testing::Test { +public: + CollectiveTest() {}; + ~CollectiveTest() {}; + static void SetUpTestCase() {}; + static void TearDownTestCase() {}; + + void SetUp() + { + YR::Config config; + config.mode = YR::Config::Mode::CLUSTER_MODE; + auto info = YR::Init(config); + std::cout << "job id: " << info.jobId << std::endl; + }; + + void TearDown() + { + YR::Finalize(); + }; +}; + +/** + * @title: 试图创建非法groupName group + * @step: 1.创建CollectiveActor实例 + * @step: 2.调用CollectiveActor::InitCollectiveGroup,在acto中创建非法group + * @step: 3.调用CreateCollectiveGroup,创建非法group + * + * @expect: 2.预期异常抛出 + * @expect: 3.预期异常抛出 + */ +TEST_F(CollectiveTest, InvalidGroupNameTest) +{ + auto ins = YR::Instance(CollectiveActor::FactoryCreate).Invoke(); + std::string groupName1 = "test@group1"; + std::string groupName2 = "test/group1"; + auto res = ins.Function(&CollectiveActor::InitCollectiveGroup).Invoke(groupName1, 0, 1); + EXPECT_THROW_WITH_CODE_AND_MSG(YR::Get(res), 2002, "groupName is invalid. It should match the regex"); + + res = ins.Function(&CollectiveActor::InitCollectiveGroup).Invoke(groupName2, 0, 1); + EXPECT_THROW_WITH_CODE_AND_MSG(YR::Get(res), 2002, "groupName is invalid. It should match the regex"); + + YR::Collective::CollectiveGroupSpec spec1{ + .worldSize = 1, + .groupName = groupName1, + }; + EXPECT_THROW_WITH_CODE_AND_MSG(YR::Collective::CreateCollectiveGroup(spec1, {ins.GetInstanceId()}, {0}), 1001, + "groupName is invalid. It should match the regex"); + + YR::Collective::CollectiveGroupSpec spec2{ + .worldSize = 1, + .groupName = groupName2, + }; + EXPECT_THROW_WITH_CODE_AND_MSG(YR::Collective::CreateCollectiveGroup(spec2, {ins.GetInstanceId()}, {0}), 1001, + "groupName is invalid. It should match the regex"); +} + +/** + * @title: InitCollectiveGroup函数调用成功 + * @step: 1.创建CollectiveActor实例 + * @step: 2.调用CollectiveActor::InitCollectiveGroup,在acto中创建对应group + * @step: 3.调用CollectiveActor::Compute sum 函数,传入{1, 2, 3, 4, 5} + * @step: 4.调用CollectiveActor::Compute min 函数,传入{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1} + * @step: 5.调用CollectiveActor::Compute max 函数,传入{5, 4, 3, 2, 1}, {1, 2, 3, 4, 5} + * @step: 6.调用CollectiveActor::Compute product 函数,传入{5, 4, 3, 2, 1}, {1, 2, 3, 4, 5} + * + * @expect: 3.预期无异常抛出,返回值为两个actor中各项之和,30 + * @expect: 4.预期无异常抛出,返回值为两个actor中各项取最小值之和,9 + * @expect: 5.预期无异常抛出,返回值为两个actor中各项取最大值之和,21 + * @expect: 6.预期无异常抛出,返回值为两个actor中各项相乘之和,35 + */ +TEST_F(CollectiveTest, InitGroupInActorTest) +{ + auto ins1 = YR::Instance(CollectiveActor::FactoryCreate).Invoke(); + auto ins2 = YR::Instance(CollectiveActor::FactoryCreate).Invoke(); + + std::string groupName = "test-group1"; + YR::Collective::DestroyCollectiveGroup(groupName); + + ins1.Function(&CollectiveActor::InitCollectiveGroup).Invoke(groupName, 0, 2); + auto ret = ins2.Function(&CollectiveActor::InitCollectiveGroup).Invoke(groupName, 1, 2); + YR::Get(ret); + + std::vector input = {1, 2, 3, 4, 5}; + auto res1 = + ins1.Function(&CollectiveActor::Compute).Invoke(input, groupName, static_cast(YR::ReduceOp::SUM)); + auto res2 = + ins2.Function(&CollectiveActor::Compute).Invoke(input, groupName, static_cast(YR::ReduceOp::SUM)); + EXPECT_EQ(*YR::Get(res1), 30); + EXPECT_EQ(*YR::Get(res2), 30); + + input = {1, 2, 3, 4, 5}; + std::vector input2 = {5, 4, 3, 2, 1}; + res1 = ins1.Function(&CollectiveActor::Compute).Invoke(input, groupName, static_cast(YR::ReduceOp::MIN)); + res2 = ins2.Function(&CollectiveActor::Compute).Invoke(input2, groupName, static_cast(YR::ReduceOp::MIN)); + EXPECT_EQ(*YR::Get(res1), 9); + EXPECT_EQ(*YR::Get(res2), 9); + + res1 = ins1.Function(&CollectiveActor::Compute).Invoke(input2, groupName, static_cast(YR::ReduceOp::MAX)); + res2 = ins2.Function(&CollectiveActor::Compute).Invoke(input, groupName, static_cast(YR::ReduceOp::MAX)); + EXPECT_EQ(*YR::Get(res1), 21); + EXPECT_EQ(*YR::Get(res2), 21); + + res1 = + ins1.Function(&CollectiveActor::Compute).Invoke(input2, groupName, static_cast(YR::ReduceOp::PRODUCT)); + res2 = + ins2.Function(&CollectiveActor::Compute).Invoke(input, groupName, static_cast(YR::ReduceOp::PRODUCT)); + EXPECT_EQ(*YR::Get(res1), 35); + EXPECT_EQ(*YR::Get(res2), 35); + + res1 = ins1.Function(&CollectiveActor::DestroyCollectiveGroup).Invoke(groupName); + res2 = ins2.Function(&CollectiveActor::DestroyCollectiveGroup).Invoke(groupName); + YR::Get(res1); + YR::Get(res2); +} + +/** + * @title: CreateCollectiveGroup函数调用成功 + * @step: 1.创建CollectiveActor实例 + * @step: 2.调用CreateCollectiveGroup,在driver中创建对应group + * @step: 3.调用CollectiveActor::Compute函数,传入{1, 2, 3, 4} + * @step: 4.调用CollectiveActor::ComputeDouble函数,传入{1.1, 2.2, 3.3, 4.4} + * @step: 5.调用DestroyCollectiveGroup,清理CollectiveGroup + * + * @expect: 预期无异常抛出,返回值为各个actor中各项之和,40 和 44 + */ +TEST_F(CollectiveTest, CreateGroupInDriverTest) +{ + std::vector> instances; + std::vector instanceIDs; + for (int i = 0; i < 4; ++i) { + auto ins = YR::Instance(CollectiveActor::FactoryCreate).Invoke(); + instances.push_back(ins); + instanceIDs.push_back(ins.GetInstanceId()); + } + + std::string groupName = "test-group2"; + YR::Collective::DestroyCollectiveGroup(groupName); + YR::Collective::CollectiveGroupSpec spec{ + .worldSize = 4, + .groupName = groupName, + }; + YR::Collective::CreateCollectiveGroup(spec, instanceIDs, {0, 1, 2, 3}); + + std::vector input = {1, 2, 3, 4}; + std::vector> res; + for (int i = 0; i < 4; ++i) { + res.push_back(instances[i] + .Function(&CollectiveActor::Compute) + .Invoke(input, groupName, static_cast(YR::ReduceOp::SUM))); + } + + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(*YR::Get(res[i]), 40); + } + + std::vector input2 = {1.1, 2.2, 3.3, 4.4}; + std::vector> res2; + for (int i = 0; i < 4; ++i) { + res2.push_back(instances[i] + .Function(&CollectiveActor::ComputeDouble) + .Invoke(input2, groupName, static_cast(YR::ReduceOp::SUM))); + } + + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(*YR::Get(res2[i]), 44); + } + + YR::Collective::DestroyCollectiveGroup(groupName); +} + +/** + * @title: Send Recv函数调用成功 + * @step: 1.创建CollectiveActor实例 + * @step: 2.调用CreateCollectiveGroup,在driver中创建对应group + * @step: 3.调用Send函数,传入{1, 2, 3, 4},另一个actor调用Recv + * + * @expect: 预期无异常抛出,返回值为各项之和,10 + */ +TEST_F(CollectiveTest, SendRecvTest) +{ + auto ins1 = YR::Instance(CollectiveActor::FactoryCreate).Invoke(); + auto ins2 = YR::Instance(CollectiveActor::FactoryCreate).Invoke(); + + std::string groupName = "test-group3"; + YR::Collective::DestroyCollectiveGroup(groupName); + YR::Collective::CollectiveGroupSpec spec{ + .worldSize = 2, + .groupName = groupName, + }; + YR::Collective::CreateCollectiveGroup(spec, {ins1.GetInstanceId(), ins2.GetInstanceId()}, {0, 1}); + std::vector input = {1, 2, 3, 4}; + ins1.Function(&CollectiveActor::Send).Invoke(groupName, input, 1, 1234); + auto ret = ins2.Function(&CollectiveActor::Recv).Invoke(groupName, 0, 1234, 4); + EXPECT_EQ(*YR::Get(ret), 10); + + std::vector input2 = {2, 2, 3, 4, 5}; + ret = ins1.Function(&CollectiveActor::Recv).Invoke(groupName, 1, 1234, 5); + ins2.Function(&CollectiveActor::Send).Invoke(groupName, input2, 0, 1234); + EXPECT_EQ(*YR::Get(ret), 16); + + YR::Collective::DestroyCollectiveGroup(groupName); +} + +/** + * @title: AllGather函数调用成功 + * @step: 1.创建CollectiveActor实例 + * @step: 2.调用CreateCollectiveGroup,在driver中创建对应group + * @step: 3.调用CollectiveActor::AllGather函数,传入{1, 2, 3, 4} + * + * @expect: 预期无异常抛出,返回值为各个actor中各项之和40, 总数 16个数据 + */ +TEST_F(CollectiveTest, AllGatherTest) +{ + std::vector> instances; + std::vector instanceIDs; + for (int i = 0; i < 4; ++i) { + auto ins = YR::Instance(CollectiveActor::FactoryCreate).Invoke(); + instances.push_back(ins); + instanceIDs.push_back(ins.GetInstanceId()); + } + + std::string groupName = "test-group4"; + YR::Collective::DestroyCollectiveGroup(groupName); + YR::Collective::CollectiveGroupSpec spec{ + .worldSize = 4, + .groupName = groupName, + }; + YR::Collective::CreateCollectiveGroup(spec, instanceIDs, {0, 1, 2, 3}); + + std::vector input = {1, 2, 3, 4}; + std::vector>> res; + for (int i = 0; i < 4; ++i) { + res.push_back(instances[i].Function(&CollectiveActor::AllGather).Invoke(groupName, input)); + } + + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(YR::Get(res[i])->first, 40); + EXPECT_EQ(YR::Get(res[i])->second, 16); + } + + YR::Collective::DestroyCollectiveGroup(groupName); +} + +/** + * @title: Broadcast函数调用成功 + * @step: 1.创建CollectiveActor实例 + * @step: 2.调用CreateCollectiveGroup,在driver中创建对应group + * @step: 3.调用CollectiveActor::Broadcast,ins1 传入{1, 2, 3, 4, 5} + * + * @expect: 预期无异常抛出,返回值为各个actor中各项之和15 + */ +TEST_F(CollectiveTest, BroadcastTest) +{ + std::vector> instances; + std::vector instanceIDs; + for (int i = 0; i < 4; ++i) { + auto ins = YR::Instance(CollectiveActor::FactoryCreate).Invoke(); + instances.push_back(ins); + instanceIDs.push_back(ins.GetInstanceId()); + } + + std::string groupName = "test-group5"; + YR::Collective::DestroyCollectiveGroup(groupName); + YR::Collective::CollectiveGroupSpec spec{ + .worldSize = 4, + .groupName = groupName, + .timeout = 10000, + }; + YR::Collective::CreateCollectiveGroup(spec, instanceIDs, {0, 1, 2, 3}); + + std::vector input = {1, 2, 3, 4, 5}; + std::vector> res; + for (int i = 0; i < 4; ++i) { + res.push_back(instances[i].Function(&CollectiveActor::Broadcast).Invoke(groupName, input, 0)); + } + + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(*YR::Get(res[i]), 15); + } + + YR::Collective::DestroyCollectiveGroup(groupName); +} + +/** + * @title: Scatter函数调用成功 + * @step: 1.创建CollectiveActor实例 + * @step: 2.调用CreateCollectiveGroup,在driver中创建对应group + * @step: 3.调用CollectiveActor::Scatter,ins1 传入{1, 2, 3, 4, 5} + * @step: 4.调用CollectiveActor::Scatter,ins1 传入{1, 1, 2, 2, 3, 3, 4, 4, 5, 5} + * + * @expect: 预期无异常抛出,返回值为各个actor中各自分到一项 + */ +TEST_F(CollectiveTest, ScatterTest) +{ + std::vector> instances; + std::vector instanceIDs; + for (int i = 0; i < 4; ++i) { + auto ins = YR::Instance(CollectiveActor::FactoryCreate).Invoke(); + instances.push_back(ins); + instanceIDs.push_back(ins.GetInstanceId()); + } + + std::string groupName = "test-group6"; + YR::Collective::DestroyCollectiveGroup(groupName); + YR::Collective::CollectiveGroupSpec spec{ + .worldSize = 4, + .groupName = groupName, + .timeout = 1000, + }; + YR::Collective::CreateCollectiveGroup(spec, instanceIDs, {0, 1, 2, 3}); + + std::vector> input = {{1}, {2}, {3}, {4}}; + std::vector> res; + for (int i = 0; i < 4; ++i) { + res.push_back(instances[i].Function(&CollectiveActor::Scatter).Invoke(groupName, input, 0, 1)); + } + + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(*YR::Get(res[i]), i + 1); + } + + input = {{1, 1}, {2, 2}, {3, 3}, {4, 4}}; + res = {}; + for (int i = 0; i < 4; ++i) { + res.push_back(instances[i].Function(&CollectiveActor::Scatter).Invoke(groupName, input, 0, 2)); + } + + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(*YR::Get(res[i]), (i + 1) * 2); + } + + YR::Collective::DestroyCollectiveGroup(groupName); +} + +/** + * todo + * 6. IBVERBS + * 9. 重复创建报错? + * 10. driver也作为一个rank,如何获取自己的instanceID + * 11. 能否动态增删rank + * 12. ds中的值会删不掉,现在一定要显式调destroy才行 + */ diff --git a/test/st/cpp/src/user_common_func.cpp b/test/st/cpp/src/user_common_func.cpp index 209526a5ef8f8a5a339cc892dd7217c522338ac4..5a67d2192614bfdf5d2156f4bdb78d08e393a4fc 100644 --- a/test/st/cpp/src/user_common_func.cpp +++ b/test/st/cpp/src/user_common_func.cpp @@ -630,3 +630,127 @@ int CallCluster(int x) } YR_INVOKE(CallLocal, CallCluster) + +int CollectiveActor::InitCollectiveGroup(std::string &groupName, int rank, int worldSize) +{ + YR::Collective::CollectiveGroupSpec spec{ + .worldSize = worldSize, + .groupName = groupName, + }; + YR::Collective::InitCollectiveGroup(spec, rank); + YR::Collective::Barrier(groupName); + return 0; +} + +int CollectiveActor::Barrier(std::string &groupName) +{ + YR::Collective::Barrier(groupName); + return 0; +} + +int CollectiveActor::Compute(std::vector in, std::string &groupName, uint8_t op) +{ + int *output = new int[in.size()]; + YR::Collective::AllReduce(in.data(), output, in.size(), YR::DataType::INT, YR::ReduceOp(op), groupName); + YR::Collective::Barrier(groupName); + + int result = 0; + for (int i = 0; i < in.size(); ++i) { + result += *output++; + } + return result; +} + +double CollectiveActor::ComputeDouble(std::vector in, std::string &groupName, uint8_t op) +{ + double *output = new double[in.size()]; + YR::Collective::AllReduce(in.data(), output, in.size(), YR::DataType::DOUBLE, YR::ReduceOp(op), groupName); + YR::Collective::Barrier(groupName); + + double result = 0; + for (int i = 0; i < in.size(); ++i) { + result += *output++; + } + return result; +} + +int CollectiveActor::Recv(std::string &groupName, int from, int tag, int count) +{ + if (count == 0 || count > 1000) { + return 1; + } + int *output = new int[count]; + YR::Collective::Recv(output, count, YR::DataType::INT, from, tag, groupName); + int result = 0; + for (int i = 0; i < count; ++i) { + result += *output++; + } + return result; +} + +int CollectiveActor::Send(std::string &groupName, std::vector in, int dest, int tag) +{ + YR::Collective::Send(in.data(), in.size(), YR::DataType::INT, dest, tag, groupName); + return 0; +} + +std::pair CollectiveActor::AllGather(std::string &groupName, std::vector in) +{ + int inputSize = in.size(); + // Gather input size from all ranks first + std::vector recvData(inputSize * YR::Collective::GetWorldSize(groupName)); + YR::Collective::AllGather(in.data(), recvData.data(), inputSize, YR::DataType::INT, groupName); + YR::Collective::Barrier(groupName); + + int sum = 0; + int size = 0; + for (int v : recvData) { + sum += v; + size++; + } + return {sum, size}; +} + +int CollectiveActor::Broadcast(std::string &groupName, std::vector in, int srcRank) +{ + std::vector out(in.size(), 0); + YR::Collective::Broadcast(in.data(), out.data(), in.size(), YR::DataType::INT, srcRank, groupName); + YR::Collective::Barrier(groupName); + + int sum = 0; + for (auto v : out) { + sum += v; + } + return sum; +} + +int CollectiveActor::Scatter(std::string &groupName, std::vector> in, int srcRank, int count) +{ + std::vector out(count, 0); + + std::vector input; + if (YR::Collective::GetRank(groupName) == srcRank) { + for (const auto &item : in) { + input.push_back(const_cast(static_cast(item.data()))); + } + } + YR::Collective::Scatter(input, out.data(), count, YR::DataType::INT, srcRank, groupName); + YR::Collective::Barrier(groupName); + + int sum = 0; + for (auto v : out) { + sum += v; + } + return sum; +} + +int CollectiveActor::DestroyCollectiveGroup(std::string &groupName) +{ + YR::Collective::DestroyCollectiveGroup(groupName); + return 0; +} + +YR_INVOKE(CollectiveActor::FactoryCreate, &CollectiveActor::InitCollectiveGroup, &CollectiveActor::Barrier, + &CollectiveActor::Compute, &CollectiveActor::ComputeDouble, &CollectiveActor::AllGather, + &CollectiveActor::Broadcast, &CollectiveActor::Scatter, &CollectiveActor::DestroyCollectiveGroup, + &CollectiveActor::Recv, &CollectiveActor::Send); \ No newline at end of file diff --git a/test/st/cpp/src/user_common_func.h b/test/st/cpp/src/user_common_func.h index 1cb04d18dea41487d901b5e067111d6224e1744e..6ec3eff317dcec94fd2e1e7c66c0e96ae93af016 100644 --- a/test/st/cpp/src/user_common_func.h +++ b/test/st/cpp/src/user_common_func.h @@ -313,4 +313,40 @@ int FunctionNotRegistered(); int FunctionRegistered(); int CallLocal(int x); -int CallCluster(int x); \ No newline at end of file +int CallCluster(int x); + +class CollectiveActor { +public: + int count; + + CollectiveActor() = default; + + ~CollectiveActor() = default; + + static CollectiveActor *FactoryCreate() + { + return new CollectiveActor(); + } + + int InitCollectiveGroup(std::string &groupName, int rank, int worldSize); + + int Barrier(std::string &groupName); + + int Compute(std::vector in, std::string &groupName, uint8_t op); + + double ComputeDouble(std::vector in, std::string &groupName, uint8_t op); + + int Recv(std::string &groupName, int from, int tag, int count); + + int Send(std::string &groupName, std::vector in, int dest, int tag); + + std::pair AllGather(std::string &groupName, std::vector in); + + int Broadcast(std::string &groupName, std::vector in, int srcRank); + + int Scatter(std::string &groupName, std::vector> in, int srcRank, int count); + + int DestroyCollectiveGroup(std::string &groupName); + + YR_STATE(count); +}; \ No newline at end of file diff --git a/tools/openSource.txt b/tools/openSource.txt index 41c17b82d565d187eef4309c3eb1008d237aef67..8ad4c8b22792231591561032a6ced94edc1d60ae 100644 --- a/tools/openSource.txt +++ b/tools/openSource.txt @@ -17,4 +17,5 @@ libboundscheck,v1.1.16,runtime;litebus;functionsystem;metrics,https://gitee.com/ curl,8.8.0,litebus,https://gitee.com/mirrors/curl/repository/archive/curl-8_8_0.zip,73c70c94f487c5ae26f9f27094249e40bb1667ae6c0406a75c3b11f86f0c1128,test gtest_1_12_1,release-1.12.1,litebus;functionsystem;metrics;logs,https://gitee.com/mirrors/googletest/repository/archive/release-1.12.1.zip,4f1037f17462dbcc4d715f2fc8212c03431ad06e7a4b73bb49b3f534a13b21f1,test gtest_1_10_0,release-1.10.0,functionsystem,https://gitee.com/mirrors/googletest/repository/archive/release-1.10.0.zip,db55081e6b5a5f820d052f8d37e0d29835dd5a2bd5b8d4d50590b7834f9abfae,test +gloo,main,runtime,https://github.com/pytorch/gloo/archive/refs/heads/main.zip,ddfb9627304025f294c78fd75ae84d7b9a662213d1b1c955afe6f91bfeb49254, ,,,,,, \ No newline at end of file