diff --git a/tf_adapter/interface_spec/api_hccl_ops.pyh b/tf_adapter/interface_spec/api_hccl_ops.pyh index 58f09459566704da01724080fca08b7c6511f33d..e5f8c54a2aebb50b7c1dcc4b27affbbe6136807d 100644 --- a/tf_adapter/interface_spec/api_hccl_ops.pyh +++ b/tf_adapter/interface_spec/api_hccl_ops.pyh @@ -2,6 +2,7 @@ def allreduce(tensor, reduction, fusion=1, fusion_id=-1, group="hccl_world_group"): def allgather(tensor, rank_size, group="hccl_world_group"): def broadcast(tensor, root_rank, fusion=0, fusion_id=-1, group="hccl_world_group"): +def reduce(tensor, reduction, root_rank, fusion=0, fusion_id=-1, group="hccl_world_group"): def reduce_scatter(tensor, reduction, rank_size, group="hccl_world_group"): def send(tensor, sr_tag, dest_rank, group="hccl_world_group"): def receive(shape, data_type, sr_tag, src_rank, group="hccl_world_group"): diff --git a/tf_adapter/kernels/hccl_ops.cc b/tf_adapter/kernels/hccl_ops.cc index 2c28f2031fd2379895dbc5991fa0808ccd80089b..d5d474b1c4cdd0e1ff2d7fdb7d531b05946684b2 100644 --- a/tf_adapter/kernels/hccl_ops.cc +++ b/tf_adapter/kernels/hccl_ops.cc @@ -55,6 +55,15 @@ class HcomBroadcastOpKernel : public OpKernel { REGISTER_KERNEL_BUILDER(Name("HcomBroadcast").Device(DEVICE_CPU), HcomBroadcastOpKernel); +class HcomReduceOpKernel : public OpKernel { + public: + explicit HcomReduceOpKernel(OpKernelConstruction *context) : OpKernel(context) {} + ~HcomReduceOpKernel() {} + void Compute(OpKernelContext *context) override { LOG(INFO) << "HcomReduceOp Compute."; } +}; + +REGISTER_KERNEL_BUILDER(Name("HcomReduce").Device(DEVICE_CPU), HcomReduceOpKernel); + class HcomReduceScatterOpKernel : public OpKernel { public: explicit HcomReduceScatterOpKernel(OpKernelConstruction *context) : OpKernel(context) {} diff --git a/tf_adapter/ops/hccl_ops.cc b/tf_adapter/ops/hccl_ops.cc index b51caf5ea3f644da01e825b613d00b319f9935d7..1669115f10c3ca12f4a863efa410d3b2344dd05c 100644 --- a/tf_adapter/ops/hccl_ops.cc +++ b/tf_adapter/ops/hccl_ops.cc @@ -115,6 +115,33 @@ input: The input to the broadcast. output: The same as input. )doc"); +REGISTER_OP("HcomReduce") + .Input("input: T") + .Output("output: T") + .Attr("T: {int8, int16, int32, float16, float32}") + .Attr("reduction: {'min', 'max', 'prod', 'sum'}") + .Attr("group: string") + .Attr("root_rank: int") + .Attr("fusion: int") + .Attr("fusion_id: int") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext *c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }) + .Doc(R"doc( +Outputs a tensor containing the reduction across all input tensors passed to ops. + +The graph should be constructed so if one op runs with shared_name value `c`, +then `num_devices` ops will run with shared_name value `c`. Failure to do so +will cause the graph execution to fail to complete. + +input: the input to the reduction +output: the value of the reduction across all `num_devices` devices. +reduction: the reduction operation to perform. +group: all devices of the group participating in this reduction. +)doc"); + REGISTER_OP("HcomReduceScatter") .Input("input: T") .Output("output: T") @@ -222,7 +249,7 @@ REGISTER_OP("HcomRemoteWrite") REGISTER_OP("HcomRemoteScatterWrite") .Input("remote: T") - .Input("local: dtype") + .Input("local: Ref(dtype)") .Input("local_offset: T") .Attr("T: {int64, uint64}") .Attr("dtype: {int8, int16, int32, float16, float32, int64, uint64}") diff --git a/tf_adapter/python/npu_bridge/hccl/hccl_ops.py b/tf_adapter/python/npu_bridge/hccl/hccl_ops.py index 974a51274c2bd9931223462fc4b22a7a2548905e..3926ac2c6a2799bb1d39e3aba9df662fb012d211 100644 --- a/tf_adapter/python/npu_bridge/hccl/hccl_ops.py +++ b/tf_adapter/python/npu_bridge/hccl/hccl_ops.py @@ -48,7 +48,7 @@ def allgather(tensor, rank_size, group="hccl_world_group"): ## 提供group内的集合通信broadcast功能 # @param tensor tensorflow的tensor类型,broadcast操作的输入; # @param root_rank int类型,作为root节点的rank_id,该id是group内的rank id; -# @param fusion int类型,算子融合标识。0: 不融合;1:按照相同fusion_id融合;其他值非法。 +# @param fusion int类型,算子融合标识。0: 不融合;2:按照相同fusion_id融合;其他值非法。 # @param fusion_id int类型,算子融合索引标识,相同fusion_id的算子将会融合。 # @param group string类型,group名称,可以为用户自定义group或者"hccl_world_group"; # @return 对输入tensor执行完broadcast操作之后的结果tensor @@ -61,6 +61,24 @@ def broadcast(tensor, root_rank, fusion=0, fusion_id=-1, group="hccl_world_group root_rank=root_rank) return result +## 提供group内的集合通信reduce功能 +# @param tensor tensorflow的tensor类型,reduce操作的输入; +# @param reduction string类型,reduce的操作类型,可以为”max”,”min”,”prod”和”sum”; +# @param fusion int类型,算子融合标识。0: 不融合; 2: 按照相同fusion_id融合。 +# @param fusion_id int类型,算子融合索引标识,相同fusion_id的算子将会融合。 +# @param root_rank int类型,作为root节点的rank_id,该id是group内的rank id; +# @param group string类型,group名称,可以为用户自定义group或者"hccl_world_group"; +# @return 对输入tensor执行完reduce操作之后的结果tensor +def reduce(tensor, reduction, root_rank, fusion=0, fusion_id=-1, group="hccl_world_group"): + result = gen_hccl_ops.hcom_reduce( + input=tensor, + reduction=reduction, + fusion=fusion, + fusion_id=fusion_id, + group=group, + root_rank=root_rank) + return result + ## 提供group内的集合通信reduce_scatter功能 # @param tensor tensorflow的tensor类型,reduce_scatter操作的输入;