diff --git a/WORKSPACE b/WORKSPACE index 72dcfb865565d66fb11a8a2999c0ec693d50a4b1..862b00b360bb851b1c0ab073ac02f387f4a1796f 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -176,6 +176,12 @@ http_archive( build_file = "@//bazel:jacoco.bzl" ) +new_local_repository( + name = "gloo", + path = "../thirdparty/gloo/install", + build_file = "@//bazel:gloo.bzl", +) + maybe( new_local_repository, name = "securec", diff --git a/api/cpp/BUILD.bazel b/api/cpp/BUILD.bazel index 8d9ea9f85665f93f405dc66acab46cd747cc9ba0..68ac502aa0dd2b159614afcf78e80fcedc0a8438 100644 --- a/api/cpp/BUILD.bazel +++ b/api/cpp/BUILD.bazel @@ -23,12 +23,15 @@ cc_library( "src/utils/*.cpp", "src/executor/*.h", "src/executor/*.cpp", + "src/collective/*.h", + "src/collective/*.cpp", ]) + ["src/utils/version.h"], hdrs = glob([ "include/yr/*.h", "include/yr/api/*.h", "include/yr/parallel/*.h", "include/yr/parallel/detail/*.h", + "include/yr/collective/*.h", ]), copts = COPTS, linkopts = ["-lstdc++fs"], @@ -43,6 +46,7 @@ cc_library( "@securec", "@nlohmann_json", "@yaml-cpp", + "@gloo//:gloo", ], alwayslink = True, ) @@ -64,6 +68,7 @@ cc_library( hdrs = glob([ "include/yr/*.h", "include/yr/api/*.h", + "include/yr/collective/*.h", "include/yr/parallel/*.h", "include/yr/parallel/detail/*.h", ]), diff --git a/api/cpp/include/yr/collective/collective.h b/api/cpp/include/yr/collective/collective.h new file mode 100644 index 0000000000000000000000000000000000000000..58da7128977df967d92952e3e26283be9f03adee --- /dev/null +++ b/api/cpp/include/yr/collective/collective.h @@ -0,0 +1,143 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "reduce_op.h" + +namespace YR::collective { +enum Backend : uint8_t { + GLOO = 0, + + INVALID, +}; + +class CollectiveGroup { +public: + CollectiveGroup(std::string groupName, int worldSize, int rank, Backend backend) + : groupName_(std::move(groupName)), rank_(rank), backend_(backend), worldSize_(worldSize) + { + } + + virtual void AllReduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op) = 0; + + virtual void Reduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, + int dstRank) = 0; + + virtual void AllGather(const void *sendbuf, void *recvbuf, int count, DataType dtype) = 0; + + virtual void Barrier() = 0; + + virtual void Scatter(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank) = 0; + + virtual void Broadcast(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank) = 0; + + virtual void Recv(void *recvbuf, int count, int srcRank, int tag) = 0; + + virtual void Send(const void *sendbuf, int count, int dstRank, int tag) = 0; + + int GetRank() const; + std::string GetGroupName(); + Backend GetBackend(); + int GetWorldSize() const; + +protected: + std::string groupName_; + + int rank_; + + Backend backend_; + + int worldSize_; +}; + +/** + * init collective group in actor + * + * @param worldSize + * @param rank + * @param groupName + */ +void InitCollectiveGroup(int worldSize, int rank, const std::string &groupName, Backend backend); + +/** + * create collective group with actor ids in driver + * + * @param instanceIDs + * @param worldSize + * @param ranks + * @param groupName + */ +void CreateCollectiveGroup(const std::vector &instanceIDs, int worldSize, const std::vector &ranks, + const std::string &groupName, Backend backend); +/** + * + * + * @param groupName + */ +void DestroyCollectiveGroup(const std::string &groupName); + +void AllReduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, + const std::string &groupName); + +void Reduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, int dstRank, + const std::string &groupName); + +void AllGather(const void *sendbuf, void *recvbuf, int count, DataType dtype, const std::string &groupName); + +void Barrier(const std::string &groupName); + +void Scatter(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank, const std::string &groupName); + +void Broadcast(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank, + const std::string &groupName); + +void Recv(void *recvbuf, int count, int srcRank, int tag, const std::string &groupName); + +void Send(const void *sendbuf, int count, int dstRank, int tag, const std::string &groupName); + +class CollectiveGroupMgr { +public: + static CollectiveGroupMgr &GetInstance() + { + static CollectiveGroupMgr instance; + return instance; + } + + std::shared_ptr CheckAndCreateGroup(const std::string &groupName); + + void InitCollectiveGroup(int worldSize, int rank, const std::string &groupName, Backend backend); + + void DestroyCollectiveGroup(const std::string &groupName); + +private: + CollectiveGroupMgr() = default; + + ~CollectiveGroupMgr() = default; + + std::recursive_mutex mtx_{}; + + std::unordered_map> groups_{}; +}; + +} // namespace YR::collective diff --git a/api/cpp/include/yr/collective/reduce_op.h b/api/cpp/include/yr/collective/reduce_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c1be9aebca2800d5fa28f5fa01c356af4dac82da --- /dev/null +++ b/api/cpp/include/yr/collective/reduce_op.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace YR { + +enum DataType : uint8_t { + INT = 0, + DOUBLE, + + INVALID +}; + +enum ReduceOp : uint8_t { + SUM = 0, + PRODUCT = 1, + MIN = 2, + MAX = 3, +}; +} diff --git a/api/cpp/src/collective/collective.cpp b/api/cpp/src/collective/collective.cpp new file mode 100644 index 0000000000000000000000000000000000000000..24a553fcd5d30755941787025f134a043065fe2f --- /dev/null +++ b/api/cpp/src/collective/collective.cpp @@ -0,0 +1,232 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "yr/collective/collective.h" + +#include +#include +#include + +#include "api/cpp/src/utils/utils.h" +#include "gloo_collective_group.h" +#include "src/dto/config.h" +#include "src/libruntime/err_type.h" +#include "src/utility/logger/logger.h" +#include "yr/api/exception.h" +#include "yr/yr.h" + + +namespace YR::collective { + +const std::string COLLECTIVE_GROUP_INFO_PREFIX = "/collective_group/"; + +struct CollectiveGroupInfo { + std::vector instances; + int worldSize; + std::vector ranks; + const std::string groupName; + Backend backend; + + std::string ToString() + { + nlohmann::json j = + nlohmann::json{{"instances", instances}, {"worldSize", worldSize}, {"ranks", ranks}, {"backend", backend}}; + return j.dump(); + } + + void FromJson(const std::string &str) + { + nlohmann::json j; + try { + j = nlohmann::json::parse(str); + } catch (const std::exception &e) { + YRLOG_WARN("json parse error: {}", e.what()); + return; + } + + j.at("instances").get_to(instances); + j.at("worldSize").get_to(worldSize); + j.at("ranks").get_to(ranks); + j.at("backend").get_to(backend); + } +}; + +int CollectiveGroup::GetRank() const +{ + return rank_; +} + +std::string CollectiveGroup::GetGroupName() +{ + return groupName_; +} + +Backend CollectiveGroup::GetBackend() +{ + return backend_; +} + +int CollectiveGroup::GetWorldSize() const +{ + return worldSize_; +} + +void InitCollectiveGroup(int worldSize, int rank, const std::string &groupName, Backend backend) +{ + CollectiveGroupMgr::GetInstance().InitCollectiveGroup(worldSize, rank, groupName, backend); +} + +void CreateCollectiveGroup(const std::vector &instanceIDs, int worldSize, const std::vector &ranks, + const std::string &groupName, Backend backend) +{ + THROW_IF_TRUE( + instanceIDs.size() != static_cast(worldSize) || static_cast(worldSize) != ranks.size(), + YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "failed to create collective group, unequal actor, rank or world size, please check"); + CollectiveGroupInfo info{ + .instances = instanceIDs, .worldSize = worldSize, .ranks = ranks, .groupName = groupName, .backend = backend}; + + YR::KVManager::Set(COLLECTIVE_GROUP_INFO_PREFIX + groupName, info.ToString()); +} + +void DestroyCollectiveGroup(const std::string &groupName) +{ + CollectiveGroupMgr::GetInstance().DestroyCollectiveGroup(groupName); +} + +void AllReduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, + const std::string &groupName) +{ + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + "first"); + group->AllReduce(sendbuf, recvbuf, count, dtype, op); +} + +void Reduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, int dstRank, + const std::string &groupName) +{ + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + "first"); + group->Reduce(sendbuf, recvbuf, count, dtype, op, dstRank); +} + +void AllGather(const void *sendbuf, void *recvbuf, int count, DataType dtype, const std::string &groupName) +{ + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + "first"); + group->AllGather(sendbuf, recvbuf, count, dtype); +} + +void Barrier(const std::string &groupName) +{ + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + "first"); + group->Barrier(); +} + +void Scatter(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank, const std::string &groupName) +{ + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + "first"); + group->Scatter(sendbuf, recvbuf, count, dtype, srcRank); +} + +void Broadcast(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank, const std::string &groupName) +{ + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + "first"); + group->Broadcast(sendbuf, recvbuf, count, dtype, srcRank); +} + +void Recv(void *recvbuf, int count, int srcRank, int tag, const std::string &groupName) +{ + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + "first"); + group->Recv(recvbuf, count, srcRank, tag); +} + +void Send(const void *sendbuf, int count, int dstRank, int tag, const std::string &groupName) +{ + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + "first"); + group->Send(sendbuf, count, dstRank, tag); +} + +std::shared_ptr CollectiveGroupMgr::CheckAndCreateGroup(const std::string &groupName) +{ + std::lock_guard lock(mtx_); + if (groups_.find(groupName) != groups_.end()) { + return groups_[groupName]; + } + auto str = YR::KVManager::Get(COLLECTIVE_GROUP_INFO_PREFIX + groupName); + THROW_IF_TRUE(str.empty(), YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + "first"); + + CollectiveGroupInfo info; + info.FromJson(str); + + size_t index = 0; + for (; index < info.instances.size(); ++index) { + if (info.instances.at(index) == YR::Libruntime::Config::Instance().INSTANCE_ID()) { + break; + } + } + + THROW_IF_TRUE(index == info.instances.size(), YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "invalid instance id"); + InitCollectiveGroup(info.worldSize, info.ranks[index], groupName, info.backend); + return groups_[groupName]; +} + +void CollectiveGroupMgr::InitCollectiveGroup(int worldSize, int rank, const std::string &groupName, Backend backend) +{ + std::lock_guard lock(mtx_); + if (groups_.find(groupName) != groups_.end()) { + YRLOG_DEBUG("collective group({}) already existed", groupName); + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "collective group already existed"); + } + + std::shared_ptr group; + switch (backend) { + case Backend::GLOO: + group = std::make_shared(groupName, worldSize, rank); + break; + + case Backend::INVALID: + // fall-through + default: + YRLOG_ERROR("failed to init collective group({}), invalid backend: {}", groupName, backend); + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "invalid collective group backend"); + } + + groups_[groupName] = group; +} + +void CollectiveGroupMgr::DestroyCollectiveGroup(const std::string &groupName) +{ + std::lock_guard lock(mtx_); + YR::KVManager::Del(COLLECTIVE_GROUP_INFO_PREFIX + groupName); + groups_.erase(groupName); +} + +} // namespace YR::collective \ No newline at end of file diff --git a/api/cpp/src/collective/gloo_collective_group.cpp b/api/cpp/src/collective/gloo_collective_group.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cf67f09cf360e84d2b088dffd904f2d87dc95f21 --- /dev/null +++ b/api/cpp/src/collective/gloo_collective_group.cpp @@ -0,0 +1,219 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gloo_collective_group.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "api/cpp/src/utils/utils.h" +#include "src/dto/config.h" +#include "yr/yr.h" + +namespace YR::collective { + +#define EXECUTE_BY_TYPE(dType, OPERATION, ...) \ + switch (dType) { \ + case DataType::INT: \ + OPERATION(__VA_ARGS__); \ + break; \ + case DataType::DOUBLE: \ + OPERATION(__VA_ARGS__); \ + break; \ + case DataType::INVALID: \ + default: \ + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, \ + "invalid dType: " + std::to_string(dType)); \ + } + +struct DsStore : public gloo::rendezvous::Store { +public: + DsStore() = default; + ~DsStore() override = default; + + void set(const std::string &key, const std::vector &data) override + { + YR::KVManager::Set(key, std::string(data.begin(), data.end())); + } + + std::vector get(const std::string &key) override + { + std::vector out; + std::string result = YR::KVManager::Get(key); + out.assign(result.begin(), result.end()); + return out; + } + + void wait(const std::vector &keys) override + { + YR::KVManager::Get(keys); + } + + void wait(const std::vector &keys, const std::chrono::milliseconds &timeout) override + { + YR::KVManager::Get(keys, static_cast(timeout.count() / 1000)); + } +}; + +GlooCollectiveGroup::GlooCollectiveGroup(std::string groupName, int worldSize, int rank) + : CollectiveGroup(std::move(groupName), worldSize, rank, Backend::GLOO) +{ + auto dsStore = std::make_shared(); + auto prefixStore = std::make_shared(groupName_, dsStore); + + std::shared_ptr dev; + auto backend = GetEnv("GLOO_BACKEND_TYPE"); + if (backend.empty() || backend == "TCP") { + gloo::transport::tcp::attr attr; + attr.hostname = YR::Libruntime::Config::Instance().HOST_IP(); + dev = gloo::transport::tcp::CreateDevice(attr); + } else if (backend == "IBVERBS") { + gloo::transport::ibverbs::attr attr; + attr.name = GetEnv("GLOO_IBVERBS_NAME"); + dev = gloo::transport::ibverbs::CreateDevice(attr); + } else { + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "invalid backend type: " + backend); + } + + context_ = std::make_shared(rank_, worldSize_); + context_->connectFullMesh(prefixStore, dev); +} + +void GlooCollectiveGroup::AllReduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op) +{ + EXECUTE_BY_TYPE(dtype, DoAllReduce, sendbuf, recvbuf, count, op); +} + +void GlooCollectiveGroup::Reduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, + int dstRank) +{ + EXECUTE_BY_TYPE(dtype, DoReduce, sendbuf, recvbuf, count, op, dstRank); +} + +void GlooCollectiveGroup::AllGather(const void *sendbuf, void *recvbuf, int count, DataType dtype) +{ + EXECUTE_BY_TYPE(dtype, DoAllGather, sendbuf, recvbuf, count); +} + +void GlooCollectiveGroup::Barrier() +{ + gloo::BarrierOptions opts(context_); + gloo::barrier(opts); +} + +void GlooCollectiveGroup::Scatter(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank) +{ + EXECUTE_BY_TYPE(dtype, DoScatter, sendbuf, recvbuf, count, srcRank); +} + +void GlooCollectiveGroup::Broadcast(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank) +{ + EXECUTE_BY_TYPE(dtype, DoBroadcast, sendbuf, recvbuf, count, srcRank); +} + +void GlooCollectiveGroup::Recv(void *recvbuf, int count, int srcRank, int tag) +{ + auto ubuf = context_->createUnboundBuffer(recvbuf, count); + ubuf->recv(srcRank, tag); + ubuf->waitRecv(); +} + +void GlooCollectiveGroup::Send(const void *sendbuf, int count, int dstRank, int tag) +{ + auto ubuf = context_->createUnboundBuffer(const_cast(sendbuf), count); + ubuf->send(dstRank, tag); + ubuf->waitSend(); +} + +template +gloo::AllreduceOptions::Func GlooCollectiveGroup::GetReduceOp(const ReduceOp &op) +{ + void (*fn)(void *, const void *, const void *, long unsigned int) = &gloo::sum; + switch (op) { + case ReduceOp::SUM: + fn = &gloo::sum; + break; + case ReduceOp::PRODUCT: + fn = &gloo::product; + break; + case ReduceOp::MIN: + fn = &gloo::min; + break; + case ReduceOp::MAX: + fn = &gloo::max; + break; + } + return fn; +} + +template +void GlooCollectiveGroup::DoAllReduce(const void *sendbuf, void *recvbuf, int count, const ReduceOp &op) +{ + gloo::AllreduceOptions opts(context_); + opts.setInput(const_cast((T *)sendbuf), count); + opts.setOutput(static_cast(recvbuf), count); + opts.setReduceFunction(GetReduceOp(op)); + gloo::allreduce(opts); +} + +template +void GlooCollectiveGroup::DoReduce(const void *sendbuf, void *recvbuf, int count, const ReduceOp &op, int dstRank) +{ + gloo::ReduceOptions opts(context_); + opts.setInput(const_cast((T *)sendbuf), count); + opts.setOutput(static_cast(recvbuf), count); + opts.setReduceFunction(GetReduceOp(op)); + gloo::reduce(opts); +} + +template +void GlooCollectiveGroup::DoBroadcast(const void *sendbuf, void *recvbuf, int count, int srcRank) +{ + gloo::BroadcastOptions opts(context_); + opts.setInput(const_cast((T *)sendbuf), count); + opts.setOutput(static_cast(recvbuf), count); + opts.setRoot(srcRank); + gloo::broadcast(opts); +} + +template +void GlooCollectiveGroup::DoScatter(const void *sendbuf, void *recvbuf, int count, int srcRank) +{ + gloo::ScatterOptions opts(context_); + std::vector inputs = {const_cast((T *)sendbuf)}; + opts.setInputs(inputs, count); + opts.setOutput(static_cast(recvbuf), count); + opts.setRoot(srcRank); + gloo::scatter(opts); +} + +template +void GlooCollectiveGroup::DoAllGather(const void *sendbuf, void *recvbuf, int count) +{ + gloo::AllgatherOptions opts(context_); + opts.setInput(const_cast((T *)sendbuf), count); + opts.setOutput(static_cast(recvbuf), count); + gloo::allgather(opts); +} +} // namespace YR::collective \ No newline at end of file diff --git a/api/cpp/src/collective/gloo_collective_group.h b/api/cpp/src/collective/gloo_collective_group.h new file mode 100644 index 0000000000000000000000000000000000000000..1682d352bd00ecc310bb7a1aad5abe643d408f51 --- /dev/null +++ b/api/cpp/src/collective/gloo_collective_group.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +#include "yr/collective/collective.h" + +namespace YR::collective { +class GlooCollectiveGroup : public CollectiveGroup { +public: + GlooCollectiveGroup(std::string groupName, int worldSize, int rank); + + void AllReduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op) override; + + void Reduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, + int dstRank) override; + + void AllGather(const void *sendbuf, void *recvbuf, int count, DataType dtype) override; + + void Barrier() override; + + void Scatter(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank) override; + + void Broadcast(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank) override; + + void Recv(void *recvbuf, int count, int srcRank, int tag) override; + + void Send(const void *sendbuf, int count, int dstRank, int tag) override; + +private: + template + void DoAllReduce(const void *sendbuf, void *recvbuf, int count, const ReduceOp &op); + + template + void DoReduce(const void *sendbuf, void *recvbuf, int count, const ReduceOp &op, int dstRank); + + template + void DoAllGather(const void *sendbuf, void *recvbuf, int count); + + template + void DoScatter(const void *sendbuf, void *recvbuf, int count, int srcRank); + + template + void DoBroadcast(const void *sendbuf, void *recvbuf, int count, int srcRank); + + template + static gloo::AllreduceOptions::Func GetReduceOp(const ReduceOp &op); + + std::unordered_map map; + + std::shared_ptr context_; +}; +} // namespace YR::collective \ No newline at end of file diff --git a/api/cpp/src/utils/utils.h b/api/cpp/src/utils/utils.h index 683260df43dae728eb3ec1aee86ada3eb4e026b9..afdb5dd17e5f7b5ff74b99f8cc9fa069029116de 100644 --- a/api/cpp/src/utils/utils.h +++ b/api/cpp/src/utils/utils.h @@ -22,6 +22,18 @@ #include "src/dto/data_object.h" namespace YR { + +#ifdef THROW_IF_TRUE +#undef THROW_IF_TRUE +#endif +#define THROW_IF_TRUE(x, c, m) \ + do { \ + if ((x)) { \ + YRLOG_ERROR(m); \ + throw YR::Exception(c, m); \ + } \ + } while (false) + std::string ConvertFunctionUrnToId(const std::string &functionUrn); YR::Libruntime::ErrorInfo WriteDataObject(const void *data, std::shared_ptr &dataObj, uint64_t size, const std::unordered_set &nestedIds); diff --git a/bazel/gloo.bzl b/bazel/gloo.bzl new file mode 100644 index 0000000000000000000000000000000000000000..a72f795862f51645c1643c5485b3c9aa9e9d4867 --- /dev/null +++ b/bazel/gloo.bzl @@ -0,0 +1,6 @@ +cc_import( + name = "gloo", + hdrs = glob(["include/gloo/**/*.h"]), + shared_library = "lib/libgloo.so", + visibility = ["//visibility:public"], +) \ No newline at end of file diff --git a/install_tools.sh b/install_tools.sh index 699c1e98561a2e8d830b0f7a23837c64659fa33e..2e26c2e44113acf51c06af8ed62dd5f5b5c94638 100644 --- a/install_tools.sh +++ b/install_tools.sh @@ -35,7 +35,7 @@ sudo yum install -y \ sqlite-devel bison gawk texinfo glibc glibc-devel wget bzip2-devel sudo \ rsync nfs-utils xz libuuid unzip util-linux-devel cpio libcap-devel libatomic \ chrpath numactl-devel openeuler-lsb libasan dos2unix net-tools pigz cmake \ - protobuf protobuf-devel patchelf + protobuf protobuf-devel patchelf libibverbs-devel echo "✅ Detected architecture: $ARCH → using package suffix: $PKG_ARCH" diff --git a/tools/download_dependency.sh b/tools/download_dependency.sh index 44e72c80bf64592f1fa8ce3187043659f7732912..5b6d85e4d6e3b34e96b749b7dce92640fc6974f9 100644 --- a/tools/download_dependency.sh +++ b/tools/download_dependency.sh @@ -152,6 +152,16 @@ function compile_all(){ fi popd fi +if [ ! -d "${THIRD_PARTY_DIR}/gloo/build" ]; then + pushd "${THIRD_PARTY_DIR}/gloo/" + chmod -R 700 "${THIRD_PARTY_DIR}/gloo/" + mkdir build + cd build + cmake .. -DBUILD_SHARED_LIBS=ON -DUSE_IBVERBS=ON -DCMAKE_BUILD_TYPE=Release -DGLOO_INSTALL=ON -DCMAKE_INSTALL_PREFIX=${THIRD_PARTY_DIR}/gloo/install + make -j4 + make install + popd +fi } function download_third_party_cache() { diff --git a/tools/openSource.txt b/tools/openSource.txt index 41c17b82d565d187eef4309c3eb1008d237aef67..8ad4c8b22792231591561032a6ced94edc1d60ae 100644 --- a/tools/openSource.txt +++ b/tools/openSource.txt @@ -17,4 +17,5 @@ libboundscheck,v1.1.16,runtime;litebus;functionsystem;metrics,https://gitee.com/ curl,8.8.0,litebus,https://gitee.com/mirrors/curl/repository/archive/curl-8_8_0.zip,73c70c94f487c5ae26f9f27094249e40bb1667ae6c0406a75c3b11f86f0c1128,test gtest_1_12_1,release-1.12.1,litebus;functionsystem;metrics;logs,https://gitee.com/mirrors/googletest/repository/archive/release-1.12.1.zip,4f1037f17462dbcc4d715f2fc8212c03431ad06e7a4b73bb49b3f534a13b21f1,test gtest_1_10_0,release-1.10.0,functionsystem,https://gitee.com/mirrors/googletest/repository/archive/release-1.10.0.zip,db55081e6b5a5f820d052f8d37e0d29835dd5a2bd5b8d4d50590b7834f9abfae,test +gloo,main,runtime,https://github.com/pytorch/gloo/archive/refs/heads/main.zip,ddfb9627304025f294c78fd75ae84d7b9a662213d1b1c955afe6f91bfeb49254, ,,,,,, \ No newline at end of file