From f426dbb5ce7a9ed290f3b48662919cd8d9d0e8ca Mon Sep 17 00:00:00 2001 From: cclworkaccount <8266062+cclworkaccount@user.noreply.gitee.com> Date: Sat, 28 Nov 2020 10:33:58 +0800 Subject: [PATCH 1/5] update tf_adapter/interface_spec/api_hccl_ops.pyh. --- tf_adapter/interface_spec/api_hccl_ops.pyh | 1 + 1 file changed, 1 insertion(+) diff --git a/tf_adapter/interface_spec/api_hccl_ops.pyh b/tf_adapter/interface_spec/api_hccl_ops.pyh index 58f094595..e5f8c54a2 100644 --- a/tf_adapter/interface_spec/api_hccl_ops.pyh +++ b/tf_adapter/interface_spec/api_hccl_ops.pyh @@ -2,6 +2,7 @@ def allreduce(tensor, reduction, fusion=1, fusion_id=-1, group="hccl_world_group"): def allgather(tensor, rank_size, group="hccl_world_group"): def broadcast(tensor, root_rank, fusion=0, fusion_id=-1, group="hccl_world_group"): +def reduce(tensor, reduction, root_rank, fusion=0, fusion_id=-1, group="hccl_world_group"): def reduce_scatter(tensor, reduction, rank_size, group="hccl_world_group"): def send(tensor, sr_tag, dest_rank, group="hccl_world_group"): def receive(shape, data_type, sr_tag, src_rank, group="hccl_world_group"): -- Gitee From c1938eb7c66609b30b482f794b87a7c88b342349 Mon Sep 17 00:00:00 2001 From: cclworkaccount <8266062+cclworkaccount@user.noreply.gitee.com> Date: Sat, 28 Nov 2020 14:10:55 +0800 Subject: [PATCH 2/5] update tf_adapter/kernels/hccl_ops.cc. --- tf_adapter/kernels/hccl_ops.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tf_adapter/kernels/hccl_ops.cc b/tf_adapter/kernels/hccl_ops.cc index 2c28f2031..d5d474b1c 100644 --- a/tf_adapter/kernels/hccl_ops.cc +++ b/tf_adapter/kernels/hccl_ops.cc @@ -55,6 +55,15 @@ class HcomBroadcastOpKernel : public OpKernel { REGISTER_KERNEL_BUILDER(Name("HcomBroadcast").Device(DEVICE_CPU), HcomBroadcastOpKernel); +class HcomReduceOpKernel : public OpKernel { + public: + explicit HcomReduceOpKernel(OpKernelConstruction *context) : OpKernel(context) {} + ~HcomReduceOpKernel() {} + void Compute(OpKernelContext *context) override { LOG(INFO) << "HcomReduceOp Compute."; } +}; + +REGISTER_KERNEL_BUILDER(Name("HcomReduce").Device(DEVICE_CPU), HcomReduceOpKernel); + class HcomReduceScatterOpKernel : public OpKernel { public: explicit HcomReduceScatterOpKernel(OpKernelConstruction *context) : OpKernel(context) {} -- Gitee From 4c63403f451f6660415404852fc20c94e4aed718 Mon Sep 17 00:00:00 2001 From: cclworkaccount <8266062+cclworkaccount@user.noreply.gitee.com> Date: Sat, 28 Nov 2020 14:15:44 +0800 Subject: [PATCH 3/5] update tf_adapter/ops/hccl_ops.cc. --- tf_adapter/ops/hccl_ops.cc | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tf_adapter/ops/hccl_ops.cc b/tf_adapter/ops/hccl_ops.cc index b51caf5ea..17e35189d 100644 --- a/tf_adapter/ops/hccl_ops.cc +++ b/tf_adapter/ops/hccl_ops.cc @@ -115,6 +115,33 @@ input: The input to the broadcast. output: The same as input. )doc"); +REGISTER_OP("HcomReduce") + .Input("input: T") + .Output("output: T") + .Attr("T: {int8, int16, int32, float16, float32}") + .Attr("reduction: {'min', 'max', 'prod', 'sum'}") + .Attr("group: string") + .Attr("root_rank: int") + .Attr("fusion: int") + .Attr("fusion_id: int") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext *c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }) + .Doc(R"doc( +Outputs a tensor containing the reduction across all input tensors passed to ops. + +The graph should be constructed so if one op runs with shared_name value `c`, +then `num_devices` ops will run with shared_name value `c`. Failure to do so +will cause the graph execution to fail to complete. + +input: the input to the reduction +output: the value of the reduction across all `num_devices` devices. +reduction: the reduction operation to perform. +group: all devices of the group participating in this reduction. +)doc"); + REGISTER_OP("HcomReduceScatter") .Input("input: T") .Output("output: T") -- Gitee From 4ef131328da723101a775ff60c6e074bfc17cdb8 Mon Sep 17 00:00:00 2001 From: cclworkaccount <8266062+cclworkaccount@user.noreply.gitee.com> Date: Sat, 28 Nov 2020 14:19:39 +0800 Subject: [PATCH 4/5] update tf_adapter/python/npu_bridge/hccl/hccl_ops.py. --- tf_adapter/python/npu_bridge/hccl/hccl_ops.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tf_adapter/python/npu_bridge/hccl/hccl_ops.py b/tf_adapter/python/npu_bridge/hccl/hccl_ops.py index 974a51274..3926ac2c6 100644 --- a/tf_adapter/python/npu_bridge/hccl/hccl_ops.py +++ b/tf_adapter/python/npu_bridge/hccl/hccl_ops.py @@ -48,7 +48,7 @@ def allgather(tensor, rank_size, group="hccl_world_group"): ## 提供group内的集合通信broadcast功能 # @param tensor tensorflow的tensor类型,broadcast操作的输入; # @param root_rank int类型,作为root节点的rank_id,该id是group内的rank id; -# @param fusion int类型,算子融合标识。0: 不融合;1:按照相同fusion_id融合;其他值非法。 +# @param fusion int类型,算子融合标识。0: 不融合;2:按照相同fusion_id融合;其他值非法。 # @param fusion_id int类型,算子融合索引标识,相同fusion_id的算子将会融合。 # @param group string类型,group名称,可以为用户自定义group或者"hccl_world_group"; # @return 对输入tensor执行完broadcast操作之后的结果tensor @@ -61,6 +61,24 @@ def broadcast(tensor, root_rank, fusion=0, fusion_id=-1, group="hccl_world_group root_rank=root_rank) return result +## 提供group内的集合通信reduce功能 +# @param tensor tensorflow的tensor类型,reduce操作的输入; +# @param reduction string类型,reduce的操作类型,可以为”max”,”min”,”prod”和”sum”; +# @param fusion int类型,算子融合标识。0: 不融合; 2: 按照相同fusion_id融合。 +# @param fusion_id int类型,算子融合索引标识,相同fusion_id的算子将会融合。 +# @param root_rank int类型,作为root节点的rank_id,该id是group内的rank id; +# @param group string类型,group名称,可以为用户自定义group或者"hccl_world_group"; +# @return 对输入tensor执行完reduce操作之后的结果tensor +def reduce(tensor, reduction, root_rank, fusion=0, fusion_id=-1, group="hccl_world_group"): + result = gen_hccl_ops.hcom_reduce( + input=tensor, + reduction=reduction, + fusion=fusion, + fusion_id=fusion_id, + group=group, + root_rank=root_rank) + return result + ## 提供group内的集合通信reduce_scatter功能 # @param tensor tensorflow的tensor类型,reduce_scatter操作的输入; -- Gitee From e4f18a62e9becb885f4e3ef2479013e6130153f7 Mon Sep 17 00:00:00 2001 From: cclworkaccount Date: Fri, 8 Jan 2021 16:31:50 +0800 Subject: [PATCH 5/5] update tf_adapter/ops/hccl_ops.cc. --- tf_adapter/ops/hccl_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf_adapter/ops/hccl_ops.cc b/tf_adapter/ops/hccl_ops.cc index 17e35189d..1669115f1 100644 --- a/tf_adapter/ops/hccl_ops.cc +++ b/tf_adapter/ops/hccl_ops.cc @@ -249,7 +249,7 @@ REGISTER_OP("HcomRemoteWrite") REGISTER_OP("HcomRemoteScatterWrite") .Input("remote: T") - .Input("local: dtype") + .Input("local: Ref(dtype)") .Input("local_offset: T") .Attr("T: {int64, uint64}") .Attr("dtype: {int8, int16, int32, float16, float32, int64, uint64}") -- Gitee