From ee6ce79a5b443de67837944a8f762827563edb2e Mon Sep 17 00:00:00 2001 From: mhsong Date: Mon, 17 Nov 2025 10:00:00 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=9F=BA=E4=BA=8Egloo?= =?UTF-8?q?=E9=9B=86=E5=90=88=E9=80=9A=E4=BF=A1=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- WORKSPACE | 6 + api/cpp/BUILD.bazel | 1 + api/cpp/include/yr/collective/collective.h | 139 +++++++++++ .../include/yr/collective/data_descriptor.h | 35 +++ api/cpp/include/yr/collective/reduce_op.h | 28 +++ api/cpp/src/collective/collective.cpp | 223 ++++++++++++++++++ .../src/collective/gloo_collective_group.cpp | 205 ++++++++++++++++ .../src/collective/gloo_collective_group.h | 67 ++++++ api/cpp/src/utils/utils.h | 12 + bazel/gloo.bzl | 7 + tools/openSource.txt | 1 + 11 files changed, 724 insertions(+) create mode 100644 api/cpp/include/yr/collective/collective.h create mode 100644 api/cpp/include/yr/collective/data_descriptor.h create mode 100644 api/cpp/include/yr/collective/reduce_op.h create mode 100644 api/cpp/src/collective/collective.cpp create mode 100644 api/cpp/src/collective/gloo_collective_group.cpp create mode 100644 api/cpp/src/collective/gloo_collective_group.h create mode 100644 bazel/gloo.bzl diff --git a/WORKSPACE b/WORKSPACE index e5074a1..d79ba53 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -149,6 +149,12 @@ http_archive( build_file = "@//bazel:jacoco.bzl" ) +new_local_repository( + name = "gloo", + build_file = "//third_party:gloo.bzl", + path = "../third_party/gloo", +) + maybe( new_local_repository, name = "securec", diff --git a/api/cpp/BUILD.bazel b/api/cpp/BUILD.bazel index 6f14e44..e65fd40 100644 --- a/api/cpp/BUILD.bazel +++ b/api/cpp/BUILD.bazel @@ -41,6 +41,7 @@ cc_library( "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", "@securec", + "@gloo//:gloo", ], alwayslink = True, ) diff --git a/api/cpp/include/yr/collective/collective.h b/api/cpp/include/yr/collective/collective.h new file mode 100644 index 0000000..5ad07c2 --- /dev/null +++ b/api/cpp/include/yr/collective/collective.h @@ -0,0 +1,139 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include "data_descriptor.h" +#include "reduce_op.h" + +namespace YR::collective { +enum Backend : uint8_t { + GLOO = 0, + + INVALID, +}; + +class CollectiveGroup { +public: + CollectiveGroup(std::string groupName, int worldSize, int rank, Backend backend) + : rank_(rank), groupName_(std::move(groupName)), backend_(backend), worldSize_(worldSize) + { + } + + virtual void AllReduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op) = 0; + + virtual void Reduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op, int dstRank) = 0; + + virtual void AllGather(const DataDescriptor &input, DataDescriptor &output) = 0; + + virtual void Barrier() = 0; + + virtual void Scatter(const DataDescriptor &input, DataDescriptor &output, int srcRank) = 0; + + virtual void Broadcast(const DataDescriptor &input, DataDescriptor &output, int srcRank) = 0; + + virtual void Recv(DataDescriptor &output, int srcRank, int tag) = 0; + + virtual void Send(const DataDescriptor &input, int dstRank, int tag) = 0; + + int GetRank() const; + std::string GetGroupName(); + Backend GetBackend(); + int GetWorldSize() const; + +protected: + std::string groupName_; + + int rank_; + + Backend backend_; + + int worldSize_; +}; + +/** + * init collective group in actor + * + * @param worldSize + * @param rank + * @param groupName + */ +void InitCollectiveGroup(int worldSize, int rank, const std::string &groupName, Backend backend); + +/** + * create collective group with actor ids in driver + * + * @param instanceIDs + * @param worldSize + * @param ranks + * @param groupName + */ +void CreateCollectiveGroup(const std::vector &instanceIDs, int worldSize, const std::vector &ranks, + const std::string &groupName, Backend backend); +/** + * + * + * @param groupName + */ +void DestroyCollectiveGroup(const std::string &groupName); + +void AllReduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op, const std::string &groupName); + +void Reduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op, int dstRank, + const std::string &groupName); + +void AllGather(const DataDescriptor &input, DataDescriptor &output, const std::string &groupName); + +void Barrier(const std::string &groupName); + +void Scatter(const DataDescriptor &input, DataDescriptor &output, int srcRank, const std::string &groupName); + +void Broadcast(const DataDescriptor &input, DataDescriptor &output, int srcRank, const std::string &groupName); + +void Recv(DataDescriptor &output, int srcRank, int tag, const std::string &groupName); + +void Send(const DataDescriptor &input, int dstRank, int tag, const std::string &groupName); + +class CollectiveGroupMgr { +public: + static CollectiveGroupMgr &GetInstance() + { + static CollectiveGroupMgr instance; + return instance; + } + + std::shared_ptr CheckAndCreateGroup(const std::string &groupName); + + void InitCollectiveGroup(int worldSize, int rank, const std::string &groupName, Backend backend); + + void DestroyCollectiveGroup(const std::string &groupName); + +private: + CollectiveGroupMgr() = default; + + ~CollectiveGroupMgr() = default; + + std::recursive_mutex mtx_{}; + + std::unordered_map> groups_{}; +}; + +} // namespace YR::collective diff --git a/api/cpp/include/yr/collective/data_descriptor.h b/api/cpp/include/yr/collective/data_descriptor.h new file mode 100644 index 0000000..f22f849 --- /dev/null +++ b/api/cpp/include/yr/collective/data_descriptor.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace YR { +enum DataType : uint8_t { + INT = 0, + DOUBLE, + + INVALID +}; + +struct DataDescriptor { + void *buf = nullptr; + uint64_t count = 0; + DataType dType = DataType::INVALID; +}; +} // namespace YR \ No newline at end of file diff --git a/api/cpp/include/yr/collective/reduce_op.h b/api/cpp/include/yr/collective/reduce_op.h new file mode 100644 index 0000000..25f1b62 --- /dev/null +++ b/api/cpp/include/yr/collective/reduce_op.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace YR { +enum ReduceOp : uint8_t { + SUM = 0, + PRODUCT = 1, + MIN = 2, + MAX = 3, +}; +} diff --git a/api/cpp/src/collective/collective.cpp b/api/cpp/src/collective/collective.cpp new file mode 100644 index 0000000..f2507b7 --- /dev/null +++ b/api/cpp/src/collective/collective.cpp @@ -0,0 +1,223 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "yr/collective/collective.h" + +#include +#include + +#include "api/cpp/src/utils/utils.h" +#include "gloo_collective_group.h" +#include "src/libruntime/err_type.h" +#include "src/utility/logger/logger.h" +#include "yr/api/exception.h" +#include "yr/api/kv_manager.h" + +namespace YR::collective { + +const std::string COLLECTIVE_GROUP_INFO_PREFIX = "/collective_group/"; + +struct CollectiveGroupInfo { + std::vector instances; + int worldSize; + std::vector ranks; + const std::string groupName; + Backend backend; + + std::string ToString() + { + nlohmann::json j = + nlohmann::json{{"instances", instances}, {"worldSize", worldSize}, {"ranks", ranks}, {"backend", backend}}; + return j.dump(); + } + + void FromJson(const std::string &str) + { + nlohmann::json j; + try { + j = nlohmann::json::parse(str); + } catch (const std::exception &e) { + YRLOG_WARN("json parse error: {}", e.what()); + return; + } + + j.at("instances").get_to(instances); + j.at("worldSize").get_to(worldSize); + j.at("ranks").get_to(ranks); + j.at("backend").get_to(backend); + } +}; + +int CollectiveGroup::GetRank() const +{ + return rank_; +} + +std::string CollectiveGroup::GetGroupName() +{ + return groupName_; +} + +Backend CollectiveGroup::GetBackend() +{ + return backend_; +} + +int CollectiveGroup::GetWorldSize() const +{ + return worldSize_; +} + +void InitCollectiveGroup(int worldSize, int rank, const std::string &groupName, Backend backend) +{ + CollectiveGroupMgr::GetInstance().InitCollectiveGroup(worldSize, rank, groupName, backend); +} + +void CreateCollectiveGroup(const std::vector &instanceIDs, int worldSize, const std::vector &ranks, + const std::string &groupName, Backend backend) +{ + // todo 前置参数校验 + CollectiveGroupInfo info{ + .instances = instanceIDs, .worldSize = worldSize, .ranks = ranks, .groupName = groupName, .backend = backend}; + + YR::KVManager::Set(COLLECTIVE_GROUP_INFO_PREFIX + groupName, info.ToString()); +} + +void DestroyCollectiveGroup(const std::string &groupName) +{ + CollectiveGroupMgr::GetInstance().DestroyCollectiveGroup(groupName); +} + +void AllReduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op, const std::string &groupName) +{ + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "group: " + groupName + " is not init"); + group->AllReduce(input, output, op); +} + +void Reduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op, int dstRank, + const std::string &groupName) +{ + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "group: " + groupName + " is not init"); + group->Reduce(input, output, op, dstRank); +} + +void AllGather(const DataDescriptor &input, DataDescriptor &output, const std::string &groupName) +{ + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "group: " + groupName + " is not init"); + group->AllGather(input, output); +} + +void Barrier(const std::string &groupName) +{ + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "group: " + groupName + " is not init"); + group->Barrier(); +} + +void Scatter(const DataDescriptor &input, DataDescriptor &output, int srcRank, const std::string &groupName) +{ + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "group: " + groupName + " is not init"); + group->Scatter(input, output, srcRank); +} + +void Broadcast(const DataDescriptor &input, DataDescriptor &output, int srcRank, const std::string &groupName) +{ + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "group: " + groupName + " is not init"); + group->Barrier(); +} + +void Recv(DataDescriptor &output, int srcRank, int tag, const std::string &groupName) +{ + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "group: " + groupName + " is not init"); + group->Recv(output, srcRank, tag); +} + +void Send(const DataDescriptor &input, int dstRank, int tag, const std::string &groupName) +{ + auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); + THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "group: " + groupName + " is not init"); + group->Send(input, dstRank, tag); +} + +std::shared_ptr CollectiveGroupMgr::CheckAndCreateGroup(const std::string &groupName) +{ + std::lock_guard lock(mtx_); // todo 看下锁的范围 + if (groups_.find(groupName) != groups_.end()) { + return groups_[groupName]; + } + auto str = YR::KVManager::Get(COLLECTIVE_GROUP_INFO_PREFIX + groupName); + THROW_IF_TRUE(str.empty(), YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "group: " + groupName + " is not init"); + + CollectiveGroupInfo info; + info.FromJson(str); + + int index = 0; + for (; index < info.instances.size(); ++index) { + if (info.instances.at(index) == "GetInstanceID") { // todo 获取instanceID + break; + } + } + + THROW_IF_TRUE(index == info.instances.size(), YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "invalid instance id"); + InitCollectiveGroup(info.worldSize, info.ranks[index], groupName, info.backend); + return groups_[groupName]; +} + +void CollectiveGroupMgr::InitCollectiveGroup(int worldSize, int rank, const std::string &groupName, Backend backend) +{ + std::lock_guard lock(mtx_); + if (groups_.find(groupName) != groups_.end()) { + YRLOG_DEBUG("collective group({}) already existed", groupName); + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "collective group already existed"); + } + + std::shared_ptr group; + switch (backend) { + case Backend::GLOO: + group = std::make_shared(groupName, worldSize, rank); + break; + + case Backend::INVALID: + // fall-through + default: + YRLOG_ERROR("failed to init collective group({}), invalid backend: {}", groupName, backend); + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "invalid collective group backend"); + } + + groups_[groupName] = group; +} + +void CollectiveGroupMgr::DestroyCollectiveGroup(const std::string &groupName) +{ + std::lock_guard lock(mtx_); + groups_.erase(groupName); +} + +} // namespace YR::collective \ No newline at end of file diff --git a/api/cpp/src/collective/gloo_collective_group.cpp b/api/cpp/src/collective/gloo_collective_group.cpp new file mode 100644 index 0000000..1bc096f --- /dev/null +++ b/api/cpp/src/collective/gloo_collective_group.cpp @@ -0,0 +1,205 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gloo_collective_group.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "api/cpp/src/utils/utils.h" +#include "yr/api/kv_manager.h" + +namespace YR::collective { + +#define EXECUTE_BY_TYPE(dType, OPERATION, ...) \ + switch (dType) { \ + case DataType::INT: \ + OPERATION(__VA_ARGS__); \ + break; \ + case DataType::DOUBLE: \ + OPERATION(__VA_ARGS__); \ + break; \ + case DataType::INVALID: \ + default: \ + throw YR::Exception(YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, \ + "invalid dType: " + std::to_string(dType)); \ + } + +struct DsStore : public gloo::rendezvous::Store { +public: + DsStore() = default; + ~DsStore() override = default; + + void set(const std::string &key, const std::vector &data) override + { + YR::KVManager::Set(key, std::string(data.begin(), data.end())); + } + + std::vector get(const std::string &key) override + { + std::vector out; + std::string result = YR::KVManager::Get(key); + out.assign(result.begin(), result.end()); + return out; + } + + void wait(const std::vector &keys) override + { + YR::KVManager::Get(keys); + } + + void wait(const std::vector &keys, const std::chrono::milliseconds &timeout) override + { + YR::KVManager::Get(keys, static_cast(timeout.count() / 1000)); + } +}; + +GlooCollectiveGroup::GlooCollectiveGroup(std::string groupName, int worldSize, int rank) + : CollectiveGroup(std::move(groupName), worldSize, rank, Backend::GLOO) +{ + auto dsStore = std::make_shared(); + auto prefixStore = std::make_shared(groupName_, dsStore); + + gloo::transport::tcp::attr attr; + attr.hostname = ""; // todo 从环境变量里取本地地址 也支持环境变量配置使用哪种后端 + auto dev = gloo::transport::tcp::CreateDevice(attr); + + context_ = std::make_shared(rank_, worldSize_); + context_->connectFullMesh(prefixStore, dev); +} + +void GlooCollectiveGroup::AllReduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op) +{ + EXECUTE_BY_TYPE(input.dType, DoAllReduce, input, output, op); +} + +void GlooCollectiveGroup::Reduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op, int dstRank) +{ + EXECUTE_BY_TYPE(input.dType, DoReduce, input, output, op, dstRank); +} + +void GlooCollectiveGroup::AllGather(const DataDescriptor &input, DataDescriptor &output) +{ + EXECUTE_BY_TYPE(input.dType, DoAllGather, input, output); +} + +void GlooCollectiveGroup::Barrier() +{ + gloo::BarrierOptions opts(context_); + gloo::barrier(opts); +} + +void GlooCollectiveGroup::Scatter(const DataDescriptor &input, DataDescriptor &output, int srcRank) +{ + EXECUTE_BY_TYPE(input.dType, DoScatter, input, output, srcRank); +} + +void GlooCollectiveGroup::Broadcast(const DataDescriptor &input, DataDescriptor &output, int srcRank) +{ + EXECUTE_BY_TYPE(input.dType, DoBroadcast, input, output, srcRank); +} + +void GlooCollectiveGroup::Recv(DataDescriptor &output, int srcRank, int tag) +{ + auto ubuf = context_->createUnboundBuffer(output.buf, output.count); + ubuf->recv(srcRank, tag); + ubuf->waitRecv(); +} + +void GlooCollectiveGroup::Send(const DataDescriptor &input, int dstRank, int tag) +{ + auto ubuf = context_->createUnboundBuffer(input.buf, input.count); + ubuf->send(dstRank, tag); + ubuf->waitSend(); +} + +template +gloo::AllreduceOptions::Func GlooCollectiveGroup::GetReduceOp(const ReduceOp &op) +{ + void (*fn)(void *, const void *, const void *, long unsigned int) = &gloo::sum; + switch (op) { + case ReduceOp::SUM: + fn = &gloo::sum; + break; + case ReduceOp::PRODUCT: + fn = &gloo::product; + break; + case ReduceOp::MIN: + fn = &gloo::min; + break; + case ReduceOp::MAX: + fn = &gloo::max; + break; + } + return fn; +} + +template +void GlooCollectiveGroup::DoAllReduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op) +{ + gloo::AllreduceOptions opts(context_); + opts.setInput(static_cast(input.buf), input.count); + opts.setOutput(static_cast(output.buf), output.count); + opts.setReduceFunction(GetReduceOp(op)); + gloo::allreduce(opts); +} + +template +void GlooCollectiveGroup::DoReduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op, int dstRank) +{ + gloo::ReduceOptions opts(context_); + opts.setInput(static_cast(input.buf), input.count); + opts.setOutput(static_cast(output.buf), output.count); + opts.setReduceFunction(GetReduceOp(op)); + gloo::reduce(opts); +} + +template +void GlooCollectiveGroup::DoBroadcast(const DataDescriptor &input, DataDescriptor &output, int srcRank) +{ + gloo::BroadcastOptions opts(context_); + opts.setInput(static_cast(input.buf), input.count); + opts.setOutput(static_cast(output.buf), output.count); + opts.setRoot(srcRank); + gloo::broadcast(opts); +} + +template +void GlooCollectiveGroup::DoScatter(const DataDescriptor &input, DataDescriptor &output, int srcRank) +{ + gloo::ScatterOptions opts(context_); + std::vector inputs = {static_cast(input.buf)}; + opts.setInputs(inputs, input.count); + opts.setOutput(static_cast(output.buf), output.count); + opts.setRoot(srcRank); + gloo::scatter(opts); +} + +template +void GlooCollectiveGroup::DoAllGather(const DataDescriptor &input, DataDescriptor &output) +{ + gloo::AllgatherOptions opts(context_); + opts.setInput(static_cast(input.buf), input.count); + opts.setOutput(static_cast(output.buf), output.count); + gloo::allgather(opts); +} +} // namespace YR::collective \ No newline at end of file diff --git a/api/cpp/src/collective/gloo_collective_group.h b/api/cpp/src/collective/gloo_collective_group.h new file mode 100644 index 0000000..9a0bf98 --- /dev/null +++ b/api/cpp/src/collective/gloo_collective_group.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +#include "yr/collective/collective.h" + +namespace YR::collective { +class GlooCollectiveGroup : public CollectiveGroup { +public: + GlooCollectiveGroup(std::string groupName, int worldSize, int rank); + + void AllReduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op) override; + + void Reduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op, int dstRank) override; + + void AllGather(const DataDescriptor &input, DataDescriptor &output) override; + + void Barrier() override; + + void Scatter(const DataDescriptor &input, DataDescriptor &output, int srcRank) override; + + void Broadcast(const DataDescriptor &input, DataDescriptor &output, int srcRank) override; + + void Recv(DataDescriptor &output, int srcRank, int tag) override; + + void Send(const DataDescriptor &input, int dstRank, int tag) override; + +private: + template + void DoAllReduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op); + + template + void DoReduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op, int dstRank); + + template + void DoAllGather(const DataDescriptor &input, DataDescriptor &output); + + template + void DoScatter(const DataDescriptor &input, DataDescriptor &output, int srcRank); + + template + void DoBroadcast(const DataDescriptor &input, DataDescriptor &output, int srcRank); + + template + static gloo::AllreduceOptions::Func GetReduceOp(const ReduceOp &op); + + std::unordered_map map; + + std::shared_ptr context_; +}; +} // namespace YR::collective \ No newline at end of file diff --git a/api/cpp/src/utils/utils.h b/api/cpp/src/utils/utils.h index 683260d..afdb5dd 100644 --- a/api/cpp/src/utils/utils.h +++ b/api/cpp/src/utils/utils.h @@ -22,6 +22,18 @@ #include "src/dto/data_object.h" namespace YR { + +#ifdef THROW_IF_TRUE +#undef THROW_IF_TRUE +#endif +#define THROW_IF_TRUE(x, c, m) \ + do { \ + if ((x)) { \ + YRLOG_ERROR(m); \ + throw YR::Exception(c, m); \ + } \ + } while (false) + std::string ConvertFunctionUrnToId(const std::string &functionUrn); YR::Libruntime::ErrorInfo WriteDataObject(const void *data, std::shared_ptr &dataObj, uint64_t size, const std::unordered_set &nestedIds); diff --git a/bazel/gloo.bzl b/bazel/gloo.bzl new file mode 100644 index 0000000..f5d6ff2 --- /dev/null +++ b/bazel/gloo.bzl @@ -0,0 +1,7 @@ +cc_library( + name = "gloo", + srcs = [], + hdrs = glob(["gloo/**/*.h"]), + includes = ["include"], # 指定头文件根目录 + visibility = ["//visibility:public"], +) \ No newline at end of file diff --git a/tools/openSource.txt b/tools/openSource.txt index 79deccd..4cdb1d0 100644 --- a/tools/openSource.txt +++ b/tools/openSource.txt @@ -17,4 +17,5 @@ libboundscheck,v1.1.16,runtime;litebus;functionsystem;metrics,https://gitee.com/ curl,8.8.0,litebus,https://gitee.com/mirrors/curl/repository/archive/curl-8_8_0.zip,73c70c94f487c5ae26f9f27094249e40bb1667ae6c0406a75c3b11f86f0c1128,test gtest_1_12_1,release-1.12.1,litebus;functionsystem;metrics;logs,https://gitee.com/mirrors/googletest/repository/archive/release-1.12.1.zip,4f1037f17462dbcc4d715f2fc8212c03431ad06e7a4b73bb49b3f534a13b21f1,test gtest_1_10_0,release-1.10.0,functionsystem,https://gitee.com/mirrors/googletest/repository/archive/release-1.10.0.zip,db55081e6b5a5f820d052f8d37e0d29835dd5a2bd5b8d4d50590b7834f9abfae,test +gloo,main,runtime,https://github.com/pytorch/gloo/archive/refs/heads/main.zip,ddfb9627304025f294c78fd75ae84d7b9a662213d1b1c955afe6f91bfeb49254, ,,,,,, \ No newline at end of file -- Gitee From 2cfc674165396b97782a274d1fb3c2acac59312b Mon Sep 17 00:00:00 2001 From: mhsong Date: Mon, 17 Nov 2025 22:41:48 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=A3=80=E8=A7=86?= =?UTF-8?q?=E6=84=8F=E8=A7=81=EF=BC=8C=E8=B0=83=E6=95=B4=E5=AF=B9=E5=A4=96?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=E5=AE=9A=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/cpp/include/yr/collective/collective.h | 32 +++++----- .../include/yr/collective/data_descriptor.h | 35 ----------- api/cpp/include/yr/collective/reduce_op.h | 8 +++ api/cpp/src/collective/collective.cpp | 27 ++++---- .../src/collective/gloo_collective_group.cpp | 61 ++++++++++--------- .../src/collective/gloo_collective_group.h | 25 ++++---- 6 files changed, 83 insertions(+), 105 deletions(-) delete mode 100644 api/cpp/include/yr/collective/data_descriptor.h diff --git a/api/cpp/include/yr/collective/collective.h b/api/cpp/include/yr/collective/collective.h index 5ad07c2..3d50df6 100644 --- a/api/cpp/include/yr/collective/collective.h +++ b/api/cpp/include/yr/collective/collective.h @@ -21,7 +21,6 @@ #include #include -#include "data_descriptor.h" #include "reduce_op.h" namespace YR::collective { @@ -38,21 +37,22 @@ public: { } - virtual void AllReduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op) = 0; + virtual void AllReduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op) = 0; - virtual void Reduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op, int dstRank) = 0; + virtual void Reduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, + int dstRank) = 0; - virtual void AllGather(const DataDescriptor &input, DataDescriptor &output) = 0; + virtual void AllGather(const void *sendbuf, void *recvbuf, int count, DataType dtype) = 0; virtual void Barrier() = 0; - virtual void Scatter(const DataDescriptor &input, DataDescriptor &output, int srcRank) = 0; + virtual void Scatter(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank) = 0; - virtual void Broadcast(const DataDescriptor &input, DataDescriptor &output, int srcRank) = 0; + virtual void Broadcast(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank) = 0; - virtual void Recv(DataDescriptor &output, int srcRank, int tag) = 0; + virtual void Recv(void *recvbuf, int count, int srcRank, int tag) = 0; - virtual void Send(const DataDescriptor &input, int dstRank, int tag) = 0; + virtual void Send(const void *sendbuf, int count, int dstRank, int tag) = 0; int GetRank() const; std::string GetGroupName(); @@ -95,22 +95,24 @@ void CreateCollectiveGroup(const std::vector &instanceIDs, int worl */ void DestroyCollectiveGroup(const std::string &groupName); -void AllReduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op, const std::string &groupName); +void AllReduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, + const std::string &groupName); -void Reduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op, int dstRank, +void Reduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, int dstRank, const std::string &groupName); -void AllGather(const DataDescriptor &input, DataDescriptor &output, const std::string &groupName); +void AllGather(const void *sendbuf, void *recvbuf, int count, DataType dtype, const std::string &groupName); void Barrier(const std::string &groupName); -void Scatter(const DataDescriptor &input, DataDescriptor &output, int srcRank, const std::string &groupName); +void Scatter(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank, const std::string &groupName); -void Broadcast(const DataDescriptor &input, DataDescriptor &output, int srcRank, const std::string &groupName); +void Broadcast(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank, + const std::string &groupName); -void Recv(DataDescriptor &output, int srcRank, int tag, const std::string &groupName); +void Recv(void *recvbuf, int count, DataType dtype, int tag, const std::string &groupName); -void Send(const DataDescriptor &input, int dstRank, int tag, const std::string &groupName); +void Send(const void *sendbuf, int count, int dstRank, int tag, const std::string &groupName); class CollectiveGroupMgr { public: diff --git a/api/cpp/include/yr/collective/data_descriptor.h b/api/cpp/include/yr/collective/data_descriptor.h deleted file mode 100644 index f22f849..0000000 --- a/api/cpp/include/yr/collective/data_descriptor.h +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace YR { -enum DataType : uint8_t { - INT = 0, - DOUBLE, - - INVALID -}; - -struct DataDescriptor { - void *buf = nullptr; - uint64_t count = 0; - DataType dType = DataType::INVALID; -}; -} // namespace YR \ No newline at end of file diff --git a/api/cpp/include/yr/collective/reduce_op.h b/api/cpp/include/yr/collective/reduce_op.h index 25f1b62..c1be9ae 100644 --- a/api/cpp/include/yr/collective/reduce_op.h +++ b/api/cpp/include/yr/collective/reduce_op.h @@ -19,6 +19,14 @@ #include namespace YR { + +enum DataType : uint8_t { + INT = 0, + DOUBLE, + + INVALID +}; + enum ReduceOp : uint8_t { SUM = 0, PRODUCT = 1, diff --git a/api/cpp/src/collective/collective.cpp b/api/cpp/src/collective/collective.cpp index f2507b7..4147a46 100644 --- a/api/cpp/src/collective/collective.cpp +++ b/api/cpp/src/collective/collective.cpp @@ -101,29 +101,30 @@ void DestroyCollectiveGroup(const std::string &groupName) CollectiveGroupMgr::GetInstance().DestroyCollectiveGroup(groupName); } -void AllReduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op, const std::string &groupName) +void AllReduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, + const std::string &groupName) { auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "group: " + groupName + " is not init"); - group->AllReduce(input, output, op); + group->AllReduce(sendbuf, recvbuf, count, dtype, op); } -void Reduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op, int dstRank, +void Reduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, int dstRank, const std::string &groupName) { auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "group: " + groupName + " is not init"); - group->Reduce(input, output, op, dstRank); + group->Reduce(sendbuf, recvbuf, count, dtype, op, dstRank); } -void AllGather(const DataDescriptor &input, DataDescriptor &output, const std::string &groupName) +void AllGather(const void *sendbuf, void *recvbuf, int count, DataType dtype, const std::string &groupName) { auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "group: " + groupName + " is not init"); - group->AllGather(input, output); + group->AllGather(sendbuf, recvbuf, count, dtype); } void Barrier(const std::string &groupName) @@ -134,15 +135,15 @@ void Barrier(const std::string &groupName) group->Barrier(); } -void Scatter(const DataDescriptor &input, DataDescriptor &output, int srcRank, const std::string &groupName) +void Scatter(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank, const std::string &groupName) { auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "group: " + groupName + " is not init"); - group->Scatter(input, output, srcRank); + group->Scatter(sendbuf, recvbuf, count, dtype, srcRank); } -void Broadcast(const DataDescriptor &input, DataDescriptor &output, int srcRank, const std::string &groupName) +void Broadcast(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank, const std::string &groupName) { auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, @@ -150,20 +151,20 @@ void Broadcast(const DataDescriptor &input, DataDescriptor &output, int srcRank, group->Barrier(); } -void Recv(DataDescriptor &output, int srcRank, int tag, const std::string &groupName) +void Recv(void *recvbuf, int count, int srcRank, int tag, const std::string &groupName) { auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "group: " + groupName + " is not init"); - group->Recv(output, srcRank, tag); + group->Recv(recvbuf, count, srcRank, tag); } -void Send(const DataDescriptor &input, int dstRank, int tag, const std::string &groupName) +void Send(const void *sendbuf, int count, int dstRank, int tag, const std::string &groupName) { auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "group: " + groupName + " is not init"); - group->Send(input, dstRank, tag); + group->Send(sendbuf, count, dstRank, tag); } std::shared_ptr CollectiveGroupMgr::CheckAndCreateGroup(const std::string &groupName) diff --git a/api/cpp/src/collective/gloo_collective_group.cpp b/api/cpp/src/collective/gloo_collective_group.cpp index 1bc096f..619a8cd 100644 --- a/api/cpp/src/collective/gloo_collective_group.cpp +++ b/api/cpp/src/collective/gloo_collective_group.cpp @@ -87,19 +87,20 @@ GlooCollectiveGroup::GlooCollectiveGroup(std::string groupName, int worldSize, i context_->connectFullMesh(prefixStore, dev); } -void GlooCollectiveGroup::AllReduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op) +void GlooCollectiveGroup::AllReduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op) { - EXECUTE_BY_TYPE(input.dType, DoAllReduce, input, output, op); + EXECUTE_BY_TYPE(dtype, DoAllReduce, sendbuf, recvbuf, count, op); } -void GlooCollectiveGroup::Reduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op, int dstRank) +void GlooCollectiveGroup::Reduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, + int dstRank) { - EXECUTE_BY_TYPE(input.dType, DoReduce, input, output, op, dstRank); + EXECUTE_BY_TYPE(dtype, DoReduce, sendbuf, recvbuf, count, op, dstRank); } -void GlooCollectiveGroup::AllGather(const DataDescriptor &input, DataDescriptor &output) +void GlooCollectiveGroup::AllGather(const void *sendbuf, void *recvbuf, int count, DataType dtype) { - EXECUTE_BY_TYPE(input.dType, DoAllGather, input, output); + EXECUTE_BY_TYPE(dtype, DoAllGather, sendbuf, recvbuf, count); } void GlooCollectiveGroup::Barrier() @@ -108,26 +109,26 @@ void GlooCollectiveGroup::Barrier() gloo::barrier(opts); } -void GlooCollectiveGroup::Scatter(const DataDescriptor &input, DataDescriptor &output, int srcRank) +void GlooCollectiveGroup::Scatter(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank) { - EXECUTE_BY_TYPE(input.dType, DoScatter, input, output, srcRank); + EXECUTE_BY_TYPE(dtype, DoScatter, sendbuf, recvbuf, count, srcRank); } -void GlooCollectiveGroup::Broadcast(const DataDescriptor &input, DataDescriptor &output, int srcRank) +void GlooCollectiveGroup::Broadcast(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank) { - EXECUTE_BY_TYPE(input.dType, DoBroadcast, input, output, srcRank); + EXECUTE_BY_TYPE(dtype, DoBroadcast, sendbuf, recvbuf, count, srcRank); } -void GlooCollectiveGroup::Recv(DataDescriptor &output, int srcRank, int tag) +void GlooCollectiveGroup::Recv(void *recvbuf, int count, int srcRank, int tag) { - auto ubuf = context_->createUnboundBuffer(output.buf, output.count); + auto ubuf = context_->createUnboundBuffer(recvbuf, count); ubuf->recv(srcRank, tag); ubuf->waitRecv(); } -void GlooCollectiveGroup::Send(const DataDescriptor &input, int dstRank, int tag) +void GlooCollectiveGroup::Send(const void *sendbuf, int count, int dstRank, int tag) { - auto ubuf = context_->createUnboundBuffer(input.buf, input.count); + auto ubuf = context_->createUnboundBuffer(const_cast(sendbuf), count); ubuf->send(dstRank, tag); ubuf->waitSend(); } @@ -154,52 +155,52 @@ gloo::AllreduceOptions::Func GlooCollectiveGroup::GetReduceOp(const ReduceOp &op } template -void GlooCollectiveGroup::DoAllReduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op) +void GlooCollectiveGroup::DoAllReduce(const void *sendbuf, void *recvbuf, int count, const ReduceOp &op) { gloo::AllreduceOptions opts(context_); - opts.setInput(static_cast(input.buf), input.count); - opts.setOutput(static_cast(output.buf), output.count); + opts.setInput(const_cast((T *)sendbuf), count); + opts.setOutput(static_cast(recvbuf), count); opts.setReduceFunction(GetReduceOp(op)); gloo::allreduce(opts); } template -void GlooCollectiveGroup::DoReduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op, int dstRank) +void GlooCollectiveGroup::DoReduce(const void *sendbuf, void *recvbuf, int count, const ReduceOp &op, int dstRank) { gloo::ReduceOptions opts(context_); - opts.setInput(static_cast(input.buf), input.count); - opts.setOutput(static_cast(output.buf), output.count); + opts.setInput(const_cast((T *)sendbuf), count); + opts.setOutput(static_cast(recvbuf), count); opts.setReduceFunction(GetReduceOp(op)); gloo::reduce(opts); } template -void GlooCollectiveGroup::DoBroadcast(const DataDescriptor &input, DataDescriptor &output, int srcRank) +void GlooCollectiveGroup::DoBroadcast(const void *sendbuf, void *recvbuf, int count, int srcRank) { gloo::BroadcastOptions opts(context_); - opts.setInput(static_cast(input.buf), input.count); - opts.setOutput(static_cast(output.buf), output.count); + opts.setInput(const_cast((T *)sendbuf), count); + opts.setOutput(static_cast(recvbuf), count); opts.setRoot(srcRank); gloo::broadcast(opts); } template -void GlooCollectiveGroup::DoScatter(const DataDescriptor &input, DataDescriptor &output, int srcRank) +void GlooCollectiveGroup::DoScatter(const void *sendbuf, void *recvbuf, int count, int srcRank) { gloo::ScatterOptions opts(context_); - std::vector inputs = {static_cast(input.buf)}; - opts.setInputs(inputs, input.count); - opts.setOutput(static_cast(output.buf), output.count); + std::vector inputs = {const_cast((T *)sendbuf)}; + opts.setInputs(inputs, count); + opts.setOutput(static_cast(recvbuf), count); opts.setRoot(srcRank); gloo::scatter(opts); } template -void GlooCollectiveGroup::DoAllGather(const DataDescriptor &input, DataDescriptor &output) +void GlooCollectiveGroup::DoAllGather(const void *sendbuf, void *recvbuf, int count) { gloo::AllgatherOptions opts(context_); - opts.setInput(static_cast(input.buf), input.count); - opts.setOutput(static_cast(output.buf), output.count); + opts.setInput(const_cast((T *)sendbuf), count); + opts.setOutput(static_cast(recvbuf), count); gloo::allgather(opts); } } // namespace YR::collective \ No newline at end of file diff --git a/api/cpp/src/collective/gloo_collective_group.h b/api/cpp/src/collective/gloo_collective_group.h index 9a0bf98..1682d35 100644 --- a/api/cpp/src/collective/gloo_collective_group.h +++ b/api/cpp/src/collective/gloo_collective_group.h @@ -25,37 +25,38 @@ class GlooCollectiveGroup : public CollectiveGroup { public: GlooCollectiveGroup(std::string groupName, int worldSize, int rank); - void AllReduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op) override; + void AllReduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op) override; - void Reduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op, int dstRank) override; + void Reduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const ReduceOp &op, + int dstRank) override; - void AllGather(const DataDescriptor &input, DataDescriptor &output) override; + void AllGather(const void *sendbuf, void *recvbuf, int count, DataType dtype) override; void Barrier() override; - void Scatter(const DataDescriptor &input, DataDescriptor &output, int srcRank) override; + void Scatter(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank) override; - void Broadcast(const DataDescriptor &input, DataDescriptor &output, int srcRank) override; + void Broadcast(const void *sendbuf, void *recvbuf, int count, DataType dtype, int srcRank) override; - void Recv(DataDescriptor &output, int srcRank, int tag) override; + void Recv(void *recvbuf, int count, int srcRank, int tag) override; - void Send(const DataDescriptor &input, int dstRank, int tag) override; + void Send(const void *sendbuf, int count, int dstRank, int tag) override; private: template - void DoAllReduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op); + void DoAllReduce(const void *sendbuf, void *recvbuf, int count, const ReduceOp &op); template - void DoReduce(const DataDescriptor &input, DataDescriptor &output, const ReduceOp &op, int dstRank); + void DoReduce(const void *sendbuf, void *recvbuf, int count, const ReduceOp &op, int dstRank); template - void DoAllGather(const DataDescriptor &input, DataDescriptor &output); + void DoAllGather(const void *sendbuf, void *recvbuf, int count); template - void DoScatter(const DataDescriptor &input, DataDescriptor &output, int srcRank); + void DoScatter(const void *sendbuf, void *recvbuf, int count, int srcRank); template - void DoBroadcast(const DataDescriptor &input, DataDescriptor &output, int srcRank); + void DoBroadcast(const void *sendbuf, void *recvbuf, int count, int srcRank); template static gloo::AllreduceOptions::Func GetReduceOp(const ReduceOp &op); -- Gitee From 15e91dcda250bb7b3d3306517652f2b821f400e0 Mon Sep 17 00:00:00 2001 From: mhsong Date: Wed, 19 Nov 2025 11:02:20 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=A3=80=E8=A7=86?= =?UTF-8?q?=E6=84=8F=E8=A7=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CMakeLists.txt | 20 +++++++++++++++++--- WORKSPACE | 8 +++++--- api/cpp/BUILD.bazel | 3 +++ api/cpp/src/collective/collective.cpp | 22 +++++++++++++--------- bazel/gloo.bzl | 18 +++++++++++++++--- tools/openSource.txt | 1 - 6 files changed, 53 insertions(+), 19 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a2fb20..0ae97b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,10 +20,22 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_VERBOSE_MAKEFILE ON) # cpp -file(GLOB SOURCE "${CMAKE_CURRENT_LIST_DIR}/cpp/src/*.cpp") +file(GLOB SOURCE "${CMAKE_CURRENT_LIST_DIR}/api/cpp/src/*.cpp") list(APPEND SOURCES ${SOURCE}) -file(GLOB SOURCE "${CMAKE_CURRENT_LIST_DIR}/cpp/src/main/*.cpp") +file(GLOB SOURCE "${CMAKE_CURRENT_LIST_DIR}/api/cpp/src/main/*.cpp") +list(APPEND SOURCES ${SOURCE}) + +file(GLOB SOURCE "${CMAKE_CURRENT_LIST_DIR}/api/cpp/src/executor/*.cpp") +list(APPEND SOURCES ${SOURCE}) + +file(GLOB SOURCE "${CMAKE_CURRENT_LIST_DIR}/api/cpp/src/utils/*.cpp") +list(APPEND SOURCES ${SOURCE}) + +file(GLOB SOURCE "${CMAKE_CURRENT_LIST_DIR}/api/cpp/src/collective/*.cpp") +list(APPEND SOURCES ${SOURCE}) + +file(GLOB SOURCE "${CMAKE_CURRENT_LIST_DIR}/api/cpp/src/parallel_for/*.cpp") list(APPEND SOURCES ${SOURCE}) # src @@ -77,7 +89,9 @@ add_library(runtime SHARED ${SOURCES} api/java/function-common/src/main/cpp/jni_types.cpp) # cpp -list(APPEND INCLUDE_PATH ${CMAKE_CURRENT_LIST_DIR}/cpp/include) +list(APPEND INCLUDE_PATH ${CMAKE_CURRENT_LIST_DIR}/api/cpp/include) +list(APPEND INCLUDE_PATH ${CMAKE_CURRENT_LIST_DIR}/build/output/external/gloo) +list(APPEND INCLUDE_PATH ${CMAKE_CURRENT_LIST_DIR}/build/output/external/nlohmann_json/include/nlohmann) # src list(APPEND INCLUDE_PATH ${CMAKE_CURRENT_LIST_DIR}) diff --git a/WORKSPACE b/WORKSPACE index d79ba53..9eb015b 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -149,10 +149,12 @@ http_archive( build_file = "@//bazel:jacoco.bzl" ) -new_local_repository( +http_archive( name = "gloo", - build_file = "//third_party:gloo.bzl", - path = "../third_party/gloo", + urls = ["https://github.com/pytorch/gloo/archive/refs/heads/main.zip"], + strip_prefix = "gloo-main", + sha256 = "ddfb9627304025f294c78fd75ae84d7b9a662213d1b1c955afe6f91bfeb49254", + build_file = "@//bazel:gloo.bzl" ) maybe( diff --git a/api/cpp/BUILD.bazel b/api/cpp/BUILD.bazel index e65fd40..6238872 100644 --- a/api/cpp/BUILD.bazel +++ b/api/cpp/BUILD.bazel @@ -23,11 +23,14 @@ cc_library( "src/utils/*.cpp", "src/executor/*.h", "src/executor/*.cpp", + "src/collective/*.h", + "src/collective/*.cpp", ]) + ["src/utils/version.h"], hdrs = glob([ "include/yr/*.h", "include/yr/api/*.h", "include/yr/parallel/*.h", + "include/yr/collective/*.h", "include/yr/parallel/detail/*.h", ]), copts = COPTS, diff --git a/api/cpp/src/collective/collective.cpp b/api/cpp/src/collective/collective.cpp index 4147a46..93aeefc 100644 --- a/api/cpp/src/collective/collective.cpp +++ b/api/cpp/src/collective/collective.cpp @@ -89,7 +89,9 @@ void InitCollectiveGroup(int worldSize, int rank, const std::string &groupName, void CreateCollectiveGroup(const std::vector &instanceIDs, int worldSize, const std::vector &ranks, const std::string &groupName, Backend backend) { - // todo 前置参数校验 + THROW_IF_TRUE(instanceIDs.size() != worldSize || worldSize != ranks.size(), + YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "failed to create collective group, unequal actor, rank or world size, please check"); CollectiveGroupInfo info{ .instances = instanceIDs, .worldSize = worldSize, .ranks = ranks, .groupName = groupName, .backend = backend}; @@ -106,7 +108,7 @@ void AllReduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, co { auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, - "group: " + groupName + " is not init"); + "please create group " + groupName + "first"); group->AllReduce(sendbuf, recvbuf, count, dtype, op); } @@ -115,7 +117,7 @@ void Reduce(const void *sendbuf, void *recvbuf, int count, DataType dtype, const { auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, - "group: " + groupName + " is not init"); + "please create group " + groupName + "first"); group->Reduce(sendbuf, recvbuf, count, dtype, op, dstRank); } @@ -123,7 +125,7 @@ void AllGather(const void *sendbuf, void *recvbuf, int count, DataType dtype, co { auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, - "group: " + groupName + " is not init"); + "please create group " + groupName + "first"); group->AllGather(sendbuf, recvbuf, count, dtype); } @@ -131,7 +133,7 @@ void Barrier(const std::string &groupName) { auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, - "group: " + groupName + " is not init"); + "please create group " + groupName + "first"); group->Barrier(); } @@ -139,7 +141,7 @@ void Scatter(const void *sendbuf, void *recvbuf, int count, DataType dtype, int { auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, - "group: " + groupName + " is not init"); + "please create group " + groupName + "first"); group->Scatter(sendbuf, recvbuf, count, dtype, srcRank); } @@ -147,7 +149,7 @@ void Broadcast(const void *sendbuf, void *recvbuf, int count, DataType dtype, in { auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, - "group: " + groupName + " is not init"); + "please create group " + groupName + "first"); group->Barrier(); } @@ -163,7 +165,7 @@ void Send(const void *sendbuf, int count, int dstRank, int tag, const std::strin { auto group = CollectiveGroupMgr::GetInstance().CheckAndCreateGroup(groupName); THROW_IF_TRUE(group == nullptr, YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, - "group: " + groupName + " is not init"); + "please create group " + groupName + "first"); group->Send(sendbuf, count, dstRank, tag); } @@ -174,7 +176,8 @@ std::shared_ptr CollectiveGroupMgr::CheckAndCreateGroup(const s return groups_[groupName]; } auto str = YR::KVManager::Get(COLLECTIVE_GROUP_INFO_PREFIX + groupName); - THROW_IF_TRUE(str.empty(), YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, "group: " + groupName + " is not init"); + THROW_IF_TRUE(str.empty(), YR::Libruntime::ErrorCode::ERR_PARAM_INVALID, + "please create group " + groupName + "first"); CollectiveGroupInfo info; info.FromJson(str); @@ -218,6 +221,7 @@ void CollectiveGroupMgr::InitCollectiveGroup(int worldSize, int rank, const std: void CollectiveGroupMgr::DestroyCollectiveGroup(const std::string &groupName) { std::lock_guard lock(mtx_); + YR::KVManager::Del(COLLECTIVE_GROUP_INFO_PREFIX + groupName); groups_.erase(groupName); } diff --git a/bazel/gloo.bzl b/bazel/gloo.bzl index f5d6ff2..827fffd 100644 --- a/bazel/gloo.bzl +++ b/bazel/gloo.bzl @@ -1,7 +1,19 @@ cc_library( name = "gloo", - srcs = [], - hdrs = glob(["gloo/**/*.h"]), - includes = ["include"], # 指定头文件根目录 + srcs = glob([ + "gloo/**/*.cc", + "gloo/**/*.cpp", + ]), + hdrs = [":headers"], + includes = ["."], + deps = [ + ], + linkopts = [ + "-lpthread", + ], + copts = [ + "-DGLOO_USE_MPI=0", + "-DUSE_IBVERBS", + ], visibility = ["//visibility:public"], ) \ No newline at end of file diff --git a/tools/openSource.txt b/tools/openSource.txt index 4cdb1d0..79deccd 100644 --- a/tools/openSource.txt +++ b/tools/openSource.txt @@ -17,5 +17,4 @@ libboundscheck,v1.1.16,runtime;litebus;functionsystem;metrics,https://gitee.com/ curl,8.8.0,litebus,https://gitee.com/mirrors/curl/repository/archive/curl-8_8_0.zip,73c70c94f487c5ae26f9f27094249e40bb1667ae6c0406a75c3b11f86f0c1128,test gtest_1_12_1,release-1.12.1,litebus;functionsystem;metrics;logs,https://gitee.com/mirrors/googletest/repository/archive/release-1.12.1.zip,4f1037f17462dbcc4d715f2fc8212c03431ad06e7a4b73bb49b3f534a13b21f1,test gtest_1_10_0,release-1.10.0,functionsystem,https://gitee.com/mirrors/googletest/repository/archive/release-1.10.0.zip,db55081e6b5a5f820d052f8d37e0d29835dd5a2bd5b8d4d50590b7834f9abfae,test -gloo,main,runtime,https://github.com/pytorch/gloo/archive/refs/heads/main.zip,ddfb9627304025f294c78fd75ae84d7b9a662213d1b1c955afe6f91bfeb49254, ,,,,,, \ No newline at end of file -- Gitee