From e3eac40048e1c690747ff8d9a7bb8b6d603a1ea8 Mon Sep 17 00:00:00 2001 From: cclworkaccount <8266062+cclworkaccount@user.noreply.gitee.com> Date: Wed, 4 Nov 2020 15:14:19 +0800 Subject: [PATCH 1/7] =?UTF-8?q?update=20tf=5Fadapter/ops/hccl=5Fops.cc.=20?= =?UTF-8?q?HCCL=20broadcast=E8=9E=8D=E5=90=88=EF=BC=8C=E8=AE=BE=E8=AE=A1tf?= =?UTF-8?q?adaptor=20=E6=8E=A5=E5=8F=A3=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tf_adapter/ops/hccl_ops.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tf_adapter/ops/hccl_ops.cc b/tf_adapter/ops/hccl_ops.cc index fb6d6f0a6..a17d67311 100644 --- a/tf_adapter/ops/hccl_ops.cc +++ b/tf_adapter/ops/hccl_ops.cc @@ -96,6 +96,8 @@ REGISTER_OP("HcomBroadcast") .Input("input: T") .Output("output: T") .Attr("T: list(type) >= 0") + .Attr("fusion: int") + .Attr("fusion_id: int") .Attr("group: string") .Attr("root_rank: int") .SetIsStateful() @@ -184,7 +186,7 @@ REGISTER_OP("HcomRemoteRead") .Attr("dtype: {int8, int16, int32, float16, float32, int64, uint64}") .SetIsStateful() .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->UnknownShape()); // 第一维shape确定,第二维unknown + c->set_output(0, c->UnknownShape()); // 一维shape确诙维unknown return Status::OK(); }) .Doc(R"doc( -- Gitee From 5defad8e82f11e93113a5cbef231eed3186a6ec6 Mon Sep 17 00:00:00 2001 From: cclworkaccount <8266062+cclworkaccount@user.noreply.gitee.com> Date: Wed, 4 Nov 2020 15:17:58 +0800 Subject: [PATCH 2/7] =?UTF-8?q?update=20tf=5Fadapter/python/npu=5Fbridge/h?= =?UTF-8?q?ccl/hccl=5Fops.py.=20HCCL=20broadcast=E8=9E=8D=E5=90=88?= =?UTF-8?q?=EF=BC=8Ctfadptor=20python=20=E6=8E=A5=E5=8F=A3=E6=9B=B4?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tf_adapter/python/npu_bridge/hccl/hccl_ops.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tf_adapter/python/npu_bridge/hccl/hccl_ops.py b/tf_adapter/python/npu_bridge/hccl/hccl_ops.py index 970da13c3..0f66c5671 100644 --- a/tf_adapter/python/npu_bridge/hccl/hccl_ops.py +++ b/tf_adapter/python/npu_bridge/hccl/hccl_ops.py @@ -48,11 +48,15 @@ def allgather(tensor, rank_size, group="hccl_world_group"): ## 鎻愪緵group鍐呯殑闆嗗悎閫氫俊broadcast鍔熻兘 # @param tensor tensorflow鐨則ensor绫诲瀷锛宐roadcast鎿嶄綔鐨勮緭鍏ワ紱 # @param root_rank int绫诲瀷锛屼綔涓簉oot鑺傜偣鐨剅ank_id锛岃id鏄痝roup鍐呯殑rank id; +# @param fusion int绫诲瀷锛岀畻瀛愯瀺鍚堟爣璇嗐0: 涓嶈瀺鍚堬紱1:鎸夌収鐩稿悓fusion_id铻嶅悎;鍏朵粬鍊奸潪娉曘 +# @param fusion_id int绫诲瀷锛岀畻瀛愯瀺鍚堢储寮曟爣璇嗭紝鐩稿悓fusion_id鐨勭畻瀛愬皢浼氳瀺鍚堛 # @param group string绫诲瀷锛実roup鍚嶇О锛屽彲浠ヤ负鐢ㄦ埛鑷畾涔塯roup鎴栬"hccl_world_group"; # @return 瀵硅緭鍏ensor鎵ц瀹宐roadcast鎿嶄綔涔嬪悗鐨勭粨鏋渢ensor -def broadcast(tensor, root_rank, group="hccl_world_group"): +def broadcast(tensor, root_rank, fusion=0, fusion_id=-1, group="hccl_world_group"): result = gen_hccl_ops.hcom_broadcast( input=tensor, + fusion=fusion, + fusion_id=fusion_id, group=group, root_rank=root_rank) return result -- Gitee From e644b00c398c0d0880dfe15a1659e296a6ceb38d Mon Sep 17 00:00:00 2001 From: cclworkaccount <8266062+cclworkaccount@user.noreply.gitee.com> Date: Wed, 4 Nov 2020 16:42:05 +0800 Subject: [PATCH 3/7] update tf_adapter/interface_spec/api_hccl_ops.pyh. --- tf_adapter/interface_spec/api_hccl_ops.pyh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf_adapter/interface_spec/api_hccl_ops.pyh b/tf_adapter/interface_spec/api_hccl_ops.pyh index aeb48e957..b854d818a 100644 --- a/tf_adapter/interface_spec/api_hccl_ops.pyh +++ b/tf_adapter/interface_spec/api_hccl_ops.pyh @@ -1,7 +1,7 @@ # source file:./python/npu_bridge/hccl/hccl_ops.py def allreduce(tensor, reduction, fusion=1, fusion_id=-1, group="hccl_world_group"): def allgather(tensor, rank_size, group="hccl_world_group"): -def broadcast(tensor, root_rank, group="hccl_world_group"): +def broadcast(tensor, root_rank, fusion=0, fusion_id=-1, group="hccl_world_group"): def reduce_scatter(tensor, reduction, rank_size, group="hccl_world_group"): def send(tensor, sr_tag, dest_rank, group="hccl_world_group"): def receive(shape, data_type, sr_tag, src_rank, group="hccl_world_group"): -- Gitee From d7c16cc4ca9b50eb2a86104fc884bcd6f3597469 Mon Sep 17 00:00:00 2001 From: cclworkaccount <8266062+cclworkaccount@user.noreply.gitee.com> Date: Wed, 4 Nov 2020 17:57:44 +0800 Subject: [PATCH 4/7] update tf_adapter/ops/hccl_ops.cc. --- tf_adapter/ops/hccl_ops.cc | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tf_adapter/ops/hccl_ops.cc b/tf_adapter/ops/hccl_ops.cc index a17d67311..b51caf5ea 100644 --- a/tf_adapter/ops/hccl_ops.cc +++ b/tf_adapter/ops/hccl_ops.cc @@ -193,6 +193,22 @@ REGISTER_OP("HcomRemoteRead") )doc"); +REGISTER_OP("HcomRemoteRefRead") + .Input("remote: T") + .Input("cache_var: Ref(dtype)") + .Input("local_offset: T") + .Output("cache_var1:Ref(dtype)") + .Attr("T: {uint64}") + .Attr("dtype: {int8, int16, int32, float16, float32, int64, uint64}") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->input(1)); + return Status::OK(); + }) + .Doc(R"doc( + +)doc"); + REGISTER_OP("HcomRemoteWrite") .Input("remote: T") .Input("local: dtype") @@ -204,4 +220,15 @@ REGISTER_OP("HcomRemoteWrite") )doc"); +REGISTER_OP("HcomRemoteScatterWrite") + .Input("remote: T") + .Input("local: dtype") + .Input("local_offset: T") + .Attr("T: {int64, uint64}") + .Attr("dtype: {int8, int16, int32, float16, float32, int64, uint64}") + .SetIsStateful() + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( + +)doc"); } // namespace tensorflow -- Gitee From f785cdf2d9c99aa920b26bec6e885370e1bc50a6 Mon Sep 17 00:00:00 2001 From: cclworkaccount <8266062+cclworkaccount@user.noreply.gitee.com> Date: Wed, 4 Nov 2020 20:00:18 +0800 Subject: [PATCH 5/7] update tf_adapter/python/npu_bridge/hccl/hccl_ops.py. --- tf_adapter/python/npu_bridge/hccl/hccl_ops.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tf_adapter/python/npu_bridge/hccl/hccl_ops.py b/tf_adapter/python/npu_bridge/hccl/hccl_ops.py index 0f66c5671..d226f58aa 100644 --- a/tf_adapter/python/npu_bridge/hccl/hccl_ops.py +++ b/tf_adapter/python/npu_bridge/hccl/hccl_ops.py @@ -116,6 +116,17 @@ def remote_read(tensorRemote, data_type): dtype=data_type) return result +##鎻愪緵remote ref read鍔熻兘 +# @param tensorRemote 杩滅鍐呭瓨淇℃伅锛宻hape(index_num, 3)锛歔u64 remoteId, u64 remoteAddr, u64 dataLength] +# @param cache 鏈鎺ユ敹鍐呭瓨鍩哄湴鍧 +# @param offset 杩涜璺宠鐨勬闀 +def remote_ref_read(tensorRemote, cache, offset): + result=gen_hccl_ops.hcom_remote_ref_read( + remote=tensorRemote, + cache_var=cache, + local_offset=offset) + return result + ## 鎻愪緵remote write鍔熻兘 # @param remote 鍐欏叆杩滅鍐呭瓨淇℃伅锛宻hape(index_num, 3)锛歔u64 remoteId, u64 remoteAddr, u64 dataLength] # @param local 鏈鍙戦佸唴瀛 @@ -123,4 +134,15 @@ def remote_write(tensorRemote, tensorLocal, data_type): result = gen_hccl_ops.hcom_remote_write( remote=tensorRemote, local=tensorLocal) + return result + +##鎻愪緵remote scatter write鍔熻兘 +# @param tensorRemote 鍐欏叆杩滅鍐呭瓨淇℃伅锛宻hape(index_num, 3)锛歔u64 remoteId, u64 remoteAddr, u64 dataLength] +# @param tensorLocal 鏈鍙戦佸唴瀛樺熀鍦板潃 +# @param offset 杩涜璺冲啓鐨勬闀 +def remote_scatter_write(tensorRemote, tensorLocal, offset): + result = gen_hccl_ops.hcom_remote_write( + remote=tensorRemote, + local=tensorLocal, + local_offset=offset) return result \ No newline at end of file -- Gitee From d88014995c395e8047b74af4b381fab26ddbbb3b Mon Sep 17 00:00:00 2001 From: cclworkaccount <8266062+cclworkaccount@user.noreply.gitee.com> Date: Wed, 4 Nov 2020 20:05:30 +0800 Subject: [PATCH 6/7] update tf_adapter/kernels/hccl_ops.cc. --- tf_adapter/kernels/hccl_ops.cc | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tf_adapter/kernels/hccl_ops.cc b/tf_adapter/kernels/hccl_ops.cc index 30f8d9dff..2c28f2031 100644 --- a/tf_adapter/kernels/hccl_ops.cc +++ b/tf_adapter/kernels/hccl_ops.cc @@ -94,6 +94,18 @@ public: REGISTER_KERNEL_BUILDER(Name("HcomRemoteRead").Device(DEVICE_CPU), HcomRemoteReadOpKernel); +class HcomRemoteRefReadOpKernel : public OpKernel { +public: + explicit HcomRemoteRefReadOpKernel(OpKernelConstruction* context) : OpKernel(context) {} + ~HcomRemoteRefReadOpKernel() {} + void Compute(OpKernelContext* context) override + { + LOG(INFO) << "HcomRemoteRefRead Compute."; + } +}; + +REGISTER_KERNEL_BUILDER(Name("HcomRemoteRefRead").Device(DEVICE_CPU), HcomRemoteRefReadOpKernel); + class HcomRemoteWriteKernel : public OpKernel { public: explicit HcomRemoteWriteKernel(OpKernelConstruction* context) : OpKernel(context) {} @@ -105,4 +117,17 @@ public: }; REGISTER_KERNEL_BUILDER(Name("HcomRemoteWrite").Device(DEVICE_CPU), HcomRemoteWriteKernel); -} // namespace tensorflow + +class HcomRemoteScatterWriteOpKernel : public OpKernel { +public: + explicit HcomRemoteScatterWriteOpKernel(OpKernelConstruction* context) : OpKernel(context) {} + ~HcomRemoteScatterWriteOpKernel() {} + void Compute(OpKernelContext* context) override + { + LOG(INFO) << "HcomRemoteScatterWrite Compute."; + } +}; + +REGISTER_KERNEL_BUILDER(Name("HcomRemoteScatterWrite").Device(DEVICE_CPU), HcomRemoteScatterWriteOpKernel); + +} // namespace tensorflow \ No newline at end of file -- Gitee From 0320ccff5feeba9e1607d79ceab14227b96f4385 Mon Sep 17 00:00:00 2001 From: cclworkaccount <8266062+cclworkaccount@user.noreply.gitee.com> Date: Wed, 4 Nov 2020 20:16:41 +0800 Subject: [PATCH 7/7] update tf_adapter/interface_spec/api_hccl_ops.pyh. --- tf_adapter/interface_spec/api_hccl_ops.pyh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tf_adapter/interface_spec/api_hccl_ops.pyh b/tf_adapter/interface_spec/api_hccl_ops.pyh index b854d818a..58f094595 100644 --- a/tf_adapter/interface_spec/api_hccl_ops.pyh +++ b/tf_adapter/interface_spec/api_hccl_ops.pyh @@ -5,5 +5,7 @@ def broadcast(tensor, root_rank, fusion=0, fusion_id=-1, group="hccl_world_group def reduce_scatter(tensor, reduction, rank_size, group="hccl_world_group"): def send(tensor, sr_tag, dest_rank, group="hccl_world_group"): def receive(shape, data_type, sr_tag, src_rank, group="hccl_world_group"): +def remote_ref_read(tensorRemote, cache, offset): +def remote_scatter_write(tensorRemote, tensorLocal, offset): -- Gitee