From 536b2aeb61190de7d13229b2abbcee884bf630b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=A4=E5=AE=89=E5=8D=87?= Date: Thu, 10 Mar 2022 09:48:01 +0800 Subject: [PATCH 1/2] Add test_c10d. --- test/test_c10d.py | 190 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 190 insertions(+) create mode 100644 test/test_c10d.py diff --git a/test/test_c10d.py b/test/test_c10d.py new file mode 100644 index 0000000000..bf15428a5d --- /dev/null +++ b/test/test_c10d.py @@ -0,0 +1,190 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import IntEnum, unique + +import os +import unittest +import torch +import torch_npu +import torch.distributed as c10d +import torch.distributed as dist +import torch.multiprocessing as mp + +from torch_npu.testing.testcase import TestCase, run_tests + + +@unique +class Format(IntEnum): + NCHW = 0 + ND = 2 + NC1HWC0 = 3 + NZ = 29 + + +class ProcessGroupHCCLTest(TestCase): + + world_size = 2 + + def setUp(self): + if torch_npu.npu.device_count() < 2: + raise unittest.SkipTest("HCCL test requires 2+ NPUs") + + @classmethod + def _init_pg_hccl(cls, rank, world_size): + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29500' + torch_npu.npu.set_device(rank) + dist.init_process_group(backend='hccl', world_size=world_size, rank=rank) + return dist.new_group([0, 1]) + + def _test_multiprocess(self, f, shared_tensors, init_pg, n_output): + ws = self.world_size + # file store will delete the test file on destruction + ctx = mp.get_context('spawn') + c2p = ctx.Queue(2) + p2c = ctx.Queue(2) + ps = [] + for i in range(ws): + p = ctx.Process( + target=f, + args=(i, shared_tensors, ws, init_pg, c2p, p2c)) + + p.start() + ps.append(p) + + for _ in range(ws * n_output): + pid, expected, result = c2p.get() + self.assertEqual( + expected, + result, + ( + "Expect rank {} to receive tensor {} but got {}." + ).format(pid, expected, result) + ) + + for _ in range(ws): + p2c.put(0) + + for p in ps: + p.join(2) + + # Why classmethod? multiprocessing cannot pickle TestCase subclass when in + # spawn mode. See https://bugs.python.org/issue33884. + @classmethod + def _test_broadcast_process( + cls, rank, shared_tensors, world_size, init_pg, c2p, p2c): + pg = init_pg(rank, world_size) + xs = [shared_tensors[rank].to(f"npu:{rank}")] + pg.broadcast(xs).wait() + c2p.put((rank, torch.zeros(2, 2), xs[0].to("cpu"))) + p2c.get() + + def test_shared_broadcast_hccl(self): + self._test_multiprocess( + ProcessGroupHCCLTest._test_broadcast_process, + [torch.ones(2, 2) * i for i in range(self.world_size)], + ProcessGroupHCCLTest._init_pg_hccl, + 1) + + @classmethod + def _test_allreduce_process( + cls, rank, shared_tensors, world_size, init_pg, c2p, p2c): + pg = init_pg(rank, world_size) + xs = [shared_tensors[rank].to(f"npu:{rank}")] + pg.allreduce(xs, op=c10d.ReduceOp.SUM).wait() + c2p.put((rank, torch.ones(2, 2) * 2, xs[0].to("cpu"))) + p2c.get() + + def test_shared_allreduce_hccl(self): + self._test_multiprocess( + ProcessGroupHCCLTest._test_allreduce_process, + [torch.ones(2, 2) for i in range(self.world_size)], + ProcessGroupHCCLTest._init_pg_hccl, + 1) + + @classmethod + def _test_allgather_process( + cls, rank, shared_tensors, world_size, init_pg, c2p, p2c): + pg = init_pg(rank, world_size) + xs = [shared_tensors[rank].to(f"npu:{rank}")] + ys = [[torch.zeros_like(xs[0]) for i in range(world_size)]] + pg.allgather(ys, xs).wait() + for i in range(world_size): + c2p.put((rank, torch.ones(2, 2) * i, ys[0][i].to("cpu"))) + + p2c.get() + + def test_shared_allgather_hccl(self): + self._test_multiprocess( + ProcessGroupHCCLTest._test_allgather_process, + [torch.ones(2, 2) * i for i in range(self.world_size)], + ProcessGroupHCCLTest._init_pg_hccl, + self.world_size) + + +class ComputeBucketAssignmentTest(TestCase): + def test_single_limit_single_dtype(self): + tensors = [ + torch.empty([100], dtype=torch.float).npu().npu_format_cast(Format.NZ), + torch.empty([200], dtype=torch.float).npu(), + torch.empty([100], dtype=torch.float).npu(), + torch.empty([50], dtype=torch.float).npu(), + ] + result = dist._compute_bucket_assignment_by_size(tensors, [1792 * 4 + 1]) + self.assertEqual([[0, 1, 2, 3]], result) + + def test_single_limit_multi_dtype(self): + tensors = [ + torch.empty([50], dtype=torch.float).npu(), + torch.empty([25], dtype=torch.double).npu(), + torch.empty([50], dtype=torch.float).npu(), + torch.empty([25], dtype=torch.double).npu(), + torch.empty([50], dtype=torch.float).npu(), + torch.empty([25], dtype=torch.double).npu(), + ] + result = dist._compute_bucket_assignment_by_size(tensors, [400]) + self.assertEqual([[0, 2], [1, 3], [4], [5]], result) + + def test_multi_limit_single_dtype(self): + tensors = [ + torch.empty([10], dtype=torch.float).npu(), + torch.empty([10], dtype=torch.float).npu(), + torch.empty([10], dtype=torch.float).npu(), + torch.empty([10], dtype=torch.float).npu(), + ] + result = dist._compute_bucket_assignment_by_size(tensors, [40, 80]) + self.assertEqual([[0], [1, 2], [3]], result) + + def test_multi_limit_multi_dtype(self): + tensors = [ + torch.empty([50], dtype=torch.float).npu(), + torch.empty([25], dtype=torch.double).npu(), + + torch.empty([50], dtype=torch.float).npu(), + torch.empty([25], dtype=torch.double).npu(), + torch.empty([50], dtype=torch.float).npu(), + torch.empty([25], dtype=torch.double).npu(), + + torch.empty([50], dtype=torch.float).npu(), + torch.empty([25], dtype=torch.double).npu(), + torch.empty([50], dtype=torch.float).npu(), + torch.empty([25], dtype=torch.double).npu(), + ] + result = dist._compute_bucket_assignment_by_size(tensors, [200, 400]) + self.assertEqual([[0], [1], [2, 4], [3, 5], [6, 8], [7, 9]], result) + + +if __name__ == '__main__': + run_tests() -- Gitee From f803a145c77c5fa0b51906224b0d9944d85c82c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=A4=E5=AE=89=E5=8D=87?= Date: Thu, 10 Mar 2022 09:48:07 +0800 Subject: [PATCH 2/2] Rename API supported list. --- CONTRIBUTING.zh.md | 2 +- .../en/{PyTorch 1.5.0 API Support.md => PyTorch API Support.md} | 0 ...orch API\346\224\257\346\214\201\346\270\205\345\215\225.md" | 0 3 files changed, 1 insertion(+), 1 deletion(-) rename docs/en/{PyTorch 1.5.0 API Support.md => PyTorch API Support.md} (100%) rename "docs/zh/PyTorch API\346\224\257\346\214\201\346\270\205\345\215\225_1.5.0.md" => "docs/zh/PyTorch API\346\224\257\346\214\201\346\270\205\345\215\225.md" (100%) diff --git a/CONTRIBUTING.zh.md b/CONTRIBUTING.zh.md index 29e28e1e53..bf0136400f 100644 --- a/CONTRIBUTING.zh.md +++ b/CONTRIBUTING.zh.md @@ -99,7 +99,7 @@ 2. 设置环境变量。 - 进入"pytorch/src"路径,并执行env.sh脚本。 + 进入"pytorch"根目录,并执行env.sh脚本。 ``` bash env.sh diff --git a/docs/en/PyTorch 1.5.0 API Support.md b/docs/en/PyTorch API Support.md similarity index 100% rename from docs/en/PyTorch 1.5.0 API Support.md rename to docs/en/PyTorch API Support.md diff --git "a/docs/zh/PyTorch API\346\224\257\346\214\201\346\270\205\345\215\225_1.5.0.md" "b/docs/zh/PyTorch API\346\224\257\346\214\201\346\270\205\345\215\225.md" similarity index 100% rename from "docs/zh/PyTorch API\346\224\257\346\214\201\346\270\205\345\215\225_1.5.0.md" rename to "docs/zh/PyTorch API\346\224\257\346\214\201\346\270\205\345\215\225.md" -- Gitee