diff --git a/test/dynamo/test_comm_converter.py b/test/dynamo/test_comm_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..65272ee483a5f4c7bf9c586a7c53120ba58844a2 --- /dev/null +++ b/test/dynamo/test_comm_converter.py @@ -0,0 +1,126 @@ +import os +from copy import deepcopy + +import torch + +import torch.distributed as dist +from torch import nn +import torch.distributed +import torch.multiprocessing as mp +import torch.distributed._functional_collectives as fcol + +from torch._dynamo.test_case import TestCase +from torch._dynamo.testing import normalize_gm + +import torch_npu + + +DIM = 200 + + +class Net(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(DIM, DIM) + + def forward(self, x, group): + _fc1 = self.fc1(x) + _fc1 = fcol.all_reduce(_fc1, "sum", group=group) + _fc1 = fcol.reduce_scatter_tensor(_fc1, "sum", scatter_dim=0, group=group) + _fc1 = fcol.all_gather_tensor(_fc1, 0, group=group) + _fc1 = fcol.all_to_all_single(_fc1, None, None, group=[0, 1]) + _fc1 = _fc1.reshape(2, -1)[0] + return _fc1 + + +def _test_compile( + rank, + world_size, +): + backend = "hccl" + dist.init_process_group( + backend=backend, + rank=rank, + world_size=world_size + ) + + graph = None + + def compiler_fn(gm): + def inner_compiler(gm_, example_inputs_): + nonlocal graph + if graph is not None: + raise AssertionError('TestCommConverter Failed, before run, graph should be None') + graph = gm_ + graph = normalize_gm(graph.print_readable(False)) + import torchair + return torchair.get_npu_backend()(gm_, example_inputs_) + + return torch.compile( + gm, backend=inner_compiler, dynamic=False, fullgraph=True + ) + + torch_npu.npu.set_device(f"npu:{rank}") + device = torch.device("npu") + torch.manual_seed(123) + model = Net().to(device) + + compiled_model = compiler_fn(deepcopy(model)) + group = torch.distributed.distributed_c10d._get_default_group() + ret = [] + for i in range(3): + torch.manual_seed(123 + rank + i) + input_tensor = torch.randn([DIM, DIM], device=device) + compiled_output = compiled_model(input_tensor, group) + loss_output = model(input_tensor, group) + expect = """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_ : torch.Tensor): + l_x_ = L_x_ + + _fc1 = self.L__self___fc1(l_x_); l_x_ = None + + tensor = torch.ops.c10d_functional.all_reduce(_fc1, 'sum', 'ptd:0', [0, 1], 2); _fc1 = None + + _fc1_1 = torch.ops.c10d_functional.wait_tensor(tensor); tensor = None + + tensor_1 = torch.ops.c10d_functional.reduce_scatter_tensor(_fc1_1, 'sum', 'ptd:0', [0, 1], 2); _fc1_1 = None + + _fc1_2 = torch.ops.c10d_functional.wait_tensor(tensor_1); tensor_1 = None + + tensor_2 = torch.ops.c10d_functional.all_gather_into_tensor(_fc1_2, 'ptd:0', [0, 1], 2); _fc1_2 = None + + _fc1_3 = torch.ops.c10d_functional.wait_tensor(tensor_2); tensor_2 = None + + tensor_3 = torch.ops.c10d_functional.all_to_all_single(_fc1_3, None, None, '', [0, 1], 2); _fc1_3 = None + + _fc1_4 = torch.ops.c10d_functional.wait_tensor(tensor_3); tensor_3 = None + + reshape = _fc1_4.reshape(2, -1); _fc1_4 = None + _fc1_5 = reshape[0]; reshape = None + return (_fc1_5,) +""" + if expect != graph: + raise RuntimeError('TestCommConverter Failed, fx graph is not expected') + if not (compiled_output == loss_output).all(): + raise RuntimeError('TestCommConverter Failed, dynamo outputs are not equal to eager outputs') + + +def mp_main(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + os.environ['TORCH_DISABLE_NATIVE_FUNCOL'] = '1' + _test_compile(rank=rank, world_size=world_size) + + +class TestCommConverter(TestCase): + def test_comm_converter(self): + world_size = 1 + mp.spawn(mp_main, args=(world_size,), nprocs=world_size, join=True) + + +if __name__ == "__main__": + + from torch._dynamo.test_case import run_tests + + run_tests()