diff --git a/test/test_distributed/run_test.py b/test/test_distributed/run_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8032a89b6ec96907d69fbfaa14b668a12ac78e61 --- /dev/null +++ b/test/test_distributed/run_test.py @@ -0,0 +1,477 @@ +#!/usr/bin/env python + +from __future__ import print_function + +import argparse +from datetime import datetime +import modulefinder +import os +import shutil +import signal +import subprocess +import sys +import tempfile +import time +import unittest + +import torch +import torch_npu +import torch._six +from torch.utils import cpp_extension +from torch.testing._internal.common_utils import TEST_WITH_ROCM, shell +import torch.distributed as dist +PY2 = sys.version_info <= (3,) +PY33 = sys.version_info >= (3, 3) +PY36 = sys.version_info >= (3, 6) + +TESTS = [ + 'test_distributed' +] + +# skip < 3.3 because mock is added in 3.3 and is used in rpc_spawn +# skip python2 for rpc and dist_autograd tests that do not support python2 +if PY33: + TESTS.extend([ + 'distributed/rpc/test_rpc_spawn', + 'distributed/rpc/test_dist_autograd_spawn', + 'distributed/rpc/test_dist_optimizer_spawn', + 'distributed/rpc/jit/test_dist_autograd_spawn', + ]) + +# skip < 3.6 b/c fstrings added in 3.6 +if PY36: + TESTS.extend([ + 'test_jit_py3', + 'test_determination', + 'distributed/rpc/jit/test_rpc_spawn', + ]) + +WINDOWS_BLACKLIST = [ + 'distributed/test_distributed', + 'distributed/rpc/test_rpc_spawn', + 'distributed/rpc/test_dist_autograd_spawn', + 'distributed/rpc/test_dist_optimizer_spawn', + 'distributed/rpc/jit/test_rpc_spawn', + 'distributed/rpc/jit/test_dist_autograd_spawn', +] + +ROCM_BLACKLIST = [ + 'test_cpp_extensions_aot_ninja', + 'test_cpp_extensions_jit', + 'test_multiprocessing', + 'distributed/rpc/test_rpc_spawn', + 'distributed/rpc/test_dist_autograd_spawn', + 'distributed/rpc/test_dist_optimizer_spawn', + 'distributed/rpc/jit/test_rpc_spawn', + 'distributed/rpc/jit/test_dist_autograd_spawn', + 'test_determination', +] + +DISTRIBUTED_TESTS_CONFIG = {} + + +if dist.is_available(): + if dist.is_hccl_available(): + DISTRIBUTED_TESTS_CONFIG['hccl'] = { + 'WORLD_SIZE': '2' if torch.npu.device_count() == 2 else '4', + 'TEST_REPORT_SOURCE_OVERRIDE': 'dist-hccl' + } + else: + if not TEST_WITH_ROCM and dist.is_mpi_available(): + DISTRIBUTED_TESTS_CONFIG['mpi'] = { + 'WORLD_SIZE': '3', + 'TEST_REPORT_SOURCE_OVERRIDE': 'dist-mpi' + } + if dist.is_nccl_available(): + DISTRIBUTED_TESTS_CONFIG['nccl'] = { + 'WORLD_SIZE': '2' if torch.cuda.device_count() == 2 else '3', + 'TEST_REPORT_SOURCE_OVERRIDE': 'dist-nccl' + } + if not TEST_WITH_ROCM and dist.is_gloo_available(): + DISTRIBUTED_TESTS_CONFIG['gloo'] = { + 'WORLD_SIZE': '2' if torch.cuda.device_count() == 2 else '3', + 'TEST_REPORT_SOURCE_OVERRIDE': 'dist-gloo' + } + +# https://stackoverflow.com/questions/2549939/get-signal-names-from-numbers-in-python +SIGNALS_TO_NAMES_DICT = {getattr(signal, n): n for n in dir(signal) + if n.startswith('SIG') and '_' not in n} + + +def print_to_stderr(message): + print(message, file=sys.stderr) + + +def run_test(executable, test_module, test_directory, options, *extra_unittest_args): + unittest_args = options.additional_unittest_args + if options.verbose: + unittest_args.append('--verbose') + # Can't call `python -m unittest test_*` here because it doesn't run code + # in `if __name__ == '__main__': `. So call `python test_*.py` instead. + argv = [test_module + '.py'] + unittest_args + list(extra_unittest_args) + + command = executable + argv + return shell(command, test_directory) + +def test_distributed_npu(executable, test_module, test_directory, options): + config = DISTRIBUTED_TESTS_CONFIG + for backend, env_vars in config.items(): + for with_init_file in {True, False}: + tmp_dir = tempfile.mkdtemp() + if options.verbose: + with_init = ' with file init_method' if with_init_file else '' + print_to_stderr( + 'Running distributed tests for the {} backend{}'.format( + backend, with_init)) + os.environ['TEMP_DIR'] = tmp_dir + os.environ['BACKEND'] = backend + os.environ['INIT_METHOD'] = 'env://' + os.environ.update(env_vars) + if with_init_file: + init_method = 'file://{}/shared_init_file'.format(tmp_dir) + os.environ['INIT_METHOD'] = init_method + try: + os.mkdir(os.path.join(tmp_dir, 'barrier')) + os.mkdir(os.path.join(tmp_dir, 'test_dir')) + return_code = run_test(executable, test_module, test_directory, + options) + if return_code != 0: + return return_code + finally: + shutil.rmtree(tmp_dir) + return 0 + +CUSTOM_HANDLERS = { + 'test_distributed': test_distributed_npu +} + + +def parse_test_module(test): + return test.split('.')[0] + + +class TestChoices(list): + def __init__(self, *args, **kwargs): + super(TestChoices, self).__init__(args[0]) + + def __contains__(self, item): + return list.__contains__(self, parse_test_module(item)) + +FAILURE_FILE_NAME = 'pytorch_org_failures.txt' +ERROR_FILE_NAME = 'pytorch_org_errors.txt' +def htmlReport_load_failure_error_cases(file_name): + data = [] + if os.path.isfile(file_name): + with open(file_name, 'r') as f: + lines = f.readlines() + for line in lines: + temp = line.strip('\n').strip('\t') + data.append(temp) + else: + print("Invlid filename:",file_name) + return data + +def htmlReport_analyse_failure_error_cases(result): + new_failures = [] + new_errors = [] + + if len(result.failures) > 0: + print("====================================== failed cases count: ", len(result.failures)) + for failure in result.failures: + print(failure[0]) + print("============================================================\n") + orig_failures = htmlReport_load_failure_error_cases(FAILURE_FILE_NAME) + for failure in result.failures: + if str(failure[0]) not in orig_failures: + new_failures.append(str(failure[0])) + + if len(result.errors) > 0: + print("====================================== error cases count: ", len(result.errors)) + for error_case in result.errors: + print(error_case[0]) + print("============================================================\n") + orig_errors = htmlReport_load_failure_error_cases(ERROR_FILE_NAME) + for error_case in result.errors: + if str(error_case[0]) not in orig_errors: + new_errors.append(str(error_case[0])) + print("====================================== new failed cases count: ", len(new_failures)) + for case in new_failures: + print(case) + print("====================================== new error cases count: ", len(new_errors)) + for case in new_errors: + print(case) + return new_failures, new_errors + +def htmlReport_RunTests(suite): + + ENABLE_HTML = bool(os.environ.get('ENABLE_HTML')) + ENABLE_HTML_MX = bool(os.environ.get('ENABLE_HTML_MX')) + ENABLE_CASE_PATH = os.environ.get('ENABLE_CASE_PATH') + ENABLE_OUTPUT_PATH = os.environ.get('ENABLE_OUTPUT_PATH') + WHITE_LIST_PATH = os.environ.get('WHITE_LIST_PATH') + + test_case_path = './' + if ENABLE_CASE_PATH is not None: + if not os.path.exists(ENABLE_CASE_PATH): + print('path is not exists: ', ENABLE_CASE_PATH) + else: + test_case_path = ENABLE_CASE_PATH + + test_report_path = test_case_path+'ReportResult' + + if ENABLE_OUTPUT_PATH is not None: + if not os.path.exists(ENABLE_OUTPUT_PATH): + print('path is not exists: ', ENABLE_OUTPUT_PATH) + else: + test_report_path = ENABLE_OUTPUT_PATH + + if not os.path.exists(test_report_path): + os.mkdir(test_report_path) + print(test_report_path) + + now = time.strftime("%Y_%m_%d_%H_%M_%S") + htmlFileName = os.path.join(test_report_path, 'pytorch-unittest-report-'+now+'.html') + txtFileName = os.path.join(test_report_path, 'pytorch-unittest-report-'+now+'.txt') + + print('start pytorch HTML unittest testset...') + import HTMLTestRunner + with os.fdopen(os.open(htmlFileName, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode=0o600), "wb") as report_file: + runner = HTMLTestRunner.HTMLTestRunner( + stream=report_file, title='AllTest', description='all npu test case', verbosity=2) + result = runner.run(suite) + new_failures, new_errors = htmlReport_analyse_failure_error_cases(result) + if len(new_failures) + len(new_errors) > 0: + print(" RuntimeError: new error or failed cases found!") + print('report files path', htmlFileName) + +def parse_args(): + parser = argparse.ArgumentParser( + description='Run the PyTorch unit test suite', + epilog='where TESTS is any of: {}'.format(', '.join(TESTS))) + parser.add_argument( + '--error-continue', + action='store_true', + help='run test continue when error or failure.') + parser.add_argument( + '--html-test-runner', + action='store_true', + help='run test case by HTML Test Runner.') + parser.add_argument( + '-v', + '--verbose', + action='store_true', + help='print verbose information and test-by-test results') + parser.add_argument( + '--jit', + '--jit', + action='store_true', + help='run all jit tests') + parser.add_argument( + '-pt', '--pytest', action='store_true', + help='If true, use `pytest` to execute the tests. E.g., this runs ' + 'TestTorch with pytest in verbose and coverage mode: ' + 'python run_test.py -vci torch -pt') + parser.add_argument( + '-c', '--coverage', action='store_true', help='enable coverage') + parser.add_argument( + '-i', + '--include', + nargs='+', + choices=TestChoices(TESTS), + default=TESTS, + metavar='TESTS', + help='select a set of tests to include (defaults to ALL tests).' + ' tests can be specified with module name, module.TestClass' + ' or module.TestClass.test_method') + parser.add_argument( + '-x', + '--exclude', + nargs='+', + choices=TESTS, + metavar='TESTS', + default=[], + help='select a set of tests to exclude') + parser.add_argument( + '-f', + '--first', + choices=TESTS, + metavar='TESTS', + help='select the test to start from (excludes previous tests)') + parser.add_argument( + '-l', + '--last', + choices=TESTS, + metavar='TESTS', + help='select the last test to run (excludes following tests)') + parser.add_argument( + '--bring-to-front', + nargs='+', + choices=TestChoices(TESTS), + default=[], + metavar='TESTS', + help='select a set of tests to run first. This can be used in situations' + ' where you want to run all tests, but care more about some set, ' + 'e.g. after making a change to a specific component') + parser.add_argument( + '--ignore-win-blacklist', + action='store_true', + help='always run blacklisted windows tests') + parser.add_argument( + '--determine-from', + help='File of affected source filenames to determine which tests to run.') + parser.add_argument( + 'additional_unittest_args', + nargs='*', + help='additional arguments passed through to unittest, e.g., ' + 'python run_test.py -i sparse -- TestSparse.test_factory_size_check') + return parser.parse_args() + + +def get_executable_command(options): + if options.coverage: + executable = ['coverage', 'run', '--parallel-mode', '--source torch'] + else: + executable = [sys.executable] + if options.pytest: + executable += ['-m', 'pytest'] + return executable + + +def find_test_index(test, selected_tests, find_last_index=False): + """Find the index of the first or last occurrence of a given test/test module in the list of selected tests. + + This function is used to determine the indices when slicing the list of selected tests when + ``options.first``(:attr:`find_last_index`=False) and/or ``options.last``(:attr:`find_last_index`=True) are used. + + :attr:`selected_tests` can be a list that contains multiple consequent occurrences of tests + as part of the same test module, e.g.: + + ``` + selected_tests = ['autograd', 'cuda', **'torch.TestTorch.test_acos', + 'torch.TestTorch.test_tan', 'torch.TestTorch.test_add'**, 'utils'] + ``` + + If :attr:`test`='torch' and :attr:`find_last_index`=False, result should be **2**. + If :attr:`test`='torch' and :attr:`find_last_index`=True, result should be **4**. + + Arguments: + test (str): Name of test to lookup + selected_tests (list): List of tests + find_last_index (bool, optional): should we lookup the index of first or last + occurrence (first is default) + + Returns: + index of the first or last occurrence of the given test + """ + idx = 0 + found_idx = -1 + for t in selected_tests: + if t.startswith(test): + found_idx = idx + if not find_last_index: + break + idx += 1 + return found_idx + + +def exclude_tests(exclude_list, selected_tests, exclude_message=None): + for exclude_test in exclude_list: + tests_copy = selected_tests[:] + for test in tests_copy: + if test.startswith(exclude_test): + if exclude_message is not None: + print_to_stderr('Excluding {} {}'.format(test, exclude_message)) + selected_tests.remove(test) + return selected_tests + + +def get_selected_tests(options): + selected_tests = options.include + + if options.bring_to_front: + to_front = set(options.bring_to_front) + selected_tests = options.bring_to_front + list(filter(lambda name: name not in to_front, + selected_tests)) + + if options.first: + first_index = find_test_index(options.first, selected_tests) + selected_tests = selected_tests[first_index:] + + if options.last: + last_index = find_test_index(options.last, selected_tests, find_last_index=True) + selected_tests = selected_tests[:last_index + 1] + + selected_tests = exclude_tests(options.exclude, selected_tests) + + if sys.platform == 'win32' and not options.ignore_win_blacklist: + target_arch = os.environ.get('VSCMD_ARG_TGT_ARCH') + if target_arch != 'x64': + WINDOWS_BLACKLIST.append('cpp_extensions_aot_no_ninja') + WINDOWS_BLACKLIST.append('cpp_extensions_aot_ninja') + WINDOWS_BLACKLIST.append('cpp_extensions_jit') + WINDOWS_BLACKLIST.append('jit') + WINDOWS_BLACKLIST.append('jit_fuser') + + selected_tests = exclude_tests(WINDOWS_BLACKLIST, selected_tests, 'on Windows') + + elif TEST_WITH_ROCM: + selected_tests = exclude_tests(ROCM_BLACKLIST, selected_tests, 'on ROCm') + + return selected_tests + +def main(): + options = parse_args() + executable = get_executable_command(options) # this is a list + print_to_stderr('Test executor: {}'.format(executable)) + test_directory = os.path.dirname(os.path.abspath(__file__)) + selected_tests = get_selected_tests(options) + + if options.verbose: + print_to_stderr('Selected tests: {}'.format(', '.join(selected_tests))) + + if options.coverage: + shell(['coverage', 'erase']) + + if options.jit: + selected_tests = filter(lambda test_name: "jit" in test_name, TESTS) + + if options.determine_from is not None and os.path.exists(options.determine_from): + pass + + htmlReport_suite = unittest.TestSuite() + htmlReport_loader = unittest.TestLoader() + + for test in selected_tests: + + test_module = parse_test_module(test) + + # Printing the date here can help diagnose which tests are slow + print_to_stderr('Running {} ... [{}]'.format(test, datetime.now())) + handler = CUSTOM_HANDLERS.get(test, run_test) + if options.html_test_runner: + testfileName = test_module + '.py' + testCase = unittest.defaultTestLoader.discover("./", pattern=testfileName) + + rtn = htmlReport_suite.addTest(testCase) + else: + return_code = handler(executable, test_module, test_directory, options) + assert isinstance(return_code, int) and not isinstance( + return_code, bool), 'Return code should be an integer' + if return_code != 0: + message = '{} failed!'.format(test) + if return_code < 0: + # subprocess.Popen returns the child process' exit signal as + # return code -N, where N is the signal number. + signal_name = SIGNALS_TO_NAMES_DICT[-return_code] + message += ' Received signal: {}'.format(signal_name) + if not options.error_continue: + raise RuntimeError(message) + if options.html_test_runner: + htmlReport_RunTests(htmlReport_suite) + if options.coverage: + shell(['coverage', 'combine']) + shell(['coverage', 'html']) + + +if __name__ == '__main__': + main() diff --git a/test/test_distributed/test_distributed.py b/test/test_distributed/test_distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..2de8c5214f12c8e88d8baecc95eacf5ea3ad3d94 --- /dev/null +++ b/test/test_distributed/test_distributed.py @@ -0,0 +1,775 @@ +from __future__ import absolute_import, division, print_function, unicode_literals +import copy +import fcntl +import os +import sys +import time +import tempfile +import unittest +import logging +import traceback +from contextlib import contextmanager +from datetime import timedelta +from functools import reduce, wraps +import types +from collections import namedtuple +from multiprocessing import Manager + +import torch +import torch_npu +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch.testing._internal.common_utils import TestCase, run_tests +from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR +from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT + +try: + import torchvision + HAS_TORCHVISION = True +except ImportError: + HAS_TORCHVISION = False + +skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") + +BACKEND = os.environ["BACKEND"] +TEMP_DIR = os.environ["TEMP_DIR"] +INIT_METHOD = os.getenv("INIT_METHOD", "env://") + +DEFAULT_TIMEOUT = 300 +CUSTOMIZED_TIMEOUT = {"test_DistributedDataParallel": 500} + +TestSkip = namedtuple('TestSkip', 'exit_code, message') +TEST_SKIPS = { + "multi-npu": TestSkip(75, "Need at least 2 ASCEND devices"), + "hccl": TestSkip(76, "c10d not compiled with HCCL support"), + "known_issues": TestSkip(77, "Test skipped due to known issues"), +} + +class _FC2(nn.Module): + def __init__(self): + super(_FC2, self).__init__() + self.fc = nn.Linear(10, 50, bias=True) + self.fc.bias.requires_grad = False + + def forward(self, x): + x = self.fc(x) + return x + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.fc1 = nn.Linear(2, 10, bias=False) + self.fc2 = _FC2() + self.fc3 = nn.Linear(50, 4, bias=False) + self.relu = nn.ReLU() + self.no_grad_param = nn.Parameter(torch.tensor([2, 2]).long(), + requires_grad=False) + + def forward(self, x): + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + x = self.fc3(x) + return F.softmax(x, dim=1) + + +class BatchNormNet(nn.Module): + + def __init__(self): + super(BatchNormNet, self).__init__() + self.fc1 = nn.Linear(2, 40, bias=False) + self.bn = nn.BatchNorm1d(4) + self.fc2 = nn.Linear(40, 4, bias=False) + + def forward(self, x): + x = torch.reshape(self.fc1(x), (-1, 4, 10)) + x = self.bn(x) + x = torch.reshape(x, (-1, 40)) + x = self.fc2(x) + return F.softmax(x, dim=1) + +DDP_NET = Net() +BN_NET = BatchNormNet() +ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99) + +def get_timeout(test_id): + test_name = test_id.split(".")[-1] + if test_name in CUSTOMIZED_TIMEOUT: + return CUSTOMIZED_TIMEOUT[test_name] + else: + return DEFAULT_TIMEOUT + + +if not dist.is_available(): + print("Distributed not available, skipping tests") + sys.exit(0) + +@contextmanager +def _lock(): + lockfile = os.path.join(TEMP_DIR, "lockfile") + with os.fdopen(os.open(lockfile, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode=0o600), "w") as lf: + try: + fcntl.flock(lf.fileno(), fcntl.LOCK_EX) + yield + finally: + fcntl.flock(lf.fileno(), fcntl.LOCK_UN) + lf.close() + +class Barrier(object): + barrier_id = 0 + + @classmethod + def init(cls): + cls.barrier_id = 0 + barrier_dir = os.path.join(TEMP_DIR, "barrier") + for f_name in os.listdir(barrier_dir): + os.unlink(os.path.join(barrier_dir, f_name)) + + @classmethod + def sync(cls, wait_for=None, timeout=5): + if wait_for is None: + wait_for = dist.get_world_size() + cls.barrier_id += 1 + barrier_dir = os.path.join(TEMP_DIR, "barrier") + pid = str(os.getpid()) + barrier_file = os.path.join(barrier_dir, pid) + with _lock(): + with os.fdopen(os.open(barrier_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode=0o600), "w") as f: + f.write(str(cls.barrier_id)) + + start_time = time.time() + while True: + arrived = 0 + with _lock(): + for f_name in os.listdir(barrier_dir): + with open(os.path.join(barrier_dir, f_name), "r") as f: + data = f.read() + if int(data) >= cls.barrier_id: + arrived += 1 + if arrived == wait_for: + break + + if time.time() - start_time > timeout: + raise RuntimeError("barrier timeout") + time.sleep(0.1) + +# [How does MultiProcessTestCase work?] +# Each MultiProcessTestCase instance uses 1 + `world_size()` processes, by +# default `world_size()` returns 4. Let's take `test_rpc_spawn.py` as an +# example which inherits from this class. Its `Setup()` methods calls into +# `MultiProcessTestCase._spawn_processes()` which spawns `world_size()` +# subprocesses. During the spawn, the main process passes the test name to +# subprocesses, and the name is acquired from self.id(). The subprocesses +# then use the provided test function name to retrieve the function attribute +# from the test instance and run it. The main process simply waits for all +# subprocesses to join. +class MultiProcessTestCase(TestCase): + MAIN_PROCESS_RANK = -1 + # This exit code is used to indicate that the test code had an error and + # exited abnormally. There are certain tests that might use sys.exit() to + # simulate failures and in those cases, we can't have an exit code of 0, + # but we still want to ensure we didn't run into any other errors. + TEST_ERROR_EXIT_CODE = 10 + + @property + def world_size(self): + return 4 + + def join_or_run(self, fn): + @wraps(fn) + def wrapper(self): + if self.rank == self.MAIN_PROCESS_RANK: + self._join_processes(fn) + else: + try: + fn() + except Exception as e: + logging.error('Caught exception: \n{}exiting process with exit code: {}' + .format(traceback.format_exc(), MultiProcessTestCase.TEST_ERROR_EXIT_CODE)) + sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE) + return types.MethodType(wrapper, self) + + # The main process spawns N subprocesses that run the test. + # Constructor patches current instance test method to + # assume the role of the main process and join its subprocesses, + # or run the underlying test function. + def __init__(self, method_name='runTest'): + super().__init__(method_name) + fn = getattr(self, method_name) + setattr(self, method_name, self.join_or_run(fn)) + + def setUp(self): + super().setUp() + self.skip_return_code_checks = [] + self.processes = [] + self.rank = self.MAIN_PROCESS_RANK + self.file_name = tempfile.NamedTemporaryFile(delete=False).name + global TEST_SKIPS + self.old_test_skips = TEST_SKIPS.copy() + + def tearDown(self): + super().tearDown() + for p in self.processes: + p.terminate() + # Each Process instance holds a few open file descriptors. The unittest + # runner creates a new TestCase instance for each test method and keeps + # it alive until the end of the entire suite. We must thus reset the + # processes to prevent an effective file descriptor leak. + self.processes = [] + + def _current_test_name(self): + # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank' + return self.id().split(".")[-1] + + def _start_processes(self, proc): + test_skips_manager = Manager() + test_skips = test_skips_manager.dict() + global TEST_SKIPS + test_skips.update(TEST_SKIPS) + TEST_SKIPS = test_skips + + self.processes = [] + for rank in range(int(self.world_size)): + process = proc( + target=self.__class__._run, + name='process ' + str(rank), + args=(rank, self._current_test_name(), self.file_name)) + process.start() + self.processes.append(process) + + def _fork_processes(self): + proc = torch.multiprocessing.get_context("fork").Process + self._start_processes(proc) + + def _spawn_processes(self): + proc = torch.multiprocessing.get_context("spawn").Process + self._start_processes(proc) + + @classmethod + def _run(cls, rank, test_name, file_name): + self = cls(test_name) + self.rank = rank + self.file_name = file_name + + # self.id() == e.g. '__main__.TestDistributed.test_get_rank' + # We're retrieving a corresponding test and executing it. + getattr(self, test_name)() + # exit to avoid run teardown() for fork processes + sys.exit(0) + + def _join_processes(self, fn): + timeout = get_timeout(self.id()) + start_time = time.time() + subprocess_error = False + try: + while True: + # check to see if any subprocess exited with an error early. + for (i, p) in enumerate(self.processes): + # This is the exit code processes exit with if they + # encountered an exception. + if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE: + print("Process {} terminated with exit code {}, terminating remaining processes.".format( + i, p.exitcode)) + active_children = torch.multiprocessing.active_children() + for ac in active_children: + ac.terminate() + subprocess_error = True + break + if subprocess_error: + break + # All processes have joined cleanly if they all a valid exitcode + if all([p.exitcode is not None for p in self.processes]): + break + # Check if we should time out the test. If so, we terminate each process. + elapsed = time.time() - start_time + if elapsed > timeout: + print( + "Timing out after {} seconds and killing subprocesses.".format( + timeout + ) + ) + for p in self.processes: + p.terminate() + break + # Sleep to avoid excessive busy polling. + time.sleep(0.1) + elapsed_time = time.time() - start_time + if fn in self.skip_return_code_checks: + self._check_no_test_errors(elapsed_time) + else: + self._check_return_codes(elapsed_time) + finally: + global TEST_SKIPS + TEST_SKIPS = self.old_test_skips + + def _check_no_test_errors(self, elapsed_time): + """ + Checks that we didn't have any errors thrown in the child processes. + """ + for i, p in enumerate(self.processes): + if p.exitcode is None: + raise RuntimeError('Process {} timed out after {} seconds'.format(i, elapsed_time)) + self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode) + + def _check_return_codes(self, elapsed_time): + """ + Checks that the return codes of all spawned processes match, and skips + tests if they returned a return code indicating a skipping condition. + """ + first_process = self.processes[0] + # first, we check if there are errors in actual processes + # (via TEST_ERROR_EXIT CODE), and raise an exception for those. + # the reason we do this is to attempt to raise a more helpful error + # message than "Process x terminated/timed out" + # TODO: we should pipe the exception of the failed subprocess here. + # Currently, the actual exception is displayed as a logging output. + errored_processes = [ + (i, p) + for i, p in enumerate(self.processes) + if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE + ] + if errored_processes: + error = "Processes {} exited with error code {}".format( + " ".join([str(i) for (i, _) in errored_processes]), + MultiProcessTestCase.TEST_ERROR_EXIT_CODE, + ) + raise RuntimeError(error) + # If no process exited uncleanly, we check for timeouts, and then ensure + # each process exited cleanly. + for i, p in enumerate(self.processes): + if p.exitcode is None: + raise RuntimeError('Process {} terminated or timed out after {} seconds'.format(i, elapsed_time)) + self.assertEqual( + p.exitcode, + first_process.exitcode + ) + for skip in TEST_SKIPS.values(): + if first_process.exitcode == skip.exit_code: + raise unittest.SkipTest(skip.message) + self.assertEqual( + first_process.exitcode, + 0 + ) + + @property + def is_master(self): + return self.rank == 0 + +class _DistTestBase(object): + def _barrier(self, *args, **kwargs): + Barrier.sync(*args, **kwargs) + + def _init_group_test(self, **kwargs): + group = [1, 2] + group_id = dist.new_group(group, **kwargs) + rank = dist.get_rank() + if rank not in group: + return ([], None, rank) + + return (group, group_id, rank) + + def _init_full_group_test(self, **kwargs): + group = list(range(0, dist.get_world_size())) + group_id = dist.new_group(**kwargs) + rank = dist.get_rank() + return (group, group_id, rank) + + def _init_global_test(self): + group = list(range(0, dist.get_world_size())) + group_id = dist.group.WORLD + rank = dist.get_rank() + return (group, group_id, rank) + + # HELPER FOR MULTINPU TESTS + def _init_multinpu_helper(self): + """Multinpu tests are designed to simulate the multi nodes with multi + NPUs on each node. Hccl backend requires one NPU device in each process. + """ + nNPUs = torch.npu.device_count() + world_size = dist.get_world_size() + visible_devices = range(min(nNPUs, world_size)) + + nNPUs_per_process = 1 + rank_to_NPU = { + i: list( + visible_devices[i * nNPUs_per_process: (i + 1) * nNPUs_per_process] + ) + for i in range(world_size) + } + return rank_to_NPU + + def _model_step(self, model): + for param in model.parameters(): + if param.grad is not None: + with torch.no_grad(): + param += param.grad + param.grad = None + + def _prepare_dummy_data(self, local_bs): + # global_bs for DDP should be divisible by WORLD_SIZE + global_bs = int(dist.get_world_size()) * local_bs + input_cpu = torch.randn(global_bs, 2) + target = torch.randn(global_bs, 4) + loss = nn.MSELoss() + return global_bs, input_cpu, target, loss + + # END TO END TEST FOR DISTRIBUTEDDATAPARALLEL + def _test_DDP_helper(self, model, input_var, target, loss, scale_factor=1.0): + model.train() + output = model(input_var) + l = loss(output, target) * scale_factor + l.backward() + + def _assert_equal_param(self, param_npu, param_DDP): + self.assertEqual(len(param_npu), len(param_DDP)) + for p_npu, p_DDP in zip(param_npu, param_DDP): + self.assertEqual(p_npu, p_DDP) + + def _test_DDP_5iter( + self, model_base, model_DDP, input_data, target, loss, local_bs, rank, batch_size, test_save, \ + offset=None, world_size=0 + ): + for idx in range(5): + # single cpu/npu training + self._test_DDP_helper(model_base, input_data, target, loss) + + if offset is None: + offset = rank * local_bs + + # DDP training, DDP scatters subsets of input_cpu to nodes/NPUs + self._test_DDP_helper( + model_DDP, + input_data[offset: offset + local_bs], + target[offset: offset + local_bs], + loss, + world_size * local_bs / batch_size if world_size != 0 else 1, + ) + + # Update weights and run a second iteration to shake out errors + self._model_step(model_base) + self._model_step(model_DDP) + self._assert_equal_param( + list(model_base.parameters()), list(model_DDP.module.parameters()) + ) + + # Shuffle the input so that DDP input is different + input_data = input_data[torch.randperm(batch_size)] + + # save the model in the middle and reload + if test_save and idx == 2 and INIT_METHOD.startswith("file://"): + with tempfile.NamedTemporaryFile() as tmp: + state = {'net': model_DDP.state_dict()} + torch.save(state, tmp.name) + checkpoint = torch.load(tmp.name) + model_DDP.load_state_dict(checkpoint['net']) + + with tempfile.TemporaryFile() as tmp_file: + state = {'net': model_DDP.state_dict()} + torch.save(state, tmp_file) + tmp_file.seek(0) + checkpoint = torch.load(tmp_file) + saved_model = copy.deepcopy(model_DDP) + saved_model.load_state_dict(checkpoint['net']) + for k in model_DDP.state_dict(): + self.assertEqual(model_DDP.state_dict()[k], + saved_model.state_dict()[k]) + + def _test_DistributedDataParallel(self, npu_subset, rank, output_device=None): + # Run a simple end to end DDP model, use result of single node model + # as baseline + + # cpu training setup + model = DDP_NET + + # single npu training setup + model_npu = copy.deepcopy(model) + model_npu.npu(npu_subset[0]) + + # DDP training setup + model_DDP = copy.deepcopy(model) + model_DDP.npu(npu_subset[0]) + + model_DDP = nn.parallel.DistributedDataParallel( + model_DDP, device_ids=npu_subset + ) + + # test serializable/unserializable + with tempfile.NamedTemporaryFile() as tmp: + state = {'net': model_DDP.state_dict()} + torch.save(state, tmp.name) + checkpoint = torch.load(tmp.name) + model_DDP.load_state_dict(checkpoint['net']) + + # dummy data initialization + local_bs = len(npu_subset) + global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs) + + # check two model parameters over 5 iterations + self._test_DDP_5iter( + model_npu, + model_DDP, + input_cpu.npu(npu_subset[0]), + target.npu(npu_subset[0]), + loss, + local_bs, + rank, + global_bs, + True + ) + self._barrier() + + def test_DistributedDataParallel_requires_grad(self): + # a module without gradients shouldn't be accepted + self.assertRaises(AssertionError, lambda: nn.parallel.DistributedDataParallel(nn.Module())) + + def test_DistributedDataParallel(self): + group, group_id, rank = self._init_global_test() + rank_to_NPU = self._init_multinpu_helper() + npus = list(rank_to_NPU[rank]) + self._test_DistributedDataParallel(npu_subset=npus, rank=rank) + + def _test_DistributedDataParallel_SyncBatchNorm( + self, npu_subset, rank, local_bs, global_bs, offset, output_device=None): + # Run a simple end to end DDP model, use result of single node model + # as baseline + + # cpu training setup + model = BN_NET + + # single npu training setup + model_npu = copy.deepcopy(model) + model_npu.npu(npu_subset[0]) + + # DDP training setup + model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model)) + model_DDP.npu(npu_subset[0]) + model_DDP = nn.parallel.DistributedDataParallel( + model_DDP, device_ids=npu_subset + ) + + # test serializable/unserializable + with tempfile.NamedTemporaryFile() as tmp: + state = {'net': model_DDP.state_dict()} + torch.save(state, tmp.name) + checkpoint = torch.load(tmp.name) + model_DDP.load_state_dict(checkpoint['net']) + + # data initialization + input_cpu = torch.randn(global_bs, 2) + target = torch.randn(global_bs, 4) + loss = nn.MSELoss() + + # check two model parameters over 5 iterations + self._test_DDP_5iter( + model_npu, + model_DDP, + input_cpu.npu(npu_subset[0]), + target.npu(npu_subset[0]), + loss, + local_bs, + rank, + global_bs, + True, + offset, + int(WORLD_SIZE) + ) + self._barrier() + + def test_DistributedDataParallel_SyncBatchNorm(self): + group, group_id, rank = self._init_global_test() + # DDP does not support replicating BN layers within a process, hence + # testing with one module replica per process + npus = [rank] + + num_processes = int(WORLD_SIZE) + local_bs = 2 + bs_offset = int(rank * 2) + global_bs = int(num_processes * 2) + + self._test_DistributedDataParallel_SyncBatchNorm( + npu_subset=npus, + rank=rank, + local_bs=local_bs, + global_bs=global_bs, + offset=bs_offset) + + def test_DistributedDataParallel_SyncBatchNorm_2D_Input(self): + group, group_id, rank = self._init_global_test() + # DDP does not support replicating BN layers within a process, hence + # testing with one module replica per process + npus = [rank] + + model = nn.BatchNorm1d(2) + + # single npu training setup + model_npu = copy.deepcopy(model) + model_npu.npu(npus[0]) + + # DDP training setup + model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model)) + model_DDP.npu(npus[0]) + model_DDP = nn.parallel.DistributedDataParallel( + model_DDP, device_ids=npus + ) + + local_bs = len(npus) * 2 + global_bs = int(WORLD_SIZE) * local_bs + input_cpu = torch.randn(global_bs, 2) + target = torch.randn(global_bs, 2) + loss = nn.MSELoss() + + # check two model parameters over 5 iterations + self._test_DDP_5iter( + model_npu, + model_DDP, + input_cpu.npu(npus[0]), + target.npu(npus[0]), + loss, + local_bs, + rank, + global_bs, + True + ) + self._barrier() + + def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value(self): + group, group_id, rank = self._init_global_test() + model = nn.parallel.DistributedDataParallel(ONLY_SBN_NET.npu(rank), device_ids=[rank]) + + input_var = [] + for i in range(int(WORLD_SIZE)): + input_var_rank = torch.cat([ + torch.ones(2, 1, 10 ** (i + 1)) * (0.1 ** (i - 1)), + torch.ones(2, 1, 10 ** (i + 1)) * (0.3 ** (i - 1)) + ], dim=1) + input_var.append(input_var_rank) + + all_input_var = torch.cat( + [x.permute(1, 0, 2).contiguous().view(ONLY_SBN_NET.num_features, -1) for x in input_var], + dim=1 + ).npu(rank) + + for i in range(5): + y = model(input_var[rank].npu(rank)) + y.mean().backward() + + running_mean, running_var = model.module.running_mean, model.module.running_var + torch.testing.assert_allclose(running_mean, all_input_var.mean(1)) + torch.testing.assert_allclose(running_var.cpu(), all_input_var.cpu().var(1, unbiased=False)) + + def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient(self): + group, group_id, rank = self._init_global_test() + # only do single NPU per process + npus = [rank] + + num_processes = int(WORLD_SIZE) + local_bs = rank + 2 + bs_offset = int((rank + 3) * rank / 2) + global_bs = int((num_processes + 3) * num_processes / 2) + + self._test_DistributedDataParallel_SyncBatchNorm( + npu_subset=npus, + rank=rank, + local_bs=local_bs, + global_bs=global_bs, + offset=bs_offset) + + @skipIfNoTorchVision + def test_SyncBatchNorm_process_group(self): + # When adopting `convert_sync_batchnorm` to convert a `nn.modules`, + # it need to recursively pass the `process_group` in the module when the `SyncBatchNorm` + # is nested in a sub-module or sub-sub-module (e.g. resnet50 in torchvision.models). + + process_ids = 0 + process_group = torch.distributed.new_group([process_ids]) + res50_model = torchvision.models.resnet50() + res50_model_sync = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(res50_model), process_group) + process_group_sync = res50_model_sync.layer1[0].bn1.process_group + self.assertEqual(process_group_sync, process_group) + +FILE_SCHEMA = "file://" +tmp_dir = None +def initialize_temp_directories(init_method=None): + global tmp_dir + tmp_dir = tempfile.TemporaryDirectory() + os.environ["TEMP_DIR"] = tmp_dir.name + os.mkdir(os.path.join(tmp_dir.name, "barrier")) + os.mkdir(os.path.join(tmp_dir.name, "test_dir")) + init_dir_path = os.path.join(tmp_dir.name, "init_dir") + os.mkdir(init_dir_path) + +def cleanup_temp_dir(): + if tmp_dir is not None: + tmp_dir.cleanup() + +WORLD_SIZE = os.environ["WORLD_SIZE"] + +class TestDistBackend(MultiProcessTestCase, _DistTestBase): + @classmethod + def setUpClass(cls): + os.environ["MASTER_ADDR"] = str(MASTER_ADDR) + os.environ["MASTER_PORT"] = str(MASTER_PORT) + super().setUpClass() + + def setUp(self): + super().setUp() + # initialize temp directories + initialize_temp_directories() + # initialize Barrier + Barrier.init() + self._spawn_processes() + + def tearDown(self): + cleanup_temp_dir() + super().tearDown() + + @property + def init_method(self): + return "{}{file_name}".format(FILE_SCHEMA, file_name=self.file_name) + + @classmethod + def _run(cls, rank, test_name, file_name): + self = cls(test_name) + self.rank = rank + self.file_name = file_name + + torch.npu.set_device(rank) + + if torch.npu.is_available() and torch.npu.device_count() < int(self.world_size): + sys.exit(TEST_SKIPS['multi-npu'].exit_code) + try: + timeout = timedelta(seconds=60) + dist.init_process_group( + init_method=INIT_METHOD, + backend=BACKEND, + world_size=int(self.world_size), + rank=self.rank, + timeout=timeout, + ) + except RuntimeError as e: + if "recompile" in e.args[0]: + sys.exit(TEST_SKIPS["backend_unavailable"].exit_code) + + raise + + # Execute barrier prior to running test to ensure that every process + # has finished initialization and that the following test + # immediately exiting due to a skip doesn't cause flakiness. + self._barrier() + + # self.id() == e.g. '__main__.TestDistributed.test_get_rank' + # We're retreiving a corresponding test and executing it. + getattr(self, test_name)() + self._barrier() + dist.destroy_process_group() + sys.exit(0) + + # Needed since MultiProcessTestCase assumes a world_size of 4, but we + # run these tests under other various world_sizes. + @property + def world_size(self): + return os.environ["WORLD_SIZE"] + +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/normalization/BatchNormGatherStatsWithCountsKernelNpu.cpp b/torch_npu/csrc/aten/ops/normalization/BatchNormGatherStatsWithCountsKernelNpu.cpp index 663520155a4b80d595c15d245b6422695177c9dc..846f7190fa93f5e0b7aa2ba9033ba495bb2b3516 100644 --- a/torch_npu/csrc/aten/ops/normalization/BatchNormGatherStatsWithCountsKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/normalization/BatchNormGatherStatsWithCountsKernelNpu.cpp @@ -43,7 +43,7 @@ std::tuple batch_norm_gather_stats_with_counts_npu_imp at::IntArrayRef axes({0}); at::Tensor countsTensor; countsTensor = NPUNativeFunctions::npu_dtype_cast(counts, meanCp.scalar_type()); - at::Tensor countsTensorT = NPUNativeFunctions::npu_transpose(countsTensor.unsqueeze(-1), {0, 1}); + at::Tensor countsTensorT = countsTensor.unsqueeze(-1); at::Tensor countsTensorBroadcast = NPUNativeFunctions::npu_broadcast(countsTensorT, invstd.sizes()); at::Tensor countsAllSum = OpPreparation::ApplyTensorWithSizes({1, dimC}, meanCp.options()); OpCommand cmd1; diff --git a/torch_npu/utils/module.py b/torch_npu/utils/module.py index d853435089d92d7629812438efe6d3623cc2ebd4..063ee5f70154f818a14f1516836f3806b30c1c84 100644 --- a/torch_npu/utils/module.py +++ b/torch_npu/utils/module.py @@ -18,7 +18,7 @@ import warnings import logging import torch import torch_npu - +from torch.nn.modules._functions import SyncBatchNorm as sync_batch_norm def npu(self, device=None): r"""Moves all model parameters and buffers to the npu. @@ -270,6 +270,70 @@ def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_le return padded_output.index_select(batch_dim, unsorted_indices), lengths[unsorted_indices] return padded_output, lengths +def syncbn_forward(self, input1: torch.Tensor) -> torch.Tensor: + # currently only NPU or GPU input is supported + if (not input1.is_cuda) and (not input1.is_npu): + raise ValueError('SyncBatchNorm expected input tensor to be on NPU or GPU') + + self._check_input_dim(input1) + + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + assert self.num_batches_tracked is not None + self.num_batches_tracked = self.num_batches_tracked + 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / self.num_batches_tracked.item() + else: # use exponential moving average + exponential_average_factor = self.momentum + + r""" + Decide whether the mini-batch stats should be used for normalization rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when buffers are None. + """ + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + r""" + Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + # If buffers are not to be tracked, ensure that they won't be updated + assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor) + assert self.running_var is None or isinstance(self.running_var, torch.Tensor) + running_mean = self.running_mean if not self.training or self.track_running_stats else None + running_var = self.running_var if not self.training or self.track_running_stats else None + + need_sync = bn_training + if need_sync: + process_group = torch.distributed.group.WORLD + if self.process_group: + process_group = self.process_group + world_size = torch.distributed.get_world_size(process_group) + need_sync = world_size > 1 + + # fallback to framework BN when synchronization is not necessary + if not need_sync: + return F.batch_norm( + input1, running_mean, running_var, self.weight, self.bias, + bn_training, exponential_average_factor, self.eps) + else: + if not self.ddp_gpu_size: + raise AttributeError('SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel') + + assert bn_training + return sync_batch_norm.apply( + input1, self.weight, self.bias, running_mean, running_var, + self.eps, exponential_average_factor, process_group, world_size) def apply_module_patch(): torch.nn.Module.npu = npu @@ -279,3 +343,4 @@ def apply_module_patch(): torch.nn.parallel.DistributedDataParallel.forward = ddp_forward torch.nn.modules.rnn.LSTM.forward = lstm_forward torch.nn.utils.rnn.pad_packed_sequence = pad_packed_sequence + torch.nn.modules.batchnorm.SyncBatchNorm.forward = syncbn_forward