From 3998eb46f1cf8a466c7ecfac658355b2571abdac Mon Sep 17 00:00:00 2001 From: wangqianren Date: Sat, 10 Jun 2023 11:00:36 +0800 Subject: [PATCH 1/3] add tensor.storage.resize_ API and fix bugs in _npu_storage_resize_only, npu_storage_resize, NpuStorage.resize_. fix bugs in _npu_storage_resize_only fix bugs in NpuStorage.resize_() fix bugs in npu_storage_resize --- torch_npu/csrc/aten/common/StorageSizeNpu.cpp | 71 +++++++++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 1 + torch_npu/utils/tensor_methods.py | 11 +-- 3 files changed, 79 insertions(+), 4 deletions(-) diff --git a/torch_npu/csrc/aten/common/StorageSizeNpu.cpp b/torch_npu/csrc/aten/common/StorageSizeNpu.cpp index 37462cf551..f35262fa3a 100644 --- a/torch_npu/csrc/aten/common/StorageSizeNpu.cpp +++ b/torch_npu/csrc/aten/common/StorageSizeNpu.cpp @@ -14,9 +14,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/StorageDescHelper.h" #include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "torch_npu/csrc/core/NPUBridge.h" +#include "torch_npu/csrc/core/NPUStorageImpl.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/core/npu/NPUFunctions.h" namespace at_npu { namespace native { @@ -30,5 +35,71 @@ namespace native { return n; } + static void _npu_storage_resize_only(torch_npu::NPUStorageImpl& storage, ptrdiff_t size) + { + if (!storage.resizable()) { + AT_ERROR("Trying to resize storage that is not resizable"); + return; + } + auto storage_desc = torch_npu::NPUBridge::GetNpuStorageImpl(&storage)->npu_desc_; + size_t itemsize = storage_desc.data_type_.itemsize(); + + at::DataPtr new_data; + new_data = storage.allocator()->allocate(size); + at::DataPtr old_data = storage.set_data_ptr(std::move(new_data)); + ptrdiff_t old_size = storage.nbytes(); + storage.set_nbytes(size); + + if (itemsize == 0) { + AT_ERROR("When resizing, item size of storage cannot be zero."); + return; + } + if ((size % itemsize) != 0) { + AT_ERROR("The specified storage nbytes cannot be divided by item size.", + "Please check the input parameter size."); + return; + } + std::vector resize_shape = {size/itemsize}; + // It is necessary to properly refresh the storage according to sizes and strides, + // not just new sizes. + at_npu::native::StorageDescHelper::UpdateDesc( + torch_npu::NPUBridge::GetNpuStorageImpl(&storage)->npu_desc_, resize_shape, resize_shape); + + if (old_data != nullptr) { + ptrdiff_t copy_size = old_size; + if (storage.nbytes() < copy_size) { + copy_size = storage.nbytes(); + } + if (copy_size > 0) { + aclError error = at_npu::native::CalcuOpUtil::LaunchAsyncCopyTaskWithModeSwitch( + storage, + copy_size, + old_data.get(), + copy_size, + ACL_MEMCPY_DEVICE_TO_DEVICE); + if (error != ACL_ERROR_NONE) { + AT_ERROR("ACL_Memcpy device to device error."); + return; + } + } + } + } + + static void _maybe_npu_storage_resize(at::TensorImpl* self, ptrdiff_t size) + { + if (!self->storage().unsafeGetStorageImpl()){ + AT_ERROR("Try to resize a tensor with null storage"); + return; + } + _npu_storage_resize_only(*torch_npu::NPUBridge::GetNpuStorageImpl(self->storage().unsafeGetStorageImpl()), size); + } + + at::Tensor NPUNativeFunctions::npu_storage_resize(const at::Tensor& self, int64_t size){ + int64_t new_size_bytes = (size + self.storage_offset()) * self.dtype().itemsize(); + auto* self_impl = self.unsafeGetTensorImpl(); + _maybe_npu_storage_resize(self_impl, new_size_bytes); + return self; + } + } // namespace native } // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index ce4b4073e7..d3900830d1 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1407,6 +1407,7 @@ unsupported: - special_zeta.other_scalar_out custom: + - func: npu_storage_resize(Tensor self, int size) -> Tensor - func: npu_change_data_ptr(Tensor dst, Tensor src, int index) -> int - func: npu_transpose(Tensor self, int[] perm, bool require_contiguous=True) -> Tensor - func: npu_transpose.out(Tensor self, int[] perm, bool require_contiguous=True, *, Tensor(a!) out) -> Tensor(a!) diff --git a/torch_npu/utils/tensor_methods.py b/torch_npu/utils/tensor_methods.py index 74c1e25c79..ffaad1e1ec 100644 --- a/torch_npu/utils/tensor_methods.py +++ b/torch_npu/utils/tensor_methods.py @@ -94,18 +94,21 @@ def _record_stream(self, *args, **kwargs): class NpuStorage(object): - def __init__(self, size): - self._size = size + def __init__(self, tensor): + self.tensor = tensor def size(self): - return self._size + return torch_npu.get_storage_size(self.tensor) + + def resize_(self, new_size): + return torch_npu.npu_storage_resize(self.tensor, new_size) storage_impl = torch.Tensor.storage def _storage(self): if torch_npu._C.is_npu(self): - return NpuStorage(torch_npu.get_storage_size(self)) + return NpuStorage(self) return storage_impl(self) -- Gitee From 18095720a881983f0d20b794348395570f938b81 Mon Sep 17 00:00:00 2001 From: wangqianren Date: Sat, 10 Jun 2023 20:38:58 +0800 Subject: [PATCH 2/3] add 1.13.0 version FSDP api --- torch_npu/__init__.py | 2 + torch_npu/distributed/__init__.py | 28 + .../algorithms/_checkpoint/__init__.py | 0 .../_checkpoint/checkpoint_wrapper.py | 257 + .../algorithms/_comm_hooks/__init__.py | 12 + .../algorithms/_comm_hooks/default_hooks.py | 172 + torch_npu/distributed/fsdp/__init__.py | 30 + .../distributed/fsdp/_fsdp_extensions.py | 112 + torch_npu/distributed/fsdp/_optim_utils.py | 1306 +++++ torch_npu/distributed/fsdp/_shard_utils.py | 269 + torch_npu/distributed/fsdp/_symbolic_trace.py | 243 + torch_npu/distributed/fsdp/_utils.py | 149 + torch_npu/distributed/fsdp/flat_param.py | 1133 +++++ .../fsdp/flatten_params_wrapper.py | 175 + .../fsdp/fully_sharded_data_parallel.py | 4477 +++++++++++++++++ .../distributed/fsdp/sharded_grad_scaler.py | 355 ++ torch_npu/distributed/fsdp/utils.py | 148 + torch_npu/distributed/fsdp/wrap.py | 482 ++ torch_npu/distributed/utils.py | 192 + torch_npu/nn/modules/module.py | 162 + 20 files changed, 9704 insertions(+) create mode 100644 torch_npu/distributed/algorithms/_checkpoint/__init__.py create mode 100644 torch_npu/distributed/algorithms/_checkpoint/checkpoint_wrapper.py create mode 100644 torch_npu/distributed/algorithms/_comm_hooks/__init__.py create mode 100644 torch_npu/distributed/algorithms/_comm_hooks/default_hooks.py create mode 100644 torch_npu/distributed/fsdp/__init__.py create mode 100644 torch_npu/distributed/fsdp/_fsdp_extensions.py create mode 100644 torch_npu/distributed/fsdp/_optim_utils.py create mode 100644 torch_npu/distributed/fsdp/_shard_utils.py create mode 100644 torch_npu/distributed/fsdp/_symbolic_trace.py create mode 100644 torch_npu/distributed/fsdp/_utils.py create mode 100644 torch_npu/distributed/fsdp/flat_param.py create mode 100644 torch_npu/distributed/fsdp/flatten_params_wrapper.py create mode 100644 torch_npu/distributed/fsdp/fully_sharded_data_parallel.py create mode 100644 torch_npu/distributed/fsdp/sharded_grad_scaler.py create mode 100644 torch_npu/distributed/fsdp/utils.py create mode 100644 torch_npu/distributed/fsdp/wrap.py create mode 100644 torch_npu/distributed/utils.py create mode 100644 torch_npu/nn/modules/module.py diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index 39e854b42d..68f83a8f95 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -67,6 +67,7 @@ from torch_npu.utils import apply_module_patch, add_tensor_methods, add_torch_fu add_fx_methods, add_checkpoint_methods from torch_npu.distributed.hccl_dtype_wraper import wrap_dtype_for_hccl from torch_npu.npu.amp.autocast_mode import apply_autocast_patch +from torch_npu.nn.modules.module import add_nn_module_api from .version import __version__ as __version__ @@ -242,6 +243,7 @@ def apply_class_patches(): add_fx_methods() add_checkpoint_methods() apply_autocast_patch() + add_nn_module_api() # Apply monkey-patches. diff --git a/torch_npu/distributed/__init__.py b/torch_npu/distributed/__init__.py index d718e0e046..b2f49a93f7 100644 --- a/torch_npu/distributed/__init__.py +++ b/torch_npu/distributed/__init__.py @@ -51,5 +51,33 @@ from .distributed_c10d import ( _rank_not_in_group, Logger, all_gather_object, broadcast_object_list, all_gather_togather, _reduce_scatter_base ) +from .fsdp import apply_fsdp_init +from .fsdp._optim_utils import apply_fsdp_optim_utils +from .fsdp._shard_utils import apply_fsdp_shard_utils +from .fsdp.flat_param import apply_fsdp_flat_param_handle +from .fsdp.flatten_params_wrapper import apply_fsdp_flatten_params_wrapper +from .fsdp.fully_sharded_data_parallel import apply_fsdp +from .fsdp.sharded_grad_scaler import apply_fsdp_shard_grad_scaler +from .fsdp.utils import apply_fsdp_utils +from .fsdp.wrap import apply_fsdp_wrap +from .algorithms._checkpoint.checkpoint_wrapper import apply_algorithms_checkpoint_wrapper +from .algorithms._comm_hooks.default_hooks import apply_algorithms_comm_hooks_default_hooks +from .algorithms._comm_hooks import apply_algorithms_comm_hooks_init +from .utils import apply_utils_func + set_debug_level_from_env() + +apply_fsdp_init() +apply_fsdp_optim_utils() +apply_fsdp_shard_utils() +apply_fsdp_flat_param_handle() +apply_fsdp_flatten_params_wrapper() +apply_fsdp() +apply_fsdp_shard_grad_scaler() +apply_fsdp_utils() +apply_fsdp_wrap() +apply_algorithms_checkpoint_wrapper() +apply_algorithms_comm_hooks_default_hooks() +apply_algorithms_comm_hooks_init() +apply_utils_func() diff --git a/torch_npu/distributed/algorithms/_checkpoint/__init__.py b/torch_npu/distributed/algorithms/_checkpoint/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torch_npu/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch_npu/distributed/algorithms/_checkpoint/checkpoint_wrapper.py new file mode 100644 index 0000000000..a53241f320 --- /dev/null +++ b/torch_npu/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -0,0 +1,257 @@ +from enum import auto, Enum +from functools import partial +from typing import Any, Dict, Iterator, Tuple + +import torch +import torch.nn as nn +from torch.autograd.graph import save_on_cpu +from torch.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs +from torch.utils.checkpoint import checkpoint + +_CHECKPOINT_PREFIX = "_checkpoint_wrapped_module" + +class CheckpointImpl(Enum): + REENTRANT = auto() + NO_REENTRANT = auto() + + +class CheckpointWrapper(torch.nn.Module): + """ + An nn.Module that wraps another nn.Module with checkpointing. Note that this + module is not meant to be used directly, but instead it is to be used + through the ``checkpoint_wrapper`` function. + """ + def __init__( + self, + mod: torch.nn.Module, + checkpoint_impl: CheckpointImpl = CheckpointImpl.REENTRANT, + offload_to_cpu: bool = False, + checkpoint_fn=None, + *checkpoint_fn_args, + **checkpoint_fn_kwargs, + ): + super().__init__() + self._checkpoint_wrapped_module = mod + self.checkpoint_impl = checkpoint_impl + self.offload_to_cpu = offload_to_cpu + if self.offload_to_cpu: + self.checkpoint_fn = None + else: + if checkpoint_fn is None: + # use torch.utils.checkpoint + self.checkpoint_fn = partial( + checkpoint, + use_reentrant=( + self.checkpoint_impl == CheckpointImpl.REENTRANT + ), + ) + else: + self.checkpoint_fn = partial( + checkpoint_fn, + *checkpoint_fn_args, + **checkpoint_fn_kwargs, + ) + # state_dict post hook to remove prefix to allow loading into a + # non-checkpoint wrapped module. + self._register_state_dict_hook(self._post_state_dict_hook) + # load_state_dict pre-hook to allow loading back into + # checkpoint-wrapped module. + self._register_load_state_dict_pre_hook( + self._pre_load_state_dict_hook, with_module=True + ) + + def __getattr__(self, name: str) -> Any: + """Forward missing attributes to wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self._checkpoint_wrapped_module, name) + + def __getitem__(self, key: int) -> Any: + """Forward indexing calls in case the module is a nn.Sequential.""" + return self._checkpoint_wrapped_module.__getitem__(key) # type: ignore[operator] + + def forward(self, *args, **kwargs): + if self.offload_to_cpu: + with save_on_cpu(pin_memory=True): + return self._checkpoint_wrapped_module(*args, **kwargs) + else: + # Support keyword arguments for reentrant checkpoint. Note that this + # only works if user has specified self.checkpoint_impl and is not + # using their own custom checkpoint_fn. + if self.checkpoint_impl == CheckpointImpl.REENTRANT and kwargs != {}: + # Pack the args and kwargs + flat_args, kwarg_keys = _pack_kwargs(*args, **kwargs) + + # Function that only takes (packed) args, but can unpack them + # into the original args and kwargs for the checkpointed + # function, and runs that function. + def my_function(*inputs): + # unpack back into args and kwargs + unpacked_args, unpacked_kwargs = _unpack_kwargs( + inputs, kwarg_keys + ) + # run original module + return self._checkpoint_wrapped_module( + *unpacked_args, **unpacked_kwargs + ) + + # Pass the function that only takes packed args into reentrant + # checkpoint API. + return self.checkpoint_fn( # type: ignore[misc] + my_function, + *flat_args, + ) + else: + return self.checkpoint_fn( # type: ignore[misc] + self._checkpoint_wrapped_module, + *args, + **kwargs + ) + + def named_parameters( + self, + *args, + **kwargs, + ) -> Iterator[Tuple[str, torch.nn.Parameter]]: + """ + Overrides :meth:`named_parameters()` to intercept parameter names and + remove all occurrences of _CHECKPOINT_PREFIX. + """ + for param_name, param in super().named_parameters(*args, **kwargs): + yield param_name.replace(f"{_CHECKPOINT_PREFIX}.", ""), param + + @staticmethod + def _post_state_dict_hook( + module: nn.Module, + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> Dict[str, Any]: + """ + _post_state_dict_hook() is called after the state_dict() of this + FSDP module is executed. For ``checkpoint_wrapper``, it will strip + checkpoint-wrapped module prefix so that this module can be loaded into + non-checkpointed modules. It would still be able to be loaded into + checkpoint-wrapped modules as this class adds the prefix back before + loading the state_dict. + """ + _replace_by_prefix(state_dict, f"{prefix}{_CHECKPOINT_PREFIX}.", prefix) + return state_dict + + @staticmethod + def _pre_load_state_dict_hook( + module: nn.Module, + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> None: + """ + ``_pre_state_dict_hook` is called before ``self._load_from_state_dict()`` + is called. For ``checkpoint_wrapper``, it will add back the module + prefix so that non-checkpointed modules can be loaded into + checkpoint_wrapper modules properly. + """ + _replace_by_prefix(state_dict, prefix, prefix + f"{_CHECKPOINT_PREFIX}.") + + +def checkpoint_wrapper( + module: torch.nn.Module, + checkpoint_impl: CheckpointImpl = CheckpointImpl.REENTRANT, + offload_to_cpu: bool = False, + checkpoint_fn=None, + *checkpoint_fn_args, + **checkpoint_fn_kwargs, +) -> torch.nn.Module: + """ + A convenience wrapper for activation checkpointing. If the module is wrapped + with this function, all subsequent calls to the module will automatically + perform checkpointing without the user having to explicitly call ``checkpoint`` + function. + Usage:: + checkpointed_module = checkpoint_wrapper(module) + outputs = checkpointed_module(inputs) + Args: + module (nn.Module): + The module to be wrapped + checkpoint_impl (Optional[CheckpointImpl]): + The checkpointing implementation to use. Note that this will only + be passed into the ``torch.utils.checkpoint.checkpoint`` + implementation, and is ignored if a custom ``checkpoint_fn`` is + specified. Note that for implementations using reentrant checkpoint + from ``torch.utils.checkpoint``, keyword arguments will only be + supported if ``checkpoint_impl`` is passed as ``CheckpointImpl.REENTRANT`. + offload_to_cpu (Optional[bool]): + Whether to offload activations of this wrapped module to CPU. Note + that if this is specified, ``checkpoint_impl`` and ``checkpoint_fn`` + arguments will be ignored in favor of the activations being + offloaded to CPU. Default is ``False``. Wrappers with activation + offload can be composed with ones that do recomputation-based + checkpoint to trade off increased compute versus increased CPU + memory usage and additional H2D transfers. + checkpoint_fn (Optional[Callable]): + Functional checkpoint implementation to use. If this is specified, + it will be used over the default ``torch.utils.checkpoint.checkpoint`` + implementation and the `checkpoint_impl` argument will be ignored. + *checkpoint_fn_args: (Sequence[Any]): Arguments to pass into `checkpoint_fn`. + **checkpoint_fn_kwargs: (Dict[str, Any]): Keyword arguments to pass into `checkpoint_fn`. + + Returns: + (nn.Module): + Wrapped module + """ + + return CheckpointWrapper( + module, checkpoint_impl, offload_to_cpu, checkpoint_fn, checkpoint_fn_args, checkpoint_fn_kwargs + ) + + +def apply_activation_checkpointing( + model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=lambda _: True +): + """ + Applies :func:`checkpoint_wrapper` to modules within `model` based on a user-defined + configuration. For each module within `model`, the `check_fn` is used to decide + whether `module` should be wrapped with :func:`checkpoint_wrapper` or not. + + Note:: + This function modifies `model` in place and replaces appropriate layers with + their checkpoint-wrapped modules. + Note:: + This function will not wrap the overall root module. If this is needed, please directly use + :class:`CheckpointWrapper`. + Usage:: + model = nn.Sequential( + nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10) + ) + check_fn = lambda l: isinstance(l, nn.Linear) + apply_activation_checkpointing(model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=check_fn) + Args: + model (nn.Module): + The model whose submodules should be wrapped with activation checkpointing. + checkpoint_wrapper_fn (Optional[Callable[nn.Module]]) + A ``Callable`` which will wrap modules + check_fn (Optional[Callable[nn.Module, nn.Module]]) + A lambda function which will be passed each child submoule of ``model`` and returns + ``True`` or ``False`` depending on whether the submodule should be wrapped. + Returns: None (`model` is modified inplace) + """ + # TODO: Importing inside function to avoid circular import issue between FSDP and + # checkpoint_wrapper. This can be resolved once wrap() APIs are decoupled from FSDP code. + from torch_npu.distributed.fsdp.wrap import _recursive_wrap, lambda_auto_wrap_policy + return _recursive_wrap( + module=model, + auto_wrap_policy=partial(lambda_auto_wrap_policy, lambda_fn=check_fn), + wrapper_cls=checkpoint_wrapper_fn, + ignored_modules=set(), + ignored_params=set(), + only_wrap_children=True + ) + + +def apply_algorithms_checkpoint_wrapper(): + torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing = apply_activation_checkpointing + torch.distributed.algorithms._checkpoint.checkpoint_wrapper.checkpoint_wrapper = checkpoint_wrapper + torch.distributed.algorithms._checkpoint.checkpoint_wrapper.CheckpointWrapper = CheckpointWrapper + torch.distributed.algorithms._checkpoint.checkpoint_wrapper.CheckpointImpl = CheckpointImpl + torch.distributed.algorithms._checkpoint.checkpoint_wrapper._CHECKPOINT_PREFIX = _CHECKPOINT_PREFIX diff --git a/torch_npu/distributed/algorithms/_comm_hooks/__init__.py b/torch_npu/distributed/algorithms/_comm_hooks/__init__.py new file mode 100644 index 0000000000..01a95c4172 --- /dev/null +++ b/torch_npu/distributed/algorithms/_comm_hooks/__init__.py @@ -0,0 +1,12 @@ +import torch +from . import default_hooks as default + +LOW_PRECISION_HOOKS = [ + default.fp16_compress_hook, + default.bf16_compress_hook, +] + + +def apply_algorithms_comm_hooks_init(): + torch.distributed.algorithms._comm_hooks.default = default + torch.distributed.algorithms._comm_hooks.LOW_PRECISION_HOOKS = LOW_PRECISION_HOOKS \ No newline at end of file diff --git a/torch_npu/distributed/algorithms/_comm_hooks/default_hooks.py b/torch_npu/distributed/algorithms/_comm_hooks/default_hooks.py new file mode 100644 index 0000000000..cbb6bf1fbc --- /dev/null +++ b/torch_npu/distributed/algorithms/_comm_hooks/default_hooks.py @@ -0,0 +1,172 @@ +import functools +import torch +import torch_npu +import torch.distributed as dist + + +class DefaultState(object): + r""" + Stores state needed to perform the default communication algorithm + within a communication hook. + + Args: + process_group (ProcessGroup): The process group to be used. + """ + + __slots__ = [ + "process_group", + "world_size", + "gradient_predivide_factor", + "gradient_postdivide_factor" + ] + + def __init__( + self, + process_group: dist.ProcessGroup + ): + if process_group is None: + raise ValueError(f"Expected to pass in an explicit ProcessGroup to {self}.") + self.process_group = process_group + self.world_size = dist.get_world_size(process_group) + # Setting two factors `self.gradient_predivide_factor` + # and `self.gradient_postdivide_factor` to avoid underflow and overflow + self.gradient_predivide_factor = self._get_gradient_predivide_factor( + self.world_size + ) + self.gradient_postdivide_factor = self.world_size / self.gradient_predivide_factor + + def _get_gradient_predivide_factor(self, world_size: int) -> float: + factor: int = 1 + while world_size % factor == 0 and world_size / factor > factor: + factor *= 2 + return float(factor) + +class LowPrecisionState(DefaultState): + r""" + Stores state needed to perform gradient communication in a lower precision + within a communication hook. Communication hook will cast gradients back + to the original parameter precision specified by ``parameter_type`` (default: torch.float32). + Builds on top of the :class:`DefaultState`. + + Args: + parameter_type (torch.dtype): The precision of model's parameters. + Required for a hook to cast gradients back to a parameter's precision. + """ + + __slots__ = [ + "parameter_type", + ] + + def __init__( + self, + process_group, + parameter_type=torch.float32, + ): + super().__init__(process_group) + self.parameter_type = parameter_type + + +def _decompress(state: LowPrecisionState, grad: torch.Tensor): + """ + Casts gradients back to full parameter precision so that + further computation happens in full precision. + """ + orig_grad_data = grad.data + grad.data = grad.data.to(state.parameter_type) + # Don't let this memory get reused until after the transfer. + orig_grad_data.record_stream(torch_npu.npu.current_stream()) # type: ignore[arg-type] + +def allreduce_hook(state: DefaultState, grad: torch.Tensor): + r""" + This FSDP communication hook implements ``all_reduce`` algorithm + and a necessary pre- and post-division of gradients. + + Args: + state (DefaultState): State information, configures pre- and post-division factors. + grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks. + """ + # Average grad by pre-division factor. Together pre- and post-division factors + # lead to an overall averaging by world_size, required for consistency with PyTorch DDP. + # This is a two-step process to avoid potential underflow and overflow. + if state.gradient_predivide_factor > 1: + grad.div_(state.gradient_predivide_factor) + dist.all_reduce(grad, group=state.process_group) + # Average grad by post-division factor. + if state.gradient_postdivide_factor > 1: + grad.div_(state.gradient_postdivide_factor) + +def reduce_scatter_hook(state: DefaultState, grad: torch.Tensor, output: torch.Tensor): + r""" + This FSDP communication hook implements ``reduce_scatter`` algorithm for + sharded FSDP strategies and a necessary pre- and post-division of gradients. + + Args: + state (DefaultState): State information, configures pre- and post-division factors. + grad (torch.Tensor): An unsharded gradient for the local batch that needs to be + communicated across ranks. + output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. + """ + # Average grad by pre-division factor. + if state.gradient_predivide_factor > 1: + grad.div_(state.gradient_predivide_factor) + dist._reduce_scatter_base( + output, grad, group=state.process_group + ) + # Average grad's shard by post-division factor. + if state.gradient_postdivide_factor > 1: + output.div_(state.gradient_postdivide_factor) + +def _low_precision_hook(prec: torch.dtype, state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor): + grad.data = grad.data.to(prec) + if output is not None: + output.data = output.data.to(prec) + reduce_scatter_hook(state, grad, output) + _decompress(state, output) + else: + allreduce_hook(state, grad) + _decompress(state, grad) + +def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor = None): + r""" + This FSDP communication hook implements a simple gradient compression + approach that casts ``grad`` to half-precision floating-point format (``torch.float16``). + It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a + ``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``) + gradients are averaged by a ``state.gradient_postdivide_factor``. + Once post-division is done, compressed gradients are casted back to parameters' precision. + + Args: + state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision. + grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision. + output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. + """ + fp16_hook = functools.partial(_low_precision_hook, torch.float16) + return fp16_hook(state, grad, output) + +def bf16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor = None): + r""" + This FSDP communication hook implements a simple gradient compression + approach that casts ``grad`` to half-precision floating-point format (``torch.float16``). + It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a + ``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``) + gradients are averaged by a ``state.gradient_postdivide_factor``. + Once post-division is done, compressed gradients are casted back to parameters' precision. + + Args: + state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision. + grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision. + output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. + """ + bf16_hook = functools.partial(_low_precision_hook, torch.bfloat16) + return bf16_hook(state, grad, output) + + +def apply_algorithms_comm_hooks_default_hooks(): + torch.distributed.algorithms._comm_hooks.default_hooks.bf16_compress_hook = bf16_compress_hook + torch.distributed.algorithms._comm_hooks.default_hooks.fp16_compress_hook = fp16_compress_hook + torch.distributed.algorithms._comm_hooks.default_hooks._low_precision_hook = _low_precision_hook + torch.distributed.algorithms._comm_hooks.default_hooks.reduce_scatter_hook = reduce_scatter_hook + torch.distributed.algorithms._comm_hooks.default_hooks.allreduce_hook = allreduce_hook + torch.distributed.algorithms._comm_hooks.default_hooks._decompress = _decompress + torch.distributed.algorithms._comm_hooks.default_hooks.LowPrecisionState = LowPrecisionState + torch.distributed.algorithms._comm_hooks.default_hooks.DefaultState = DefaultState \ No newline at end of file diff --git a/torch_npu/distributed/fsdp/__init__.py b/torch_npu/distributed/fsdp/__init__.py new file mode 100644 index 0000000000..f7b0018cd8 --- /dev/null +++ b/torch_npu/distributed/fsdp/__init__.py @@ -0,0 +1,30 @@ +import torch +import torch_npu +from .flat_param import FlatParameter +from .fully_sharded_data_parallel import ( + BackwardPrefetch, + CPUOffload, + FullStateDictConfig, + FullyShardedDataParallel, + LocalStateDictConfig, + MixedPrecision, + OptimStateKeyType, + ShardingStrategy, + StateDictType, +) +from .wrap import ParamExecOrderWrapPolicy + + +def apply_fsdp_init(): + torch.distributed.fsdp.FlatParameter = FlatParameter + torch.distributed.fsdp.BackwardPrefetch = BackwardPrefetch + torch.distributed.fsdp.CPUOffload = CPUOffload + torch.distributed.fsdp.FullStateDictConfig = FullStateDictConfig + torch.distributed.fsdp.FullyShardedDataParallel = FullyShardedDataParallel + torch.distributed.fsdp.LocalStateDictConfig = LocalStateDictConfig + torch.distributed.fsdp.MixedPrecision = MixedPrecision + torch.distributed.fsdp.OptimStateKeyType = OptimStateKeyType + torch.distributed.fsdp.ShardingStrategy = ShardingStrategy + torch.distributed.fsdp.StateDictType = StateDictType + torch.distributed.fsdp.ParamExecOrderWrapPolicy = ParamExecOrderWrapPolicy + diff --git a/torch_npu/distributed/fsdp/_fsdp_extensions.py b/torch_npu/distributed/fsdp/_fsdp_extensions.py new file mode 100644 index 0000000000..079ae76c3d --- /dev/null +++ b/torch_npu/distributed/fsdp/_fsdp_extensions.py @@ -0,0 +1,112 @@ +from abc import ABC, abstractmethod +from typing import Any, List, Optional, Tuple + +import torch +import torch.distributed as dist + +from ._shard_utils import _create_chunk_sharded_tensor + + +class FSDPExtensions(ABC): + """ + This enables some customizable hooks to enable composability with tensor + parallelism. To activate these hooks, use :func:`_set_fsdp_extensions` to + set a custom :class:`FSDPExtensions` that implements the hooks. + """ + + @abstractmethod + def pre_flatten_transform( + self, + tensor: torch.Tensor, + ) -> Tuple[torch.Tensor, Optional[Any]]: + """E.g. converting ``DistributedTensor`` to local tensor.""" + ... + + @abstractmethod + def post_unflatten_transform( + self, + tensor: torch.Tensor, + param_extension: Any, + ) -> torch.Tensor: + """E.g. converting local tensor to ``DistributedTensor``.""" + ... + + @abstractmethod + def chunk_tensor( + self, + tensor: torch.Tensor, + rank: int, + world_size: int, + num_devices_per_node: int, + pg: dist.ProcessGroup, + ) -> torch.Tensor: + """Shards a tensor to chunks and returns the local chunk.""" + ... + + @abstractmethod + def pre_load_state_dict_transform( + self, + tensor: torch.Tensor, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + This is to be called before loading a *sharded* model state dict and + should return the tensor and list of shards from which to load data. + """ + ... + + +_extensions: Optional[FSDPExtensions] = None + + +def _set_fsdp_extensions(flattener: FSDPExtensions) -> None: + global _extensions + _extensions = flattener + + +def _ext_pre_flatten_transform( + tensor: torch.Tensor, +) -> Tuple[torch.Tensor, Optional[Any]]: + if _extensions is not None: + new_tensor, extension = _extensions.pre_flatten_transform(tensor) + if extension is not None: + return new_tensor, extension + return tensor, None + + +def _ext_post_unflatten_transform( + tensor: torch.Tensor, + param_extension: Any, +) -> torch.Tensor: + if _extensions is not None and param_extension is not None: + return _extensions.post_unflatten_transform(tensor, param_extension) + return tensor + + +def _ext_chunk_tensor( + tensor: torch.Tensor, + rank: int, + world_size: int, + num_devices_per_node: int, + pg: dist.ProcessGroup, +) -> torch.Tensor: + chunk_tensor_fn = ( + _extensions.chunk_tensor + if _extensions is not None + else _create_chunk_sharded_tensor + ) + return chunk_tensor_fn( + tensor, + rank, + world_size, + num_devices_per_node, + pg, + ) + + +def _ext_pre_load_state_dict_transform( + tensor: torch.Tensor, +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + if _extensions is not None: + return _extensions.pre_load_state_dict_transform(tensor) + shards = tensor.local_shards() # type: ignore[attr-defined] + return (tensor, shards) diff --git a/torch_npu/distributed/fsdp/_optim_utils.py b/torch_npu/distributed/fsdp/_optim_utils.py new file mode 100644 index 0000000000..1451ca2456 --- /dev/null +++ b/torch_npu/distributed/fsdp/_optim_utils.py @@ -0,0 +1,1306 @@ +import collections +import copy +import functools +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, + cast, +) + +import torch +import torch_npu +import torch.distributed as dist +# Import the entire FSDP file to avoid circular imports +import .fully_sharded_data_parallel as FSDP +import torch.nn as nn +from torch.distributed._shard.sharded_tensor import ShardedTensor +from ._shard_utils import _gather_state_dict +from .flat_param import FlatParameter, FlatParamHandle +from ._fsdp_extensions import _ext_chunk_tensor + + +def sorted_items(dictionary: Dict[str, Any]) -> Iterator[Tuple[str, Any]]: + keys = sorted(dictionary.keys()) + for k in keys: + yield k, dictionary[k] + + +class _ConsolidatedOptimState: + """ + This holds the consolidated optimizer state on the target rank. Positive- + dimension tensor state is communicated across ranks, while zero-dimension + tensor state and non-tensor state is taken directly from the target rank. + + PyTorch version 1.12 moved to using zero-dimension tensors for scalar + values, but user implemented optimizers may still use float (i.e. a + non-tensor). Thus, we support both and handle them identically. + + Attributes: + tensor_state (Dict[str, torch.Tensor]): Mapping from positive-dimension + tensor state name to the unsharded flattened tensor representing + the state. + zero_dim_tensor_state (Dict[str, torch.Tensor]): Mapping from zero- + dimension tensor state name to its value. + non_tensor_state (Dict[str, Any]): Mapping from non-tensor state + name to its value. + """ + + tensor_state: Dict[str, torch.Tensor] = {} + zero_dim_tensor_state: Dict[str, torch.Tensor] = {} + non_tensor_state: Dict[str, Any] = {} + + +class _PosDimTensorInfo(NamedTuple): + """ + Meatadata for positive-dimension tensors used internally for + :meth:`scatter_full_optim_state_dict`. + + Attributes: + shape (torch.Size): Sharded tensor shape (which is equal to the + unsharded tensor shape if the tensor is optimizer state for a + non-FSDP parameter and is hence not sharded). + dtype (torch.dtype): Data type of the tensor. + """ + + shape: torch.Size + dtype: torch.dtype + + +class _OptimStateKey(NamedTuple): + """ + This represents an optimizer state key that may be used commonly across + ranks. It is based on the unflattened parameter names rather than parameter + IDs to make it indepenendent of each rank's own optimizer construction. + """ + + unflat_param_names: Tuple[str, ...] + is_flat_param: bool + + +def _unflatten_optim_state( + flat_param: FlatParameter, + flat_param_state: Dict[str, Any], + fsdp_module, + to_save: bool, + shard_state: bool, +) -> List[Dict[str, Any]]: + """ + Unflattens the optimizer state, consisting of the "state" part and the + "param_groups" part. Unflattening the "state" part involves consolidating + the state on the target rank and remapping from flattened to unflattened + parameter IDs, and the "param_groups" part only involves remapping from + flattened to unflattened parameter IDs. + + Args: + flat_param (FlatParameter): The flattened parameter. + flat_param_state (Dict[str, Any]): Entry for the flattened parameter + in the "state" part of the optimizer state dict. + fsdp_module (FullyShardedDataParallel): FSDP module that owns + ``flat_param``, i.e. holds it in ``self.params``. + to_save (bool): Whether to save the state on this rank. + + Returns: + List[Dict[str, Any]]: A :class:`list` holding the entries in the + "state" part of the optimizer state dict corresponding to the + unflattened parameters comprising the flattened parameter + ``flat_param`` if on the target rank or an empty :class:`list` + otherwise. The final optimizer state dict will need to map these + entries using the proper unflattened parameter IDs. + """ + consolidated_state = _communicate_optim_state( + flat_param, + flat_param_state, + fsdp_module, + to_save, + ) + unflat_param_state = ( + _unflatten_communicated_optim_state( + fsdp_module, + flat_param, + consolidated_state, + shard_state, + ) + if to_save or shard_state + else [] + ) + if to_save: + for optim_state in unflat_param_state: + for key in list(optim_state.keys()): + state = optim_state[key] + if isinstance(state, torch.Tensor): + optim_state[key] = state.cpu() + return unflat_param_state + + +def _communicate_optim_state( + flat_param: FlatParameter, + flat_param_state: Dict[str, Any], + fsdp_module, + to_save: bool, +) -> _ConsolidatedOptimState: + """ + Communicates the optimizer state for a flattened parameter ``flat_param`` + across ranks so that the target rank holds the entire non-sharded optimizer + state. + + If ``N`` is the number of tensor optimizer states in the optimizer state + dict, then the communication complexity is 0 if ``N = 0`` and ``N + 1`` + otherwise (where the plus 1 comes from all-gathering the padding per rank). + + Args: + flat_param (FlatParameter): The flattened parameter. + flat_param_state (Dict[str, Any]): The entry in the "state" part of the + optimizer state dict corresponding to the flattened parameter. + fsdp_module (FullyShardedDataParallel): FSDP module that owns + ``flat_param``, i.e. holds it in ``self.params``. + to_save (bool): Whether to save the state on this rank. + + Returns: + ConsolidatedOptimState: Consolidated optimizer state for + ``flat_param``; the state is not populated for non-target ranks. + """ + state = _ConsolidatedOptimState() + tensor_state, zero_dim_tensor_state, non_tensor_state = ( + state.tensor_state, + state.zero_dim_tensor_state, + state.non_tensor_state, + ) + group = fsdp_module.process_group + + for state_name, value in sorted_items(flat_param_state): + # Positive-dimension tensor state: communicate across ranks + if torch.is_tensor(value) and value.dim() > 0: + # If the parameter is not sharded, then neither is the + # positive-dimension tensor state, so no need to communicate it -- + # we take the target rank's value + if ( + fsdp_module.world_size == 1 + or fsdp_module.sharding_strategy == FSDP.ShardingStrategy.NO_SHARD + ): + tensor_state[state_name] = value + continue + if not value.is_npu: + value = value.to(fsdp_module.compute_device) + # Assume that positive-dimension tensor optimizer state + # has the same shape as the sharded flattened parameter + buffer_size = flat_param._full_param_padded.size() # type: ignore[attr-defined] + tensor_buffer = value.new_zeros(*buffer_size) + dist._all_gather_base(tensor_buffer, value, group=group) + torch_npu.npu.synchronize() + if to_save: + unpadded_numel = flat_param._unpadded_unsharded_size.numel() # type: ignore[attr-defined] + tensor_state[state_name] = tensor_buffer[:unpadded_numel] + # Zero-dimension tensor state and non-tensor state: take this rank's + # value directly + elif to_save: + if _is_zero_dim_tensor(value): + zero_dim_tensor_state[state_name] = value + else: + non_tensor_state[state_name] = value + return state + + +def _unflatten_communicated_optim_state( + fsdp_module, + flat_param: FlatParameter, + state: _ConsolidatedOptimState, + shard_state: bool, +) -> List[Dict[str, Any]]: + """ + Unflattens the communicated optimizer state (given by ``tensor_state``, + ``non_tensor_state``, and ``zero_dim_tensor_state``) for a single flattened + parameter ``flat_param``. This should only be called on the target rank. + + Args: + flat_param (FlatParameter): The flattened parameter. + state (_ConsolidatedOptimState): Consolidated optimizer state. + + Returns: + List[Dict[str, Any]]: A :class:`list` holding the entries in the + "state" part of the optimizer state dict corresponding to the + unflattened parameters comprising the flattened parameter + ``flat_param``. The final optimizer state dict will need to map these + entries using the proper unflattened parameter IDs. + """ + unflat_param_state: List[Dict[str, Any]] = [] + flat_param_views: Dict[str, Iterator] = {} + num_unflat_params = flat_param._num_params + tensor_state, zero_dim_tensor_state, non_tensor_state = ( + state.tensor_state, + state.zero_dim_tensor_state, + state.non_tensor_state, + ) + + for _ in range(num_unflat_params): + unflat_state_param = {} + # Add positive-dimension tensor state: unflatten with views + for state_name, flat_tensor in sorted_items(tensor_state): + views_generated = state_name in flat_param_views + if not views_generated: + views = FlatParamHandle._get_unflat_views(flat_param, flat_tensor) + flat_param_views[state_name] = views + else: + views = flat_param_views[state_name] + optim_state: Union[torch.Tensor, ShardedTensor] = next(views) + if shard_state: + optim_state = _ext_chunk_tensor( + optim_state, + fsdp_module.rank, + fsdp_module.world_size, + torch_npu.npu.device_count(), + fsdp_module.process_group, + ) + unflat_state_param[state_name] = optim_state + + # Add zero-dimension tensor state: take the target rank's value + for state_name, zero_dim_tensor in sorted_items(zero_dim_tensor_state): + unflat_state_param[state_name] = zero_dim_tensor + # Add non-tensor state: take the target rank's value + for state_name, non_tensor in sorted_items(non_tensor_state): + unflat_state_param[state_name] = non_tensor + unflat_param_state.append(unflat_state_param) + return unflat_param_state + + +def _flatten_optim_state_dict( + optim_state_dict: Dict[str, Any], + model: torch.nn.Module, + shard_state: bool, +) -> Dict[str, Any]: + """ + Flattens the full optimizer state dict, still keying by unflattened + parameter names. If ``shard_state=True``, then FSDP-managed + ``FlatParameter`` 's optimizer states are sharded, and otherwise, they are + kept unsharded. + + Returns: + Dict[str, Any]: The flattened optimizer state dict. + """ + unflat_osd = optim_state_dict + if "state" not in unflat_osd or "param_groups" not in unflat_osd: + raise ValueError( + '`optim_state_dict` must have the keys "state" and ' + '"param_groups" to be a valid optimizer state dict' + ) + flat_param_to_fsdp_module = _get_flat_param_to_fsdp_module(model) + param_to_unflat_param_names = FSDP._get_param_to_unflat_param_names(model) + + # Construct the "state" part + flat_osd_state: Dict[_OptimStateKey, Any] = {} + unflat_osd_state = unflat_osd["state"] + for param, unflat_param_names in param_to_unflat_param_names.items(): + if isinstance(param, FlatParameter): # flatten FSDP parameters' states + assert param in flat_param_to_fsdp_module, ( + "Check the `flat_param_to_fsdp_module` construction\n" f"param: {param}" + ) + fsdp_module = flat_param_to_fsdp_module[param] + flat_state = _flatten_optim_state( + unflat_osd_state, + unflat_param_names, + fsdp_module, + param, + shard_state, + ) + key = _OptimStateKey(tuple(unflat_param_names), True) + flat_osd_state[key] = flat_state + else: # do not flatten non-FSDP parameters' states + assert len(unflat_param_names) == 1 + unflat_param_name = unflat_param_names[0] + if unflat_param_name not in unflat_osd_state: + # The state dict may not have an entry for a parameter if it + # was not passed into the optimizer (e.g. if it is not an + # FSDP-managed parameter) + continue + key = _OptimStateKey(tuple(unflat_param_names), False) + flat_osd_state[key] = copy.copy(unflat_osd_state[unflat_param_name]) + + # Construct the "param_groups" part -- copy as is since it will be + # rekeyed later according to the target rank's optimizer + flat_osd_param_groups = copy.deepcopy(unflat_osd["param_groups"]) + return {"state": flat_osd_state, "param_groups": flat_osd_param_groups} + + +def _flatten_optim_state( + unflat_osd_state: Dict[str, Dict[str, Any]], + unflat_param_names: List[str], + fsdp_module, + flat_param: FlatParameter, + shard_state: bool, +) -> Dict[str, Any]: + """ + Flattens the optimizer state in ``full_optim_state_dict`` for a single + flattened parameter ``flat_param`` in ``fsdp_module`` corresponding to + the unflattened parameter names in ``unflat_param_names``. + + Args: + unflat_osd_state (Dict[str, Dict[str, Any]]): The "state" part of the + optimizer state dict corresponding to the unflattened parameters. + unflat_param_names (List[str]): A :class:`list` of unflattened + parameter names corresponding to the flattened parameter + ``flat_param``. + fsdp_module (FullyShardedDataParallel): FSDP module owning the + flattened parameter. + flat_param (FlatParameter): The flattened parameter. + shard_state (bool): Whether to shard flattened positive-dimension + tensor state; if ``False``, then the full flattened tensor is + kept in the returned :class:`dict. + + Returns: + Dict[str, Any]: A :class:`dict` mapping state names to their values for + a particular flattened parameter. The sharded optimizer state dict's + "state" part will map a key to this returned value. + """ + num_unflat_params = len(unflat_param_names) + assert num_unflat_params > 0, ( + "Expects at least one unflattened parameter corresponding to the " + "flattened parameter" + ) + unflat_param_shapes = flat_param._shapes + num_unflat_param_shapes = len(unflat_param_shapes) + assert ( + num_unflat_params == num_unflat_param_shapes + ), f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}" + + # Check if these unflattened parameters have any optimizer state + has_state = [ + bool(unflat_param_name in unflat_osd_state) + for unflat_param_name in unflat_param_names + ] + # If none of the unflattened parameters comprising this flattened parameter + # have any state, then we do not want an entry in the optimizer state dict + if not any(has_state): + return {} # no need to flatten any state + # There may still be some unflattened parameters with state and some + # without + unflat_param_states = [ + _gather_state_dict( + unflat_osd_state[unflat_param_name], pg=fsdp_module.process_group + ) + if unflat_param_name in unflat_osd_state + else None + for unflat_param_name in unflat_param_names + ] + # Check that the unflattened parameters have the same state names + state_names = None + for unflat_param_state in unflat_param_states: + if unflat_param_state is None: + continue + if state_names is None: + state_names = set(unflat_param_state.keys()) + else: + if state_names != set(unflat_param_state.keys()): + raise ValueError( + "Differing optimizer state names for the unflattened " + f"parameters: {unflat_param_names}" + ) + assert state_names is not None + + # Flatten the state + flat_state: Dict[str, Any] = {} + for state_name in state_names: + state_values = [ + unflat_param_state[state_name] if unflat_param_state is not None else None + for unflat_param_state in unflat_param_states + ] + non_none_state_values = [v for v in state_values if v is not None] + are_pos_dim_tensors = are_zero_dim_tensors = are_non_tensors = True + for v in non_none_state_values: + are_pos_dim_tensors &= torch.is_tensor(v) and v.dim() > 0 + are_zero_dim_tensors &= _is_zero_dim_tensor(v) + are_non_tensors &= not torch.is_tensor(v) + types = set(type(v) for v in non_none_state_values) + if len(types) != 1 or not ( + are_pos_dim_tensors or are_zero_dim_tensors or are_non_tensors + ): + raise ValueError( + f"Differing optimizer state types for state {state_name}, " + f"values {non_none_state_values}, and unflattened parameter " + f"names {unflat_param_names}" + ) + if are_pos_dim_tensors: + flat_tensor = _flatten_tensor_optim_state( + state_name, + state_values, + unflat_param_names, + unflat_param_shapes, + flat_param, + ) + if shard_state: + # Shard the flattened tensor immediately to minimize max memory + # usage + sharded_flat_tensor, _ = FlatParamHandle._get_shard( + flat_tensor, + fsdp_module.rank, + fsdp_module.world_size, + ) + flat_state[state_name] = sharded_flat_tensor + else: + flat_state[state_name] = flat_tensor + elif are_zero_dim_tensors: + flat_state[state_name] = _flatten_zero_dim_tensor_optim_state( + state_name, + state_values, + unflat_param_names, + ) + else: + assert are_non_tensors + flat_state[state_name] = _flatten_non_tensor_optim_state( + state_name, + state_values, + unflat_param_names, + ) + + return flat_state + + +def _flatten_tensor_optim_state( + state_name: str, + pos_dim_tensors: List[torch.Tensor], + unflat_param_names: List[str], + unflat_param_shapes: Sequence[torch.Size], + flat_param: FlatParameter, +) -> torch.Tensor: + """ + Flattens the positive-dimension tensor optimizer state given by the values + ``tensors`` for the state ``state_name`` for a single flattened parameter + ``flat_param`` corresponding to the unflattened parameter names + ``unflat_param_names`` and unflatted parameter shapes + ``unflat_param_shapes``. This flattens each unflattened parameter's tensor + state into one tensor. + + NOTE: We use zero tensors for any unflattened parameters without state + since some value is required to fill those entries. This assumes that the + zero tensor is mathematically equivalent to having no state, which is true + for Adam's "exp_avg" and "exp_avg_sq" but may not be true for all + optimizers. + + Args: + state_name (str): Optimizer state name. + pos_dim_tensors (List[torch.Tensor]): Positive-dimension tensor + optimizer state values for the unflattened parameters corresponding + to the single flattened parameter. + unflat_param_names (List[str]): A :class:`list` of unflattened + parameter names corresponding to the single flattened parameter. + unflat_param_shapes (List[torch.Size]): Unflattened parameter shapes + corresponding to the single flattened parameter. + flat_param (FlatParameter): The flattened parameter. + + Returns: + torch.Tensor: A flattened tensor containing the optimizer state + corresponding to ``state_name`` constructed by concatenating the + unflattened parameter tensor states in ``pos_dim_tensors`` (using zero + tensors for any unflattened parameters without the state). + """ + non_none_tensors = [t for t in pos_dim_tensors if t is not None] + # Check that all are tensors with the same dtype + dtypes = set(t.dtype for t in non_none_tensors) + if len(dtypes) != 1: + raise ValueError( + "All unflattened parameters comprising a single flattened " + "parameter must have positive-dimension tensor state with the " + f"same dtype but got dtypes {dtypes} for state {state_name} and " + f"unflattened parameter names {unflat_param_names}" + ) + dtype = next(iter(dtypes)) + # Check that each tensor state matches its parameter's shape + for tensor, shape in zip(pos_dim_tensors, unflat_param_shapes): + if tensor is None and len(shape) == 0: + raise ValueError("Flattening a zero-dimension parameter is not supported") + elif tensor is not None and tensor.shape != shape: + raise ValueError( + "Tensor optimizer state does not have same shape as its " + f"parameter: {tensor.shape} {shape}" + ) + # Flatten the tensor states: we do not need to add any padding since the + # flattened optimizer state tensor sharded via `_get_shard()`, which pads + # the shard as needed (just like for the flattened parameter) + cpu_device = torch.device("cpu") + tensors = [ + torch.flatten(state_value.to(cpu_device)) + if state_value is not None + else torch.flatten( + torch.zeros( + size=shape, + dtype=dtype, + device=cpu_device, + ) + ) + for state_value, shape in zip(pos_dim_tensors, unflat_param_shapes) + ] + flat_tensor = torch.cat(tensors) + flat_param_shape = flat_param._unpadded_unsharded_size # type: ignore[attr-defined] + assert flat_tensor.shape == flat_param_shape, ( + f"tensor optim state: {flat_tensor.shape} " + f"flattened parameter: {flat_param_shape}" + ) + return flat_tensor + + +def _flatten_zero_dim_tensor_optim_state( + state_name: str, + zero_dim_tensors: List[torch.Tensor], + unflat_param_names: List[str], +) -> torch.Tensor: + """ + Flattens the zero-dimension tensor optimizer state given by the values + ``zero_dim_tensors`` for the state ``state_name`` for a single flattened + parameter corresponding to the unflattened parameter names + ``unflat_param_names`` by enforcing that all tensors are the same and using + that common value. + + NOTE: The requirement that the tensors are the same across all unflattened + parameters comprising the flattened parameter is needed to maintain the + invariant that FSDP performs the same computation as its non-sharded + equivalent. This means that none of the unflattened parameters can be + missing this state since imposing a value may differ from having no value. + For example, for Adam's "step", no value means maximum bias correction, + while having some positive value means less bias correction. + + Args: + state_name (str): Optimizer state name. + zero_dim_tensors (List[torch.Tensor]): Zero-dimension optimizer state + for the unflattened parameters corresponding to the single + flattened parameter. + unflat_param_names (List[str]): A :class:`list` of unflattened + parameter names corresponding to the single flattened parameter. + + Returns: + torch.Tensor: A zero-dimensional tensor giving the value of the state + ``state_name`` for all unflattened parameters corresponding to the + names ``unflat_param_names``. + """ + non_none_tensors = [t for t in zero_dim_tensors if t is not None] + # Enforce that all have the same value and dtype + values_set = set(t.item() if t is not None else None for t in zero_dim_tensors) + dtypes = set(t.dtype if t is not None else None for t in zero_dim_tensors) + if ( + len(non_none_tensors) != len(zero_dim_tensors) + or len(values_set) != 1 + or len(dtypes) != 1 + ): + raise ValueError( + "All unflattened parameters comprising a single flattened " + "parameter must have scalar state with the same value and dtype " + f"but got values {values_set} and dtypes {dtypes} for state " + f"{state_name} and unflattened parameter names " + f"{unflat_param_names}" + ) + value = next(iter(values_set)) + dtype = next(iter(dtypes)) + return torch.tensor(value, dtype=dtype, device=torch.device("cpu")) + + +def _flatten_non_tensor_optim_state( + state_name: str, + non_tensors: List[Any], + unflat_param_names: List[str], +) -> Any: + """ + Flattens the non-tensor optimizer state given by the values ``non_tensors`` + for the state ``state_name`` for a single flattened parameter corresponding + to the unflattened parameter names ``unflat_param_names`` by enforcing that + all values are the same and using that common value. + + See the note in :func:`_flatten_zero_dim_tensor_optim_state`. + + Args: + state_name (str): Optimizer state name. + non_tensors (List[Any]): Non-tensor optimizer state for the unflattened + parameters corresponding to the single flattened parameter. + unflat_param_names (List[str]): A :class:`list` of unflattened + parameter names corresponding to the single flattened parameter. + + Returns: + Any: A non-tensor giving the value of the state ``state_name`` for all + unflattened parameters corresponding to the names + ``unflat_param_names``. + """ + non_none_non_tensors = [nt for nt in non_tensors if nt is not None] + # Enforce that all have the same value (same type already checked) + non_tensor_set = set(non_tensors) + if len(non_none_non_tensors) != len(non_tensors) or len(non_tensor_set) != 1: + raise ValueError( + "All unflattened parameters comprising a single flattened " + "parameter must have scalar state with the same value and dtype " + f"but got values {non_tensor_set} for state {state_name} and " + f"unflattened parameter names {unflat_param_names}" + ) + non_tensor = next(iter(non_tensor_set)) + return non_tensor + + +def _process_pos_dim_tensor_state( + flat_optim_state_dict: Dict[str, Any], + world_size: int, +) -> Dict[str, Any]: + """ + Processes positive-dimension tensor states in ``flat_optim_state_dict`` by + replacing them with metadata. This is done so the processed optimizer state + dict can be broadcast from rank 0 to all ranks without copying those tensor + states, and thus, this is meant to only be called on rank 0. + + Args: + flat_optim_state_dict (Dict[str, Any]): Flattened optimizer state dict + with the positive-dimension tensor states unsharded. + + Returns: + Dict[str, Any]: The flattened optimizer state dict with positive- + dimension tensor states replaced by metadata. + """ + flat_osd = flat_optim_state_dict # alias + no_tensor_osd: Dict[str, Any] = {"state": {}} + for key, param_state in flat_osd["state"].items(): + no_tensor_osd["state"][key] = {} + for state_name, value in sorted_items(param_state): + is_pos_dim_tensor_state = torch.is_tensor(value) and value.dim() > 0 + if not is_pos_dim_tensor_state: + no_tensor_osd["state"][key][state_name] = value + continue + if key.is_flat_param: # FSDP parameter + sharded_size = FlatParamHandle._get_sharded_size( + value, rank=0, world_size=world_size + ) + assert len(sharded_size) == 1, f"{sharded_size}" + info = _PosDimTensorInfo(sharded_size, value.dtype) + else: # non-FSDP parameter + info = _PosDimTensorInfo(value.shape, value.dtype) + no_tensor_osd["state"][key][state_name] = info + no_tensor_osd["param_groups"] = flat_osd["param_groups"] + return no_tensor_osd + + +def _broadcast_processed_optim_state_dict( + processed_optim_state_dict: Optional[Dict[str, Any]], + rank: int, + group, +) -> Dict[str, Any]: + """ + Broadcasts the processed optimizer state dict from rank 0 to all ranks. + + Args: + processed_optim_state_dict (Optional[Dict[str, Any]]): The flattened + optimizer state dict with positive-dimension tensor states replaced + with metadata if on rank 0; ignored otherwise. + + Returns: + Dict[str, Any]: The processed optimizer state dict. + """ + # Broadcast the two data structures rank 0 to all ranks + obj_list = [processed_optim_state_dict] if rank == 0 else [None] + dist.broadcast_object_list(obj_list, src=0, group=group) + processed_optim_state_dict = obj_list[0] # type: ignore[assignment] + assert processed_optim_state_dict is not None + # Keep zero-dimension tensors on CPU + return processed_optim_state_dict + + +def _broadcast_pos_dim_tensor_states( + processed_optim_state_dict: Dict[str, Any], + flat_optim_state_dict: Optional[Dict[str, Any]], + rank: int, + world_size: int, + group, + broadcast_device: torch.device, +) -> Dict[str, Any]: + """ + Takes ``processed_optim_state_dict``, which has metadata in place of + positive-dimension tensor states, and broadcasts those tensor states from + rank 0 to all ranks. For tensor states corresponding to FSDP parameters, + rank 0 shards the tensor and broadcasts shard-by-shard, and for tensor + states corresponding to non-FSDP parameters, rank 0 broadcasts the full + tensor. + + Args: + processed_optim_state_dict (Dict[str, Any]): The flattened optimizer + state dict with positive-dimension tensor states replaced with + metadata; this should be returned by + :meth:`_process_pos_dim_tensor_state` and non-empty on all ranks. + flat_optim_state_dict (Optional[Dict[str, Any]]): The flattened + unsharded optimizer state dict with the actual positive-dimension + tensor states if on rank 0; ignored on nonzero ranks. + + Returns: + Dict[str, Any]: The optimizer state dict with the positive-dimension + tensor state correctly populated via ``broadcast()`` s from rank 0. + """ + assert ( + rank != 0 or flat_optim_state_dict is not None + ), "Expects rank 0 to pass in the flattened optimizer state dict" + no_tensor_osd = processed_optim_state_dict # alias + flat_osd = flat_optim_state_dict # alias + for key, param_state in no_tensor_osd["state"].items(): + for state_name, value in sorted_items(param_state): + is_pos_dim_tensor_state = isinstance(value, _PosDimTensorInfo) + if not is_pos_dim_tensor_state: + continue + if rank == 0: + assert flat_osd is not None + unsharded_tensor = flat_osd["state"][key][state_name] + else: + unsharded_tensor = None + shape, dtype = value.shape, value.dtype + if key.is_flat_param: # FSDP parameter + _broadcast_sharded_pos_dim_tensor_state( + unsharded_tensor, + param_state, + state_name, + shape, + dtype, + broadcast_device, + rank, + world_size, + group, + ) # modify `param_state` destructively + else: # non-FSDP parameter + _broadcast_unsharded_pos_dim_tensor_state( + unsharded_tensor, + param_state, + state_name, + shape, + dtype, + broadcast_device, + rank, + group, + ) # modify `param_state` destructively + return no_tensor_osd + + +def _broadcast_sharded_pos_dim_tensor_state( + unsharded_tensor: Optional[torch.Tensor], + param_state: Dict[str, Any], + state_name: str, + shape: torch.Size, + dtype: torch.dtype, + broadcast_device: torch.device, + rank: int, + world_size: int, + group, +) -> None: + """ + Broadcasts positive-dimension tensor state for the state ``state_name`` + corresponding to an FSDP parameter shard-by-shard, only to be saved on the + relevant rank. This modifies ``param_state`` destructively. + + Args: + unsharded_tensor (Optional[torch.Tensor]): Unsharded tensor from which + to broadcast shards if on rank 0; ignored otherwise. + shape (torch.Size): Shape of the sharded tensor; same on all ranks. + """ + get_shard: Optional[functools.partial[Tuple[torch.Tensor, int]]] = None + if rank == 0: + assert ( + unsharded_tensor is not None + ), "Expects rank 0 to pass in the unsharded tensor" + get_shard = functools.partial( + FlatParamHandle._get_shard, + unsharded_tensor, + ) + for target_rank in range(1, world_size): + if rank == 0: + assert get_shard is not None + sharded_tensor = get_shard(target_rank, world_size)[0].to(broadcast_device) + else: + sharded_tensor = torch.zeros( + shape, + requires_grad=False, + dtype=dtype, + device=broadcast_device, + ) + dist.broadcast(sharded_tensor, src=0, group=group) + # Only keep the shard on the target rank and keep it on the broadcast + # device, which is typically GPU + if rank == target_rank: + param_state[state_name] = sharded_tensor + else: + del sharded_tensor + # Lastly, shard on rank 0 + if rank != 0: + return + param_state[state_name] = get_shard(0, world_size)[0].to(broadcast_device) # type: ignore[misc] + + +def _broadcast_unsharded_pos_dim_tensor_state( + unsharded_tensor: Optional[torch.Tensor], + param_state: Dict[str, Any], + state_name: str, + shape: torch.Size, + dtype: torch.dtype, + broadcast_device: torch.device, + rank: int, + group, +) -> None: + """ + Broadcasts positive-dimension tensor state for the state ``state_name`` + corresponding to an unsharded non-FSDP parameter from rank 0 to all ranks. + This modifies ``param_state`` destructively. + + Args: + unsharded_tensor (Optional[torch.Tensor]): Unsharded tensor to + broadcast if on rank 0; ignored otherwise. + """ + if rank == 0: + assert ( + unsharded_tensor is not None + ), "Expects rank 0 to pass in the unsharded tensor" + assert ( + shape == unsharded_tensor.shape + ), f"Shape mismatch: {shape} {unsharded_tensor.shape}" + assert ( + dtype == unsharded_tensor.dtype + ), f"dtype mismatch: {dtype} {unsharded_tensor.dtype}" + unsharded_tensor = unsharded_tensor.to(broadcast_device) + else: + unsharded_tensor = torch.zeros( + shape, + requires_grad=False, + dtype=dtype, + device=broadcast_device, + ) + dist.broadcast(unsharded_tensor, src=0, group=group) + # Keep the tensor on the broadcast device, which is typically GPU + param_state[state_name] = unsharded_tensor + + +def _rekey_sharded_optim_state_dict( + sharded_osd: Dict[str, Any], + model: torch.nn.Module, + optim: torch.optim.Optimizer, + optim_input: Optional[ + Union[ + List[Dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ], + using_optim_input: bool, +) -> Dict[str, Any]: + """ + Rekeys the optimizer state dict from unflattened parameter names to + flattened parameter IDs according to the calling rank's ``optim``, which + may be different across ranks. In particular, the unflattened parameter + names are represented as :class:`_OptimStateKey` s. + """ + param_to_flat_param_id = ( + _get_param_to_param_id_from_optim_input(model, optim_input) + if using_optim_input + else _get_param_to_param_id(optim) + ) + param_to_unflat_param_names = FSDP._get_param_to_unflat_param_names(model) + # All parameter keys in `param_to_flat_param_id` should be in + # `param_to_unflat_param_names` -- strict inequality follows when not all + # parameters are passed to the optimizer + assert len(param_to_flat_param_id) <= len(param_to_unflat_param_names) + + unflat_param_names_to_flat_param_id: Dict[Tuple[str, ...], int] = {} # for "state" + unflat_param_name_to_flat_param_id: Dict[str, int] = {} # for "param_groups" + for param, unflat_param_names in param_to_unflat_param_names.items(): + if param not in param_to_flat_param_id: + # This parameter was not passed to the optimizer + continue + flat_param_id = param_to_flat_param_id[param] + unflat_param_names_to_flat_param_id[tuple(unflat_param_names)] = flat_param_id + for unflat_param_name in unflat_param_names: + unflat_param_name_to_flat_param_id[unflat_param_name] = flat_param_id + + sharded_osd_state = sharded_osd["state"] + rekeyed_osd_state = {} + for key, param_state in sharded_osd_state.items(): + flat_param_id = unflat_param_names_to_flat_param_id[key.unflat_param_names] + rekeyed_osd_state[flat_param_id] = param_state + + rekeyed_osd_param_groups: List[Dict[str, Any]] = [] + for unflat_param_group in sharded_osd["param_groups"]: + flat_param_group = copy.deepcopy(unflat_param_group) + flat_param_ids = sorted( + set( + unflat_param_name_to_flat_param_id[unflat_param_name] + for unflat_param_name in unflat_param_group["params"] + ) + ) + flat_param_group["params"] = flat_param_ids + rekeyed_osd_param_groups.append(flat_param_group) + + return {"state": rekeyed_osd_state, "param_groups": rekeyed_osd_param_groups} + + +def _get_flat_param_to_fsdp_module(model: torch.nn.Module): + """ + Constructs a mapping from FSDP flattened parameters to their owning FSDP + modules and ensures that all FSDP modules are initialized. + + Args: + model (torch.nn.model): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance). + + Returns: + Dict[FlatParameter, FullyShardedDataParallel]: Mapping from FSDP + flattened parameters to their owning FSDP modules. + """ + flat_param_to_fsdp_module = {} + for module in model.modules(): + if isinstance(module, FSDP.FullyShardedDataParallel): + module._lazy_init() + for param in module.params: # may have none + flat_param_to_fsdp_module[param] = module + return flat_param_to_fsdp_module + + +def _get_param_id_to_param( + optim: torch.optim.Optimizer, +): + """ + Constructs a mapping from parameter IDs to parameters. This may be used + both for models with ``FlatParameter`` s and without. + """ + param_id_to_param: List[nn.Parameter] = [] + for param_group in optim.param_groups: + for param in param_group["params"]: + param_id_to_param.append(param) + return param_id_to_param + + +def _get_param_id_to_param_from_optim_input( + model: torch.nn.Module, + optim_input: Optional[ + Union[ + List[Dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, +) -> List[torch.nn.Parameter]: + """ + Constructs a mapping from parameter IDs to parameters. This may be used + both for models with ``FlatParameter`` s and without. + + NOTE: This method is only preserved for backward compatibility. The method + :meth:`_get_param_id_to_param` is the preferred code path that does not + rely on ``optim_input``. + + NOTE: We critically assume that, whether the optimizer input is a list of + parameters or a list of parameter groups, :class:`torch.optim.Optimizer` + enumerates the parameter IDs in order. In other words, for a parameter list + input, the parameter IDs should be in that list order, and for a parameter + groups input, the parameter IDs should be in order within each parameter + group and in order across parameter groups. + + Args: + model (torch.nn.Module): Model whose parameters are passed into the + optimizer. + optim_input (Optional[Union[List[Dict[str, Any]], + Iterable[torch.nn.Parameter]]]): Input passed into the optimizer + representing either a :class:`list` of parameter groups or an + iterable of parameters; if ``None``, then this method assumes the + input was ``model.parameters()``. (Default: ``None``) + + Returns: + List[torch.nn.Parameter]: Mapping from parameter IDs to parameters, + where the parameter ID is implicitly the index in the :class:`list`. + """ + # Assume the standard case of passing `model.parameters()` to the optimizer + # if `optim_input` is not specified + if optim_input is None: + return list(model.parameters()) + try: + params = list(optim_input) + except TypeError: + raise TypeError( + "Optimizer input should be an iterable of Tensors or dicts, " + f"but got {optim_input}" + ) + if len(params) == 0: + raise ValueError("Optimizer input should not be empty") + + # Check if the optimizer input represents tensors or parameter groups + all_tensors = True + all_dicts = True + for param in params: + all_tensors &= isinstance(param, torch.Tensor) + all_dicts &= isinstance(param, dict) + if not all_tensors and not all_dicts: + raise TypeError("Optimizer input should be an iterable of Tensors or dicts") + if all_tensors: + return params # type: ignore[return-value] + assert all_dicts + param_id_to_param = [] + for param_group in params: + has_params_key = "params" in param_group # type: ignore[operator] + assert has_params_key, ( + 'A parameter group should map "params" to a list of the ' + "parameters in the group" + ) + for param in param_group["params"]: # type: ignore[index] + # Implicitly map `flat_param_id` (current length of the list) to + # `param` + param_id_to_param.append(param) + return param_id_to_param # type: ignore[return-value] + + +def _get_param_to_param_id( + optim: torch.optim.Optimizer, +) -> Dict[torch.nn.Parameter, int]: + """Constructs the inverse mapping of :func:`_get_param_id_to_param`.""" + param_id_to_param = _get_param_id_to_param(optim) + return {param: param_id for param_id, param in enumerate(param_id_to_param)} + + +def _get_param_to_param_id_from_optim_input( + model: torch.nn.Module, + optim_input: Optional[ + Union[ + List[Dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, +) -> Dict[torch.nn.Parameter, int]: + """Constructs the inverse mapping of :func:`_get_param_id_to_param`.""" + param_id_to_param = _get_param_id_to_param_from_optim_input(model, optim_input) + return {param: param_id for param_id, param in enumerate(param_id_to_param)} + + +def _get_unflat_to_flat_param_ids( + flat_to_unflat_param_ids: Dict[int, List[int]], +) -> List[int]: + """ + Inverts the mapping ``flat_to_unflat_param_ids`` to be from unflattened + parameter ID to flattened parameter ID, where the unflattened parameter ID + is the index in the returned :class:`list`. There may be multiple + unflattened parameter IDs mapping to the same flattened parameter ID. + + Args: + flat_to_unflat_param_ids (Dict[int, List[int]]): A mapping from + flattened parameter ID to a :class:`list` of corresponding + unflattened parameter IDs. + + Returns: + List[int]: A mapping from unflattened parameter ID to flattened + parameter ID, where the unflattened parameter ID is the index in the + :class:`list`. + """ + # Construct as a dict and then convert to list + unflat_to_flat_param_ids = {} + for flat_param_id, unflat_param_ids in flat_to_unflat_param_ids.items(): + for unflat_param_id in unflat_param_ids: + assert unflat_param_id not in unflat_to_flat_param_ids, ( + "`flat_to_unflat_param_ids` has the unflattened parameter " + f"ID {unflat_param_id} mapped to multiple flattened " + "parameter IDs" + ) + unflat_to_flat_param_ids[unflat_param_id] = flat_param_id + num_unflat_param_ids = len(unflat_to_flat_param_ids) + unflat_param_ids_set = set(unflat_to_flat_param_ids.keys()) + assert unflat_param_ids_set == set(range(num_unflat_param_ids)), ( + "The set of unflattened parameter IDs should be {0, ..., " + + str(num_unflat_param_ids - 1) + + "} but got " + + f"{unflat_param_ids_set}" + ) + return [ + unflat_to_flat_param_ids[unflat_param_id] + for unflat_param_id in range(num_unflat_param_ids) + ] + + +def _is_zero_dim_tensor(x: Any) -> bool: + return torch.is_tensor(x) and x.dim() == 0 + + +def _optim_state_dict( + model: torch.nn.Module, + optim: torch.optim.Optimizer, + optim_input: Optional[ + Union[ + List[Dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ], + rank0_only: bool, + shard_state: bool, + group: Optional[dist.ProcessGroup], + using_optim_input: bool, +) -> Dict[str, Any]: + """ + Consolidates the optimizer state and returns it as a :class:`dict` + following the convention of :meth:`torch.optim.Optimizer.state_dict`, + i.e. with keys ``"state"`` and ``"param_groups"``. + The flattened parameters in ``FSDP`` modules contained in ``model`` + are mapped back to their unflattened parameters. + + Args: + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + were passed into the optimizer ``optim``. + optim (torch.optim.Optimizer): Optimizer for ``model`` 's + parameters. + rank0_only (bool): If ``True``, saves the populated :class:`dict` + only on rank 0; if ``False``, saves it on all ranks. (Default: + ``True``) + shard_state (bool): If ``True``, shard and distribute all + non-zero-dimension states. + + Returns: + Dict[str, Any]: A :class:`dict` containing the optimizer state for + ``model`` 's original unflattened parameters and including keys + "state" and "param_groups" following the convention of + :meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=False``, + then nonzero ranks return an empty :class:`dict`. + """ + osd = optim.state_dict() + osd_state, osd_param_groups = osd["state"], osd["param_groups"] + rank = dist.get_rank(group) + to_save = not rank0_only or (rank == 0 or shard_state) + fsdp_osd: Dict = {"state": {}, "param_groups": []} if to_save else {} + fsdp_osd_state = fsdp_osd["state"] if to_save else None + + # Construct the local mapping between unflattened parameter names + # (`_OptimStateKey`s) and parameter IDs and broadcast rank 0's mapping + param_to_unflat_param_names: Dict[ + torch.nn.Parameter, List[str] + ] = FSDP._get_param_to_unflat_param_names(model) + flat_param_id_to_param: List[torch.nn.Parameter] = ( + _get_param_id_to_param_from_optim_input(model, optim_input) + if using_optim_input + else _get_param_id_to_param(optim) + ) + optim_state_key_to_flat_param_id: Dict[_OptimStateKey, int] = {} # local + r0_flat_param_id_to_optim_state_key: Dict[ + int, _OptimStateKey + ] = collections.OrderedDict() # rank 0 + for flat_param_id, param in enumerate(flat_param_id_to_param): + # Do not include parameters without state to avoid empty mappings + # just like in normal `torch.optim.Optimizer.state_dict()` + if flat_param_id not in osd_state: + continue + optim_state_key = _OptimStateKey( + unflat_param_names=tuple(param_to_unflat_param_names[param]), + is_flat_param=isinstance(param, FlatParameter), + ) + if rank == 0: + r0_flat_param_id_to_optim_state_key[flat_param_id] = optim_state_key + optim_state_key_to_flat_param_id[optim_state_key] = flat_param_id + key_obj_list: List[Optional[Dict[int, _OptimStateKey]]] = ( + [r0_flat_param_id_to_optim_state_key] if rank == 0 else [None] + ) + dist.broadcast_object_list(key_obj_list, src=0, group=group) + assert key_obj_list[0] is not None + r0_flat_param_id_to_optim_state_key = key_obj_list[0] + + # Ensure that all ranks have at least the optimizer states needed by + # rank 0's optimizer + missing_keys: List[_OptimStateKey] = [] + for r0_optim_state_key in r0_flat_param_id_to_optim_state_key.values(): + if r0_optim_state_key not in optim_state_key_to_flat_param_id: + # A parameter from rank 0's optimizer does not exist for this + # rank's optimizer + missing_keys.append(r0_optim_state_key) + continue + flat_param_id = optim_state_key_to_flat_param_id[r0_optim_state_key] + assert flat_param_id >= 0 and flat_param_id < len( + flat_param_id_to_param + ), "Check the `flat_param_id_to_param` construction" + device = torch.device("npu", torch_npu.npu.current_device()) + num_missing = torch.tensor([len(missing_keys)], dtype=torch.int32, device=device) + dist.all_reduce(num_missing, group=group) + if num_missing.item() > 0: + obj_list = [None for _ in range(dist.get_world_size(group))] + dist.all_gather_object(obj_list, missing_keys, group=group) + error_msg = ( + "FSDP currently requires each rank to have at least the " + "optimizer states needed by rank 0's optimizer but some ranks " + "are missing some of those states" + ) + for rank, keys in enumerate(obj_list): + keys = cast(List[_OptimStateKey], keys) + if len(keys) > 0: + error_msg += ( + f"\nRank {rank} is missing states for the parameters: " + f"{[key.unflat_param_names for key in keys]}" + ) + raise RuntimeError(error_msg) + + # Iterate in rank 0's flattened parameter ID order to ensure aligned + # all-gathers across ranks + flat_param_to_fsdp_module = _get_flat_param_to_fsdp_module(model) + for r0_optim_state_key in r0_flat_param_id_to_optim_state_key.values(): + flat_param_id = optim_state_key_to_flat_param_id[r0_optim_state_key] + param = flat_param_id_to_param[flat_param_id] + if r0_optim_state_key.is_flat_param: + fsdp_module = flat_param_to_fsdp_module[param] + unflat_state = _unflatten_optim_state( + cast(FlatParameter, param), + osd_state[flat_param_id], + fsdp_module, + to_save, + shard_state, + ) + if to_save: + assert len(unflat_state) == len(r0_optim_state_key.unflat_param_names) + for unflat_param_name, unflat_param_state in zip( + r0_optim_state_key.unflat_param_names, + unflat_state, + ): + fsdp_osd_state[unflat_param_name] = unflat_param_state + elif to_save: + assert len(r0_optim_state_key.unflat_param_names) == 1 + unflat_param_name = r0_optim_state_key.unflat_param_names[0] + fsdp_osd_state[unflat_param_name] = copy.copy(osd_state[flat_param_id]) + for state_name, value in sorted_items(fsdp_osd_state[unflat_param_name]): + if torch.is_tensor(value): + fsdp_osd_state[unflat_param_name][state_name] = value.cpu() + + if not to_save: + return {} + + # Handle the "param_groups" part of the optimizer state dict + fsdp_osd_param_groups = fsdp_osd["param_groups"] # alias + for flat_param_group in osd_param_groups: + unflat_param_group = copy.deepcopy(flat_param_group) + param_group_params = [ + flat_param_id_to_param[flat_param_id] + for flat_param_id in flat_param_group["params"] + ] + nested_unflat_param_names = [ + param_to_unflat_param_names[param] for param in param_group_params + ] + unflat_param_group["params"] = [ + unflat_param_name + for unflat_param_names in nested_unflat_param_names + for unflat_param_name in unflat_param_names + ] # flatten the list of lists + fsdp_osd_param_groups.append(unflat_param_group) + return fsdp_osd + + +def apply_fsdp_optim_utils(): + torch.distributed.fsdp._optim_utils.sorted_items = sorted_items + torch.distributed.fsdp._optim_utils._ConsolidatedOptimState = _ConsolidatedOptimState + torch.distributed.fsdp._optim_utils._PosDimTensorInfo = _PosDimTensorInfo + torch.distributed.fsdp._optim_utils._OptimStateKey = _OptimStateKey + torch.distributed.fsdp._optim_utils._unflatten_optim_state = _unflatten_optim_state + torch.distributed.fsdp._optim_utils._communicate_optim_state = _communicate_optim_state + torch.distributed.fsdp._optim_utils._unflatten_communicated_optim_state = _unflatten_communicated_optim_state + torch.distributed.fsdp._optim_utils._flatten_optim_state_dict = _flatten_optim_state_dict + torch.distributed.fsdp._optim_utils._flatten_optim_state = _flatten_optim_state + torch.distributed.fsdp._optim_utils._flatten_tensor_optim_state = _flatten_tensor_optim_state + torch.distributed.fsdp._optim_utils._flatten_zero_dim_tensor_optim_state = _flatten_zero_dim_tensor_optim_state + torch.distributed.fsdp._optim_utils._flatten_non_tensor_optim_state = _flatten_non_tensor_optim_state + torch.distributed.fsdp._optim_utils._process_pos_dim_tensor_state = _process_pos_dim_tensor_state + torch.distributed.fsdp._optim_utils._broadcast_processed_optim_state_dict = _broadcast_processed_optim_state_dict + torch.distributed.fsdp._optim_utils._broadcast_pos_dim_tensor_states = _broadcast_pos_dim_tensor_states + torch.distributed.fsdp._optim_utils._broadcast_sharded_pos_dim_tensor_state = _broadcast_sharded_pos_dim_tensor_state + torch.distributed.fsdp._optim_utils._broadcast_unsharded_pos_dim_tensor_state = _broadcast_unsharded_pos_dim_tensor_state + torch.distributed.fsdp._optim_utils._rekey_sharded_optim_state_dict = _rekey_sharded_optim_state_dict + torch.distributed.fsdp._optim_utils._get_flat_param_to_fsdp_module = _get_flat_param_to_fsdp_module + torch.distributed.fsdp._optim_utils._get_param_id_to_param = _get_param_id_to_param + torch.distributed.fsdp._optim_utils._get_param_id_to_param_from_optim_input = _get_param_id_to_param_from_optim_input + torch.distributed.fsdp._optim_utils._get_param_to_param_id = _get_param_to_param_id + torch.distributed.fsdp._optim_utils._get_param_to_param_id_from_optim_input = _get_param_to_param_id_from_optim_input + torch.distributed.fsdp._optim_utils._get_unflat_to_flat_param_ids = _get_unflat_to_flat_param_ids + torch.distributed.fsdp._optim_utils._is_zero_dim_tensor = _is_zero_dim_tensor + torch.distributed.fsdp._optim_utils._optim_state_dict = _optim_state_dict diff --git a/torch_npu/distributed/fsdp/_shard_utils.py b/torch_npu/distributed/fsdp/_shard_utils.py new file mode 100644 index 0000000000..9372d4ed02 --- /dev/null +++ b/torch_npu/distributed/fsdp/_shard_utils.py @@ -0,0 +1,269 @@ +import bisect +import itertools +import math +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch_npu +import torch.distributed as dist +import torch.nn.functional as F +from torch.distributed import distributed_c10d +from torch.distributed._shard.sharded_tensor import ( + Shard, + ShardedTensor, + ShardedTensorMetadata, + TensorProperties, +) +from torch.distributed._shard.sharding_spec import ( + ChunkShardingSpec, + EnumerableShardingSpec, + ShardingSpec, + ShardMetadata, +) + + +def _sharding_spec_to_offsets( + sharding_spec: ShardingSpec, tensor_numel: int, world_size: int +) -> List[int]: + r""" + Translates the sharding spec to a list of offsets along dim 0. If the + sharding spec is ChunkShardingSpec, only the ``dim`` is used and the + placement is not used. + """ + offsets: List[int] = [] + if isinstance(sharding_spec, EnumerableShardingSpec): + for shard in sharding_spec.shards: + offsets.append(shard.shard_offsets[0]) + elif isinstance(sharding_spec, ChunkShardingSpec): + assert sharding_spec.dim == 0 + chunk_size = math.ceil(tensor_numel / world_size) + if chunk_size == 1: + offsets = [ + rank if rank < tensor_numel else tensor_numel + for rank in range(world_size) + ] + else: + offsets = [chunk_size if rank > 0 else 0 for rank in range(world_size)] + offsets = list(itertools.accumulate(offsets)) + else: + raise ValueError(f"Un-recognized sharding spec type {type(sharding_spec)}.") + + return offsets + + +def _offsets_to_split_sizes( + input_offsets: List[int], + output_offsets: List[int], + tensor_numel: int, + world_size: int, + my_rank: int, +) -> Tuple[List[int], List[int]]: + r""" + Given the shard offsets for each rank of the input tensor and output tensor, + this API returns the corresponding split sizes that can be passed to + all_to_all_single(). + """ + + def _get_interval(offsets): + if my_rank != world_size - 1: + return offsets[my_rank], offsets[my_rank + 1] - 1 + else: + return offsets[my_rank], tensor_numel - 1 + + def _offsets_to_sizes(offsets, begin, end): + sizes = [] + for i, offset in enumerate(offsets): + next_offset = offsets[i + 1] if i < len(offsets) - 1 else end + 1 + sizes.append( + (next_offset - offset) + - max(begin - offset, 0) + - max(next_offset - end - 1, 0) + ) + return sizes + + def _convert(from_offsets, to_offsets, split_sizes): + begin, end = _get_interval(from_offsets) + to_begin_rank = bisect.bisect(to_offsets, begin) - 1 + to_end_rank = bisect.bisect(to_offsets, end) - 1 + _split_sizes = _offsets_to_sizes( + to_offsets[to_begin_rank : to_end_rank + 1], begin, end + ) + split_sizes[to_begin_rank : to_end_rank + 1] = _split_sizes + + input_split_sizes = [0 for _ in range(world_size)] + output_split_sizes = [0 for _ in range(world_size)] + _convert(input_offsets, output_offsets, input_split_sizes) + _convert(output_offsets, input_offsets, output_split_sizes) + + return input_split_sizes, output_split_sizes + + +def _reshard_flatten_tensor( + input_tensor: ShardedTensor, + output_spec: ShardingSpec, + world_size: int, + my_rank: int, + device: torch.device, + process_group: Optional[dist.ProcessGroup], +) -> torch.Tensor: + """ + Resharded a sharded flatten tensor, this is used by FSDP to do sharded + state_dict. But the functionaility is not supported by ShardedTensor. + This API is designed to be used for FSDP; therefore this API supports only + 1-D ShardedTensor (hence the naming, reshard_flatten_tensor). + + This API uses the ChunkShardingSpec and EnumerableShardingSpec from + torch.distributed.sharding_spec but ignores the placement field in + ChunkShardingSpec, as the placement requires the callees understand the + number of GPUs per node. The API simply uses the semantics of the sharding + specs. + + Args: + input_tensor (ShardedTensor): the original ShardedTensor. Must be 1D. + output_spec (ShardingSpec): the sharding spect for the output tensor. + world_size (int): total trainer count. + my_rank (int): the rank for this trainer. + + Returns: + The local shard for the new ShardedTensor. + """ + + input_spec = input_tensor.sharding_spec() + size = input_tensor.size() + if isinstance(size, int): + raise ValueError("The input tensor has no dimensions.") + tensor_numel = size.numel() + input_offsets = _sharding_spec_to_offsets(input_spec, tensor_numel, world_size) + output_offsets = _sharding_spec_to_offsets(output_spec, tensor_numel, world_size) + input_split_sizes, output_split_sizes = _offsets_to_split_sizes( + input_offsets, output_offsets, tensor_numel, world_size, my_rank + ) + output_size = sum(output_split_sizes) + local_shard = torch.empty(output_size, dtype=input_tensor.dtype, device=device) + dist.all_to_all_single( + local_shard, + input_tensor.local_shards()[0].tensor, + input_split_sizes=input_split_sizes, + output_split_sizes=output_split_sizes, + group=process_group, + ) + return local_shard + + +def _all_gather_sharded_tensor( + sharded_tensor: ShardedTensor, pg: Optional[dist.ProcessGroup] = None +) -> torch.Tensor: + if pg is None: + pg = distributed_c10d._get_default_group() + world_size = dist.get_world_size(pg) + shards = sharded_tensor.local_shards() + dim_0_size = sharded_tensor.size()[0] # type: ignore[index] + tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr] + chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size + npu_device = torch.device("npu", torch_npu.npu.current_device()) + if shards: + local_tensor = shards[0].tensor.flatten() + if not local_tensor.is_npu: + move_to_cpu = torch.ones(1, device=npu_device) + local_tensor = local_tensor.npu() + else: + move_to_cpu = torch.zeros(1, device=npu_device) + num_padding = chunk_size - local_tensor.numel() + if num_padding > 0: + local_tensor = F.pad(local_tensor, [0, num_padding]) + else: + local_tensor = torch.zeros( + chunk_size, dtype=sharded_tensor.dtype, device=npu_device + ) + move_to_cpu = torch.zeros(1, device=npu_device) + + tensor = torch.empty( + chunk_size * world_size, + dtype=local_tensor.dtype, + device=npu_device, + ) + dist._all_gather_base(tensor, local_tensor, group=pg) + + tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size()) + return tensor + + +def _gather_state_dict( + state_dict: Dict[str, Any], + pg: Optional[dist.ProcessGroup] = None, +) -> Dict[str, Any]: + """ + Given a state_dict, this API gathers all the ShardedTensors in the state_dict. + """ + new_state_dict = {} + for key, tensor in state_dict.items(): + if isinstance(tensor, ShardedTensor): + output_tensor = _all_gather_sharded_tensor(tensor, pg) + if tensor.local_shards() and tensor.local_shards()[0].tensor.is_npu: + tensor = output_tensor + else: + tensor = output_tensor.cpu() + new_state_dict[key] = tensor + return new_state_dict + + +def _create_chunk_sharded_tensor( + tensor: torch.Tensor, + rank: int, + world_size: int, + num_devices_per_node: int, + pg: dist.ProcessGroup, +) -> ShardedTensor: + """ + Shard a tensor to chunks along the first dimension. The local rank will gets its + corresponding chunk as the local shard to create a ShardedTensor. + """ + chunks = tensor.chunk(world_size, dim=0) + if len(chunks) > rank: + local_shard = chunks[rank].clone() + offsets = [0 for _ in tensor.size()] + offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank + local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)] + else: + local_shards = [] + + # Create a ShardedTensor without invoking communication. + chunk_sizes = [list(chunk.size()) for chunk in chunks] + dim0_offsets = [0] + list( + itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes]) + )[:-1] + offsets = [0] * (len(chunk_sizes[0]) - 1) + chunk_offsets = [[d0] + offsets for d0 in dim0_offsets] + placements = [ + f"rank:{r}/npu:{r % num_devices_per_node}" for r in range(len(chunk_sizes)) + ] + assert len(chunk_sizes) == len(chunk_offsets) == len(placements) + shard_metadata = [ + ShardMetadata(offset, size, placement) + for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements) + ] + sharded_tensor_metadata = ShardedTensorMetadata( + shards_metadata=shard_metadata, + size=tensor.size(), + tensor_properties=TensorProperties( + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=False, + memory_format=torch.contiguous_format, + pin_memory=tensor.is_pinned(), + ) + ) + return ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards, + sharded_tensor_metadata=sharded_tensor_metadata, + process_group=pg + ) + + +def apply_fsdp_shard_utils(): + torch.distributed.fsdp._shard_utils._create_chunk_sharded_tensor = _create_chunk_sharded_tensor + torch.distributed.fsdp._shard_utils._gather_state_dict = _gather_state_dict + torch.distributed.fsdp._shard_utils._all_gather_sharded_tensor = _all_gather_sharded_tensor + torch.distributed.fsdp._shard_utils._reshard_flatten_tensor = _reshard_flatten_tensor + torch.distributed.fsdp._shard_utils._offsets_to_split_sizes = _offsets_to_split_sizes + torch.distributed.fsdp._shard_utils._sharding_spec_to_offsets = _sharding_spec_to_offsets diff --git a/torch_npu/distributed/fsdp/_symbolic_trace.py b/torch_npu/distributed/fsdp/_symbolic_trace.py new file mode 100644 index 0000000000..026595fd7d --- /dev/null +++ b/torch_npu/distributed/fsdp/_symbolic_trace.py @@ -0,0 +1,243 @@ +import contextlib +import functools +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple + +import torch + + +__all__ = ["TracingConfig"] + + +@dataclass +class TracingConfig: + """ + Configurations used in ``ParamExecOrderWrapPolicy`` for symbolic tracing of + a model. + + Args: + tracer (torch.fx.Tracer): An instance of ``torch.fx.Tracer`` that will + be used to perform symbolic tracing. ``tracer`` is default to be + ``torch.fx.Tracer()``, but can also be instance of some child class + of ``torch.fx.Tracer``. For example, one may want to use + ``HFTracer`` for models in Transformers: .. _Transformers: + https://huggingface.co/docs/transformers/index + concrete_args (Optional[Dict[str, Any]]): Concrete arguments that should + not be treated as ``torch.fx.Proxy`` when tracing the forward + function. ``concrete_args`` allows one to partially specialize the + forward function, including removing control flow or data + structures. ``concrete_args`` is also the argument used in + :meth:`~torch.fx.Tracer.trace`. + """ + + tracer: torch.fx.Tracer = torch.fx.Tracer() + concrete_args: Optional[Dict[str, Any]] = None + + +@dataclass +class _ExecutionInfo: + """ + Contains the execution order information in the model forward pass. + + Attributes: + current_module: record the module that is currently being traced. + + module_forward_order: a list of modules, where the ordering is based on + when their forward function is called. ``module_forward_order`` + includes the info of how many times a module is called + used to + check the forward order in different iterations. + + param_exec_order: a list of parameters ordered based on their execution + order. + + module_to_execution_infos: a dict that maps each module to a list of + tuples each containing a module and a list of named parameters. + ``module_execution_info_dict`` is used as the parameter execution + order info. For a given module, each tuple: 1. either contains this + module and part of its ``named_parameters`` that will be executed + together, 2. or contains one of its child modules and all of the + child module's ``named_parameters``. The list of tuples is ordered + based on the parameter execution order. + """ + + current_module: torch.nn.Module + module_forward_order: List[torch.nn.Module] + module_to_execution_infos: Dict[ + torch.nn.Module, + List[Tuple[torch.nn.Module, List[Tuple[str, torch.nn.Parameter]]]], + ] + param_exec_order: List[torch.nn.Parameter] = field(default_factory=list) + + +def _init_execution_info(root_module: torch.nn.Module) -> _ExecutionInfo: + """ + Create an instance of _ExecutionInfo with initialization based on + ``root_module``. + + Args: + root_module (torch.nn.Module): the module to get the execution + information via ``tracer.trace()`` inside ``_patch_tracer``. + """ + return _ExecutionInfo( + current_module=root_module, + module_forward_order=[root_module], + module_to_execution_infos={root_module: []}, + ) + + +def _patched_create_proxy( + create_proxy: Callable, + execution_info: _ExecutionInfo, + prefixed_param_name_to_param: Dict[str, torch.nn.Parameter], + kind: str, + target: torch.fx.node.Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + proxy_factory_fn: Callable[[torch.fx.Node], torch.fx.Proxy] = None, +) -> torch.fx.Proxy: + """ + Override of :meth:`~torch.fx.Tracer.create_proxy`. ``Tracer.create_proxy`` + is called in symbolic tracing for each leaf function/method/module. This + override intercepts the recording of each of these operations to update + ``execution_info.module_to_execution_infos``. + + Args: + create_proxy (Callable): + The ``create_proxy`` function to be patched. + execution_info (_ExecutionInfo): + Used to record the execution information. + prefixed_param_name_to_param (Dict[str, torch.nn.Parameter]): + A dict that maps each prefixed parameter name to the parameter. + kind (str): + The type of the target method. One of 'call_function', + 'call_method', 'get_attr', 'call_module', 'placeholder', or + 'output'. The semantics of these opcodes are described in the + ``torch.fx.Graph`` docstring. This is the input to ``create_proxy``. + target (torch.fx.node.Target): + Contains the string name of the method. This is the input to + ``create_proxy``. + args (Tuple[Any, ...]): + Arguments of the method. This is the input to ``create_proxy``. + kwargs (Dict[str, Any]): + Keyword arguments of the method. This is the input to + ``create_proxy``. + name (Optional[str]): + An optional string name for the ``Node`` created in + ``create_proxy``. This is the input to ``create_proxy``. + type_expr (Optional[Any]): + An optional type annotation representing the Python type the output + of a node will have. This is the input to ``create_proxy``. + proxy_factory_fn (Callable[[torch.fx.Node], torch.fx.Proxy]): + An alternative proxy constructor used in ``create_proxy``. This is + the input to ``create_proxy``. + """ + proxy = create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) + + module = execution_info.current_module + if kind in ["call_function", "call_method"]: + if args is not None: + named_params: List[Tuple[str, torch.nn.Parameter]] = [] + for arg in args: + if isinstance(arg, torch.fx.Proxy) and arg.node.target in prefixed_param_name_to_param: + param = prefixed_param_name_to_param[arg.node.target] + named_params.append((arg.node.target, param)) + if param not in set(execution_info.param_exec_order): + execution_info.param_exec_order.append(param) + if named_params: + execution_info.module_to_execution_infos[module].append((module, named_params)) + elif kind == "call_module": + named_params = list(module.named_parameters()) + if named_params: + execution_info.module_to_execution_infos[module].append( + (module, named_params) + ) + for (_, p) in named_params: + if p not in set(execution_info.param_exec_order): + execution_info.param_exec_order.append(p) + return proxy + + +def _patched_call_module( + call_module: Callable, + execution_info: _ExecutionInfo, + module: torch.nn.Module, + forward: Callable[..., Any], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], +) -> Any: + """ + Override of :meth:`~torch.fx.Tracer.call_module`. ``Tracer.call_module`` is + called in symbolic tracing for each non-root module. This override + intercepts the recording of each operation to update + ``execution_info.module_forward_order`` and + ``execution_info.module_to_execution_infos``. + + Args: + call_module (Callable): + The ``call_module`` function to be patched. + execution_info (_ExecutionInfo): + Used to repord the execution information. + module (torch.nn.Module): + The module for which a call is being emitted. + forward (Callable[..., Any]): + The ``forward()`` method of the ``torch.nn.Module`` to be invoked. + args (Tuple[Any, ...]): + ``args`` of the module callsite. + kwargs (Dict[str, Any]): + ``kwargs`` of the module callsite. + """ + execution_info.module_forward_order.append(module) + named_params = list(module.named_parameters()) + if named_params: + execution_info.module_to_execution_infos[execution_info.current_module].append( + (module, list(module.named_parameters())) + ) + # Stores away current_module for restoration later + prev_current_module = execution_info.current_module + execution_info.current_module = module + # Note that if the forward of module is called multiple times, this will record + # the execution info of the last forward pass. + execution_info.module_to_execution_infos[module] = [] + output = call_module(module, forward, args, kwargs) + execution_info.current_module = prev_current_module + return output + + +@contextlib.contextmanager +def _patch_tracer( + tracer: torch.fx.Tracer, + root_module: torch.nn.Module, + execution_info: _ExecutionInfo, +) -> Generator: + """ + Within the context manager, patches the input tracer so that during + ``tracer.trace()``, the forward order of all modules and the parameter + execution information are recorded. The patches of the input tracer will be + removed after the context manager exits. + + Args: + tracer (torch.fx.Tracer): the input ``tracer`` whose member functions + will be patched within the context manager. + root_module (torch.nn.Module): the top-level module to be traced + and should not contain any FSDP modules. + execution_info (_ExecutionInfo): used to record the execution order + information when performing ``tracer.trace()`` within the context + manager. + """ + original_call_module = tracer.call_module + original_create_proxy = tracer.create_proxy + + tracer.call_module = functools.partial( + _patched_call_module, original_call_module, execution_info + ) + prefixed_param_name_to_param = dict(root_module.named_parameters()) + tracer.create_proxy = functools.partial( + _patched_create_proxy, original_create_proxy, execution_info, prefixed_param_name_to_param + ) + try: + yield + finally: + tracer.call_module = original_call_module + tracer.create_proxy = original_create_proxy diff --git a/torch_npu/distributed/fsdp/_utils.py b/torch_npu/distributed/fsdp/_utils.py new file mode 100644 index 0000000000..80688e5dec --- /dev/null +++ b/torch_npu/distributed/fsdp/_utils.py @@ -0,0 +1,149 @@ +import dataclasses +import traceback +from collections import OrderedDict +from typing import Any, Callable, Dict, List, Set, Tuple, Union + +import torch +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parallel.scatter_gather import ( # type: ignore[attr-defined] + _is_namedtuple, +) +from torch.nn.utils.rnn import PackedSequence + + +FSDP_FLATTENED = "_fsdp_flattened" + + +def _contains_batchnorm(module): + return any( + isinstance(mod, _BatchNorm) for mod in module.modules() + ) + + +def _override_batchnorm_mixed_precision(module): + for mod in module.modules(): + if isinstance(mod, _BatchNorm): + mod._wrap_overrides = {"mixed_precision": None} # type: ignore[assignment] + + +def _apply_to_tensors( + fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence] +) -> Any: + """Recursively apply to all tensor in different kinds of container types.""" + + def apply(x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]) -> Any: + if torch.is_tensor(x): + return fn(x) + elif hasattr(x, "__dataclass_fields__"): + dc = dataclasses.replace(x) + for f in dataclasses.fields(dc): + name = f.name + setattr(dc, name, apply(getattr(dc, name))) + return dc + elif isinstance(x, OrderedDict): + od = x.__class__() + for key, value in x.items(): + od[key] = apply(value) + return od + elif isinstance(x, PackedSequence): + apply(x.data) + return x + elif isinstance(x, dict): + return {key: apply(value) for key, value in x.items()} + elif _is_namedtuple(x): + res = (apply(el) for el in x) + return type(x)(*res) + elif isinstance(x, (list, tuple, set)): + return type(x)(apply(el) for el in x) + else: + return x + + return apply(container) + + +def _apply_to_modules( + root_module: torch.nn.Module, + module_fn: Callable, + return_fn: Callable, + *args, + **kwargs, +): + """ + Performs a pre-order traversal of the modules in the hierarchy rooted at + ``root_module``, applying ``module_fn`` at each module and finally + returning a value using ``return_fn``. The traversal constructs the full + module prefix name (e.g. "module.submodule." just like in model state dict) + and makes that available to ``module_fn``. + """ + def f(module: torch.nn.Module, prefix: str, *args, **kwargs): + # Call the module function before recursing over children (pre-order) + module_fn(module, prefix, *args, **kwargs) + for submodule_name, submodule in module.named_children(): + if submodule is not None: + new_prefix = prefix + submodule_name + "." + f(submodule, new_prefix, *args, **kwargs) + + f(root_module, "", *args, **kwargs) + return return_fn(*args, **kwargs) + + +@torch.no_grad() +def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool: + """ + Allocate storage for ``tensor`` with the given size. + + Returns: + bool: ``True`` if this method allocated storage and ``False`` if the + storage was already allocated. + """ + already_allocated = tensor.storage().size() == size.numel() + if not already_allocated: + tensor_storage_size = tensor.storage().size() + p_assert( + tensor_storage_size == 0, + f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}", + ) + tensor.storage().resize_(size.numel()) + return not already_allocated + + +@torch.no_grad() +def _free_storage(tensor: torch.Tensor) -> bool: + """ + Frees the underlying storage of ``tensor``. + + Returns: + bool: ``True`` if the method freed the storage and ``False`` if the + storage was already freed. + """ + already_freed = tensor.storage().size() == 0 + if not already_freed: + p_assert( + tensor.storage_offset() == 0, + "Freeing a tensor's storage is unsafe when it is not the sole occupant", + ) + tensor.storage().resize_(0) + return not already_freed + + +def _set_fsdp_flattened(tensor: torch.Tensor) -> None: + """ + Sets an attribute on ``tensor`` to mark it as flattened by FSDP. This is to + avoid re-flattening it during nested construction. + """ + setattr(tensor, FSDP_FLATTENED, True) + + +def _is_fsdp_flattened(tensor: torch.Tensor) -> bool: + """Returns if ``tensor`` has been marked as flattened by FSDP.""" + return getattr(tensor, FSDP_FLATTENED, False) + + +def p_assert(cond: Any, s: Any, raise_assertion_error: bool = True) -> None: + """This is used as an alternate to ``assert`` when in the backward context + to print the error message ``s`` since otherwise, it is swallowed.""" + if not cond: + print(s) + traceback.print_stack() + if raise_assertion_error: + raise AssertionError diff --git a/torch_npu/distributed/fsdp/flat_param.py b/torch_npu/distributed/fsdp/flat_param.py new file mode 100644 index 0000000000..fcdc9c0248 --- /dev/null +++ b/torch_npu/distributed/fsdp/flat_param.py @@ -0,0 +1,1133 @@ +import contextlib +from dataclasses import dataclass +from enum import auto, Enum +from itertools import accumulate, chain +from typing import ( + Any, + cast, + Dict, + Generator, + Iterator, + List, + NamedTuple, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import torch +import torch_npu +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from ._fsdp_extensions import _ext_post_unflatten_transform, _ext_pre_flatten_transform +from ._utils import _alloc_storage, _free_storage, _set_fsdp_flattened, p_assert + +__all__ = [ + "FlatParameter", + "FlatParamHandle", + "FlatParamShardMetadata", + "ParamInfo", + "SharedParamInfo", + "HandleConfig", + "HandleShardingStrategy", + "HandleTrainingState", +] + + +class ParamInfo(NamedTuple): + """Information for an original module parameter.""" + + param_name: str # unprefixed + module: nn.Module + module_name: str + + +class SharedParamInfo(NamedTuple): + """ + Additional information for a shared parameter. + + For each shared parameter, we designate one module and its parameter + variable to be the primary owner, determined as the first one encountered + in the parameter walk. These are prefixed with "prim". The primary module + and parameter do not have their own :class:`SharedParamInfo` instance. + """ + + param_name: str # unprefixed + module: nn.Module + module_name: str + prim_param_name: str # unprefixed + prim_module: nn.Module + prim_module_name: str + + +class FlatParamShardMetadata(NamedTuple): + """ + This holds metadata specific to this rank's shard of the flattened + parameter. + + Attributes: + param_names (Tuple[str, ...]): Prefixed parameter names of this rank's + shard of the parameters; see :class:`FlatParameter`. + param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's + shard of the parameters; see :class:`FlatParameter`. + param_numels (Tuple[int, ...]): Parameter numels of this rank's shard + of the parameters; see :class:`FlatParameter`. + param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in + units of numels) giving this rank's part of each flattened + original module parameter. + """ + + param_names: Tuple[str, ...] + param_shapes: Tuple[torch.Size, ...] + param_numels: Tuple[int, ...] + param_offsets: Tuple[Tuple[int, int], ...] + + +# TODO (awgu): Prefix these with "Handle" for now to avoid circular imports and +# inadvertent misuses; coalesce with those in fully_sharded_data_parallel.py +# later +class HandleShardingStrategy(Enum): + FULL_SHARD = auto() + SHARD_GRAD_OP = auto() + NO_SHARD = auto() + + +class HandleTrainingState(Enum): + IDLE = auto() + FORWARD = auto() + BACKWARD_PRE = auto() + BACKWARD_POST = auto() + SUMMON_FULL_PARAMS = auto() + + +@dataclass +class HandleConfig: + sharding_strategy: HandleShardingStrategy + offload_params: bool + param_dtype: Optional[torch.dtype] + reduce_dtype: Optional[torch.dtype] + keep_low_precision_grads: Optional[bool] = False + + +class FlatParameter(nn.Parameter): + """ + This is the flattened parameter used by :class:`FullyShardedDataParallel`. + It is comprised of one or more original parameters, which are flattened + and concatenated to construct the flattened parameter. + + Under the current design, this parameter logically represents both the + unsharded and sharded flattened parameter, and its data changes storages + dynamically. + - In the :class:`FullyShardedDataParallel` constructor, the parameter + is initialized as unsharded and then sharded in-place. + - At runtime, the parameter is lazily (re)-initialized. The sharded + parameter data is saved in ``self._local_shard``, and a new ``Tensor`` + ``self._full_param_padded`` is created, which is the all-gather + destination and owns the unsharded parameter storage thereafter. (See + :meth:`FullyShardedDataParallel._init_param_attributes`.) + - Throughout runtime, the parameter data changes storages as needed, + e.g. to the sharded flattened parameter, reduced-precision sharded + flattened parameter, or the unsharded flattened parameter. + + Attributes: + _unpadded_unsharded_size (torch.Size): Unsharded flattened parameter's + size without padding. + _padded_unsharded_size (torch.Size): Unsharded flattened parameter's + size with padding. This is only set for sharded strategies since + they require padding for the all-gather. + + _param_infos (Tuple[ParamInfo, ...]): Each parameter's parameter info + entry; see :class:`ParamInfo`. + _numels (Tuple[int, ...]): Each parameter's numel. + _shapes (Tuple[torch.Size, ...]): Each parameter's shape. + _prefixed_param_names (Tuple[str, ...]): Each parameter's name prefixed + with the parent module names starting from the module passed to + construct this flattened parameter via :class:`FlatParamHandle`; + the prefixed names are guaranteed to be unique within the subtree + rooted in that module. + _num_params (int): Number of original parameters flattened into this + flattened parameter; this is the length of ``_param_infos``, + ``_numels``, ``_shapes``, and ``_prefixed_param_names``. + _shared_param_infos (Tuple[SharedParamInfo, ...]): Shared parameter + info entries; see :class:`SharedParamInfo`. + _param_extensions (Tuple[Optional[Any], ...]): Parameter extensions + (i.e. some per-parameter state) used to customize pre-flatten and + post-unflatten behavior. This is experimental, and users should not + depend on its existence in the future. + + _shard_param_offsets (List[Tuple[int, int])): [start, end] offsets (in + units of numel) giving this rank's part of each flattened original + module parameter; for any parameter ``p`` that is not sharded + across ranks, this will be [0, ``p.numel()``-1]. + _shard_indices (Tuple[int, int]): [start, end] indices (in units of + parameters) for this rank's shard of the original model parameters, + where the parameters follow the order in which they were originally + flattened; this indexes appropriately into any data structure that + follows the flattening order (e.g. ``_param_infos``, ``_numels``, + etc.). + _shard_numel_padded (int): Numel padded for this rank's sharded + flattened parameter. + + _local_shard (Tensor): Sharded flattened parameter with padding if + using a sharded strategy. If using ``NO_SHARD``, then this is the + unpadded unsharded flattened parameter, and there is no notion of a + sharded flattened parameter or padded unsharded flattened + parameter. + _full_param_padded (Tensor): Unsharded flattened parameter with + padding. This is not defined for ``NO_SHARD``. When using mixed + precision for parameters, this has the low precision. + _full_prec_full_param_padded (Tensor): Full precision unsharded + flattened parameter with padding. This is used for unsharding + outside of computation when using mixed precision for parameters. + This is never defined for ``NO_SHARD``. + _post_backward_hook_state (Tuple[AccumulateGrad, RemovableHandle]): + Flattened parameter's :class:`AccumulateGrad` object and + post-backward hook handle. + _mp_shard (Tensor): Low precision sharded flattened parameter with + padding. This is only defined when parameter mixed precision is + enabled. For ``NO_SHARD``, this is used for computation. + _cpu_grad (Tensor): Sharded gradient with padding stored on CPU. + This is only defined when offloading parameters is enabled. + _saved_grad_shard (Tensor): Sharded gradient with padding from previous + iterations for gradient accumulation without :meth:`no_sync`. + """ + + def _init_metadata( + self, + param_infos: List[ParamInfo], + numels: List[int], + shapes: List[torch.Size], + prefixed_param_names: List[str], + shared_param_infos: List[SharedParamInfo], + param_extensions: List[Any], + ) -> None: + """ + Initializes attributes holding metadata about the original parameters + comprising the flattened parameter. + + We expose this method separate from the constructor to keep the + constructor only responsible for the flattened parameter's tensor data. + This method should only be called once per model, while the constructor + may be called multiple times, e.g. when reloading from a checkpoint, in + which case only the tensor data needs to be passed to the constructor. + Since :meth:`load_state_dict` is implemented via :meth:`copy_`, the + metadata is correctly assumed to be unchanged. + + Args: + See the Attributes in the class docstring. + """ + assert len(param_infos) == len(numels) + assert len(param_infos) == len(shapes) + assert len(param_infos) == len(prefixed_param_names) + assert len(param_infos) == len(param_extensions) + self._num_params = len(param_infos) + self._param_infos = tuple(param_infos) + self._numels = tuple(numels) + self._shapes = tuple(shapes) + self._prefixed_param_names = tuple(prefixed_param_names) + self._shared_param_infos = tuple(shared_param_infos) + self._param_extensions = tuple(param_extensions) + self._unpadded_unsharded_size = self.size() + _set_fsdp_flattened(self) + + +class FlatParamHandle: + """ + This handle manages a flattened parameter (:class:`FlatParameter`). This + includes sharding and view management. + + Args: + params (Sequence[nn.Parameter]): The parameters to use for the + flattened parameter. + module (nn.Module): A module that is the root of the subtree containing + all parameters in ``params``; for non-recursive wrapping, this must + be the top-level module, while for recursive wrapping, this may not + necessarily be the top-level module. + device (torch.device): The compute and communication device, which + should be a non-CPU device. We refer to it as the compute device. + config (HandleConfig): A config customizing the handle based on FSDP's + available features. + """ + + ################## + # INITIALIZATION # + ################## + def __init__( + self, + params: Sequence[nn.Parameter], + module: nn.Module, + device: torch.device, + config: HandleConfig, + ) -> None: + super().__init__() + self.device = device + self._config = config + self._training_state = HandleTrainingState.IDLE + self._init_flat_param(params, module) + self._unflatten(as_params=False) + + def _init_flat_param( + self, + params: Sequence[Optional[nn.Parameter]], + module: nn.Module, + ) -> None: + """ + Initializes the flattened parameter ``self.flat_param`` by flattening + the parameters in ``params`` into a single :class:`FlatParameter` and + saves relevant metadata. Shared parameters are only included in the + flattened parameter once. + + This checks that all comprising parameters have the same dtype and + ``requires_grad`` and does not support nested construction of + :class:`FlatParameter` s. + + Args: + See the Args in the class docstring. + """ + params_set = set(params) + params_set.discard(None) + assert ( + len(params_set) > 0 + ), "Cannot initialize a `FlatParameter` from an empty parameter list" + param_infos: List[ParamInfo] = [] + numels: List[int] = [] + shapes: List[torch.Size] = [] + prefixed_param_names: List[str] = [] + shared_param_infos: List[SharedParamInfo] = [] + shared_param_memo: Dict[nn.Parameter, Tuple[nn.Module, str, str]] = {} + params_to_flatten: List[nn.Parameter] = [] + param_extensions: List[Any] = [] + dtype: Optional[torch.dtype] = None + requires_grad: Optional[bool] = None + for submodule_name, submodule in module.named_modules(): + for param_name, param in submodule.named_parameters(recurse=False): + if param not in params_set: + continue + if param in shared_param_memo: + prim_module, prim_module_name, prim_param_name = shared_param_memo[ + param + ] + shared_param_infos.append( + SharedParamInfo( + param_name, + submodule, + submodule_name, + prim_param_name, + prim_module, + prim_module_name, + ) + ) + else: + if type(param) is FlatParameter: + raise ValueError("`FlatParameter` does not support nesting") + if dtype is not None and param.dtype != dtype: + raise ValueError( + "`FlatParameter` requires uniform dtype but got " + f"{dtype} and {param.dtype}" + ) + if dtype is None and not param.is_floating_point(): + raise ValueError("Integer parameters are unsupported") + if ( + requires_grad is not None + and param.requires_grad != requires_grad + ): + raise ValueError( + "`FlatParameter` requires uniform `requires_grad`" + ) + param, extension = _ext_pre_flatten_transform(param) + param_extensions.append(extension) + dtype = param.dtype + requires_grad = param.requires_grad + shared_param_memo[param] = (submodule, submodule_name, param_name) + params_to_flatten.append(param) + param_infos.append(ParamInfo(param_name, submodule, submodule_name)) + numels.append(param.numel()) + shapes.append(param.shape) + prefixed_param_name = ( + submodule_name + "." + param_name + if submodule_name + else param_name + ) + prefixed_param_names.append(prefixed_param_name) + assert requires_grad is not None + self.flat_param = FlatParamHandle.flatten_params( + params_to_flatten, requires_grad + ) + self.flat_param._init_metadata( + param_infos, + numels, + shapes, + prefixed_param_names, + shared_param_infos, + param_extensions, + ) + + @staticmethod + def flatten_params( + params: Sequence[torch.Tensor], + requires_grad: bool, + ) -> FlatParameter: + """ + Flattens the parameters in ``params`` into a single + :class:`FlatParameter`. This should be the only way used to construct + :class:`FlatParameter` s. + + We expose this factory method for checkpointing (e.g. sharded state + dict). The flattened parameter's metadata should only be initialized + once (see :meth:`_init_metadata`), but its tensor data may be reloaded. + """ + with torch.no_grad(): + flat_params = [ + p.detach().reshape(-1) if isinstance(p, nn.Parameter) else p.reshape(-1) + for p in params + ] + flat_param_data = torch.cat(flat_params, dim=0) + flat_param = FlatParameter(flat_param_data, requires_grad=requires_grad) + return flat_param + + ################################### + # SHARD INITIALIZATION & METADATA # + ################################### + @torch.no_grad() + def shard(self, process_group: dist.ProcessGroup): + """ + Shards the handle's ``FlatParameter``. In terms of memory, this + allocates new memory for the sharded flattened parameter and frees the + unsharded flattened parameter's storage. + + Postcondition: ``self.flat_param`` is the sharded flattened parameter. + ``process_group``, ``rank``, and ``world_size`` attributes are set. + + TODO (awgu): Once we retire ``FlattenParamsWrapper``, we should pass + the process group directly to the ``FlatParamHandle`` constructor. For + now, we decouple ``FlattenParamsWrapper` from a process group, but this + makes the process-group-related attributes not necessarily defined. + """ + if not self.uses_sharded_strategy: + return + flat_param = self.flat_param + self.process_group = process_group + self.rank = process_group.rank() + self.world_size = process_group.size() + assert ( + flat_param.storage_offset() == 0 + ), "The `FlatParameter` is not the sole occupant of its storage" + orig_storage = flat_param.storage() + local_shard, numel_padded = FlatParamHandle._get_shard( + flat_param, self.rank, self.world_size + ) + + if flat_param.device == torch.device("cpu"): + flat_param.set_(local_shard) # type: ignore[call-overload] + if orig_storage.size() > 0: + orig_storage.resize_(0) + else: + if orig_storage.size() > 0: + orig_storage.resize_(0) + flat_param.set_(local_shard) # type: ignore[call-overload] + + self._init_shard_metadata(local_shard.numel(), numel_padded, self.rank) + + + def _init_shard_metadata( + self, + sharded_flat_param_numel: int, + numel_padded: int, + rank: int, + ) -> None: + """ + Initializes shard-related metadata for this rank's shard of the + flattened parameter: ``_shard_param_offsets``, ``_shard_indices``, and + ``_shard_numel_padded``. + + Args: + sharded_flat_param_numel (int): Numel of each rank's sharded + flattened parameter with padding (i.e. including + ``numel_padded``). + numel_padded (int): Numel padded for this rank's sharded flattened + parameter. + rank (int): Caller's rank. + """ + if numel_padded > sharded_flat_param_numel: + raise ValueError( + f"Sharded flattened parameter with {sharded_flat_param_numel} " + f"numel cannot have {numel_padded} numel padded" + ) + start = sharded_flat_param_numel * rank + end = sharded_flat_param_numel * (rank + 1) - 1 # inclusive + ( + self.flat_param._shard_param_offsets, # type: ignore[attr-defined] + self.flat_param._shard_indices, # type: ignore[attr-defined] + ) = self._get_shard_metadata(start, end) + self.flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined] + + def _get_shard_metadata( + self, + start: int, + end: int, + ) -> Tuple[Tuple[Tuple[int, int], ...], Tuple[int, int]]: + """ + Computes the shard metadata based on ``start`` and ``end``, which give + the closed interval of the unsharded flattened parameter specifying the + shard. + + Args: + start (int): Start index (in units of numel) of this rank's shard + of the flattened parameter. + end (int): End index (in units of numel and inclusive) of this + rank's shard of the flattened parameter. + + Return: + Tuple[Tuple[Tuple[int, int], ...], Tuple[int, int]]: See + ``_shard_param_offsets`` and ``_shard_indices`` in + :class:`FlatParameter` 's docstring. + """ + flat_param_offsets = self._get_flat_param_offsets() + # Indices of the original parameters in this rank's sharded flattened + # parameter + shard_param_indices_range = [] # elements will be consecutive + # [start, end] offsets giving this rank's part of the flattened + # original module parameter (which will be [0, `p.numel()`-1] for any + # parameter that is not sharded across ranks) + shard_param_offsets = [] + for i, (param_start, param_end) in enumerate(flat_param_offsets): + if start > param_end or end < param_start: + continue + if start <= param_start: + intra_param_start = 0 + else: + intra_param_start = start - param_start + intra_param_end = min(param_end, end) - param_start + shard_param_indices_range.append(i) + shard_param_offsets.append( + (intra_param_start, intra_param_end) + ) # both inclusive + if len(shard_param_indices_range) == 0: + shard_param_indices = (0, 0) + assert len(shard_param_offsets) == 0 + else: + shard_param_indices = ( + shard_param_indices_range[0], + shard_param_indices_range[-1], + ) + assert ( + len(shard_param_offsets) + == shard_param_indices[-1] - shard_param_indices[0] + 1 + ) + return tuple(shard_param_offsets), shard_param_indices + + @staticmethod + def _get_unpadded_shard( + tensor: Tensor, + rank: int, + world_size: int, + ) -> Tuple[Tensor, int]: + """ + Returns the shard of ``tensor`` without any padding for the given + ``rank`` and ``world_size`` and the numel to pad for that shard. + + If ``tensor`` is already flattened or may be viewed in the flattened + shape (which is true in the expected usage), then this method does not + allocate any new tensor memory. + """ + chunks = torch.flatten(tensor).chunk(world_size) + if len(chunks) < (rank + 1): + # This rank gets an empty chunk fully padded with zeros since there + # are not enough chunks across ranks + chunk = chunks[0].new_empty(0) + else: + chunk = chunks[rank] + numel_to_pad = chunks[0].numel() - chunk.numel() + assert ( + numel_to_pad >= 0 + ), "Chunk's size should be at most the first chunk's size" + return chunk, numel_to_pad + + @staticmethod + def _get_shard( + tensor: Tensor, + rank: int, + world_size: int, + ) -> Tuple[Tensor, int]: + """ + Returns the shard of ``tensor`` with padding for the given ``rank`` and + ``world_size`` and the numel padded for that shard. + + This method allocates new memory (via :meth:`clone`) since the + unsharded ``tensor`` may be deallocated after this method returns. + """ + chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard( + tensor, rank, world_size + ) + shard = chunk.clone() + if numel_to_pad > 0: + shard = F.pad(shard, [0, numel_to_pad]) + return shard, numel_to_pad + + @staticmethod + def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size: + """ + Returns the shape of ``tensor`` after sharding including padding. This + requires ``tensor`` to have 1D shape and ensures that the returned + shape is 1D. + """ + assert len(tensor.shape) == 1, f"{tensor.shape}" + unpadded_sharded_tensor, numel_to_pad = FlatParamHandle._get_unpadded_shard( + tensor, rank, world_size + ) + unpadded_sharded_size = unpadded_sharded_tensor.size() + assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}" + return torch.Size([unpadded_sharded_size[0] + numel_to_pad]) + + def _get_flat_param_offsets(self) -> List[Tuple[int, int]]: + """Returns [start, end] offsets of each original parameter's flattened + data in the unsharded flattened parameter (without padding).""" + cumulative_sum = list(accumulate(self.flat_param._numels)) + starts = [0] + cumulative_sum[:-1] + ends = [end - 1 for end in cumulative_sum] # inclusive + param_offsets = list(zip(starts, ends)) + return param_offsets + + def shard_metadata( + self, + ) -> FlatParamShardMetadata: + """Returns shard-related metadata specific to this rank's shard of the + flattened parameter.""" + assert hasattr(self.flat_param, "_shard_indices") and hasattr( + self.flat_param, "_shard_param_offsets" + ), "Shard metadata has not been initialized" + shard_param_start_index = self.flat_param._shard_indices[0] # type: ignore[attr-defined] + shard_param_end_index = self.flat_param._shard_indices[1] # type: ignore[attr-defined] + sl = ( + slice(shard_param_start_index, shard_param_end_index + 1) + if shard_param_start_index <= shard_param_end_index + else slice(0, 0) + ) + return FlatParamShardMetadata( + self.flat_param._prefixed_param_names[sl], + self.flat_param._shapes[sl], + self.flat_param._numels[sl], + self.flat_param._shard_param_offsets[:], # type: ignore[attr-defined] + ) + + ################### + # UNSHARD/RESHARD # + ################### + def pre_unshard(self) -> bool: + """ + Returns: ``False`` if this is a no-op and ``True`` otherwise. + + Postcondition: ``self.flat_param`` 's data is on the device for + communication and is what should be all-gathered. This means that it + matches the dtype of the expected unsharded parameter. + """ + ret = False + if ( + self.uses_sharded_strategy + and not self._config.offload_params + and not self.needs_unshard() + ): + pass # no-op + elif self._uses_param_mixed_precision and not self._force_full_precision: + self._use_low_precision_shard() + ret = True + elif self._config.offload_params and self.flat_param.device != self.device: + # NOTE: This creates a new tensor distinct from any attributes. + self._flat_param_to(self.device, non_blocking=True) + ret = True + self._check_on_compute_device(self.flat_param) + return ret + + def _use_low_precision_shard(self): + """ + Allocates the low precision shard directly on the compute device and + switches to using the low precision sharded flattened parameter. + """ + self._check_low_precision_shard() + flat_param = self.flat_param + _alloc_storage( + flat_param._mp_shard, flat_param._local_shard.size() # type: ignore[attr-defined] + ) + # `copy_()` implicitly casts to the low precision + flat_param._mp_shard.copy_( # type: ignore[attr-defined] + flat_param._local_shard.to( # type: ignore[attr-defined] + self.device, non_blocking=True + ) + ) + # Invariant: `_mp_shard` is always on the compute device. + flat_param.data = flat_param._mp_shard # type: ignore[attr-defined] + + def unshard(self): + """ + Runs the unshard logic. This includes all-gathering the flattened + parameter and switching to using the unsharded flattened parameter. If + the handle does not need unsharding, then this only switches to using + the unsharded flattened parameter. For ``NO_SHARD``, this is a no-op. + + If FSDP is in :meth:`summon_full_params` and the handle uses parameter + mixed precision, then the parameter is forced to full precision. + """ + if not self.needs_unshard(): + if self.uses_sharded_strategy: + # The handle may have been resharded without freeing the padded + # unsharded flattened parameter, in which case we need to + # switch to using the unsharded parameter + unsharded_flat_param = self._get_padded_unsharded_flat_param() + self._use_unsharded_flat_param(unsharded_flat_param) + return + unsharded_flat_param = self._alloc_padded_unsharded_flat_param() + self._all_gather_flat_param(unsharded_flat_param) + + def needs_unshard(self) -> bool: + """Returns if the handle's flattened parameter needs to be unsharded.""" + if not self.uses_sharded_strategy: + return False + unsharded_flat_param = self._get_padded_unsharded_flat_param() + already_unsharded = ( + unsharded_flat_param.storage().size() == unsharded_flat_param.numel() + ) + return not already_unsharded + + def _alloc_padded_unsharded_flat_param(self): + """ + Allocates the *padded* unsharded flattened parameter. The unpadded + unsharded flattened parameter is always a view into the padded one. + This padded parameter is saved to a different attribute on the + ``FlatParameter`` depending on if we force full precision. + """ + self._check_sharded_strategy() + flat_param = self.flat_param + unsharded_flat_param = self._get_padded_unsharded_flat_param() + self._check_storage_freed(unsharded_flat_param) + _alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size) # type: ignore[attr-defined] + return unsharded_flat_param + + def _get_padded_unsharded_flat_param(self) -> torch.Tensor: + """ + Returns a reference to the padded unsharded flattened parameter + depending on the calling context. This should only be called if using a + sharded strategy. + """ + self._check_sharded_strategy() + flat_param = self.flat_param + if self._force_full_precision: + # When parameter mixed precision is enabled, we use a different + # tensor as the all-gather destination to preserve the invariant + # that `_full_param_padded` is in the low precision + unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined] + p_assert( + unsharded_flat_param.dtype != self._config.param_dtype, + f"Expects full precision but got {self._config.param_dtype}", + ) + else: + unsharded_flat_param = flat_param._full_param_padded # type: ignore[attr-defined] + return unsharded_flat_param + + def _all_gather_flat_param( + self, + padded_unsharded_flat_param: Tensor, + ) -> None: + """ + All-gathers the handle's flattened parameter to the destination + ``padded_unsharded_flat_param``, and switches to using the all-gathered + tensor. + """ + p_assert( + hasattr(self, "process_group") and hasattr(self, "world_size"), + "Expects a process group and world size to have been set via `shard()`", + ) + sharded_flat_param = self.flat_param.data + expected_numel = sharded_flat_param.numel() * self.world_size + p_assert( + padded_unsharded_flat_param.numel() == expected_numel, + f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}", + ) + dist._all_gather_base( + padded_unsharded_flat_param, + sharded_flat_param, + self.process_group, + ) + self._use_unsharded_flat_param(padded_unsharded_flat_param) + + def _use_unsharded_flat_param( + self, + padded_unsharded_flat_param: torch.Tensor, + ) -> None: + """ + Switches to using the *unpadded* unsharded flattened parameter, which + is a view into the *padded* unsharded flattened parameter. + """ + unsharded_size = self.flat_param._unpadded_unsharded_size + self.flat_param.data = padded_unsharded_flat_param[ + : unsharded_size.numel() + ].view(unsharded_size) + + def post_unshard(self): + """ + Runs the post-unshard logic. This includes freeing the low precision + shard if needed. + """ + if self._uses_param_mixed_precision and self.uses_sharded_strategy: + self._free_low_precision_sharded_param() + self._check_on_compute_device(self.flat_param) + + def _free_low_precision_sharded_param(self): + """Frees the low precision sharded flattened parameter.""" + self._check_low_precision_shard() + _free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined] + + def prepare_gradient(self): + """ + Prepares the gradient for the backward computation by saving and + clearing any existing sharded gradient in ``.grad`` to enable computing + a new unsharded gradient. + """ + p_assert( + self._training_state + in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE), + "Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)", + ) + flat_param = self.flat_param + if flat_param.grad is not None and ( + flat_param.grad.size() != flat_param._unpadded_unsharded_size + or flat_param.grad.device != flat_param.device # grad on CPU + ): + self._check_on_compute_device(self.flat_param) + grad_offloaded = flat_param.grad.device != self.device + p_assert( + not grad_offloaded or self._config.offload_params, + f"Expects the sharded gradient to be on {self.device} " + f"but got {flat_param.grad.device}", + ) + prev_iter_synced_gradients = ( + flat_param.grad.size() + == flat_param._local_shard.size() # type: ignore[attr-defined] + ) + if prev_iter_synced_gradients: + # TODO (awgu): Gradient accumulation outside `no_sync()` + # does not work with CPU offloading. The issue should be + # that, in the post-backward hook, we cannot do an addition + # between a CPU tensor (the existing sharded gradient) and + # a GPU tensor (the new sharded gradient). + if not grad_offloaded: + flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined] + # If we're using mixed precision with keeping grads + # casted, gradient here might still be of the reduced + # dtype if we didn't clear / set the gradients to None + # after previous backward. In that case, make sure + # p._saved_grad_shard is cast to the full precision type + # so that we can accumulate in full precision in + # _post_backward_hook and assign back in full precision + # in _wait_for_post_backward. + if ( + self._config.keep_low_precision_grads + and flat_param._saved_grad_shard.dtype # type: ignore[attr-defined] + != flat_param._local_shard.dtype # type: ignore[attr-defined] + ): + flat_param._saved_grad_shard = flat_param._saved_grad_shard.to( # type: ignore[attr-defined] + flat_param._local_shard.dtype # type: ignore[attr-defined] + ) + else: + padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined] + p_assert( + flat_param.grad.size() == padded_unsharded_size, + "Expects `.grad` to be the unsharded gradient in " + f"`no_sync()` with size {padded_unsharded_size} " + f"but got size {flat_param.grad.size()}", + ) + flat_param.grad = None + + @contextlib.contextmanager + def to_cpu(self): + """ + Moves the unpadded unsharded flattened parameter to CPU while in the + context and moves it back to the previous device upon exit. For now, + this assumes the ``FlatParameter`` is the unpadded unsharded flattened + parameter since (1) there is no reason to include the padding in the + copy and (2) there is no use case for the sharded flattened parameter. + + Precondition: ``self.flat_param`` 's data is the unpadded unsharded + flattened parameter on the compute device, and the handle uses a + sharded strategy. + Postcondition: Same as the precondition. + """ + self._check_sharded_strategy() + p_assert( + self.flat_param.size() == self.flat_param._unpadded_unsharded_size, + f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}", + ) + self._check_on_compute_device(self.flat_param) + # Check that the unpadded unsharded flattened parameter is a view into + # the padded unsharded flattened parameter as expected + # NOTE: This check is not strictly needed for correctness but is a + # useful sanity check since the tensor should only be used internally. + unpadded_storage_ptr = self.flat_param.storage().data_ptr() + padded_storage_ptr = ( + self._get_padded_unsharded_flat_param().storage().data_ptr() + ) + p_assert( + unpadded_storage_ptr == padded_storage_ptr, + "Expects the unpadded parameter to be a view into the padded parameter", + ) + self._flat_param_to(torch.device("cpu")) + self._free_unsharded_flat_param() + try: + yield + finally: + p_assert( + self.flat_param.size() == self.flat_param._unpadded_unsharded_size, + f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}", + ) + padded_unsharded_flat_param = self._alloc_padded_unsharded_flat_param() + # Copy from CPU to the compute device + padded_unsharded_flat_param[: self.flat_param.numel()].copy_( + self.flat_param + ) + self._use_unsharded_flat_param(padded_unsharded_flat_param) + + def reshard(self, free_unsharded_flat_param: bool): + """ + Runs the reshard logic. This includes freeing the unsharded flattened + parameter if ``free_unsharded_flat_param`` and switching to using the + sharded flattened parameter. + """ + if free_unsharded_flat_param: + self._free_unsharded_flat_param() + self._use_sharded_flat_param() + + def post_reshard(self): + """ + Runs the post-reshard logic. This includes freeing any memory that + can now be freed given that the ``FlatParameter`` points to the full + precision sharded flattened parameter. + + Precondition: ``self.flat_param`` 's data points to the full precision + sharded flattened parameter. + """ + # For `NO_SHARD`, `_mp_shard` is not freed in the post-unshard since + # it is also the low precision *unsharded* flattened parameter. Hence, + # we delay the free until the reshard. + if ( + self._uses_param_mixed_precision + and not self.uses_sharded_strategy + and not self._force_full_precision # did not use the low precision shard + ): + self._free_low_precision_sharded_param() + + def _free_unsharded_flat_param(self): + """ + Frees the padded unsharded flattened parameter. The tensor to free + depends on the calling context since the unshard may have forced full + precision, in which case a different tensor is used. + """ + self._check_sharded_strategy() + unsharded_flat_param = self._get_padded_unsharded_flat_param() + self._check_storage_allocated(unsharded_flat_param) + self._check_on_compute_device(unsharded_flat_param) + # Do not free the memory until all ops in the current stream finish + unsharded_flat_param.record_stream( + cast(torch._C.Stream, torch_npu.npu.current_stream()) + ) + _free_storage(unsharded_flat_param) + + def _use_sharded_flat_param(self) -> None: + """Switches to using the sharded flattened parameter.""" + flat_param = self.flat_param + if self._config.offload_params: + device = flat_param._local_shard.device # type: ignore[attr-defined] + p_assert( + device == torch.device("cpu"), + f"Expects the local shard to be on CPU but got {device}", + ) + flat_param.data = flat_param._local_shard # type: ignore[attr-defined] + + ######### + # VIEWS # + ######### + @staticmethod + def _get_unflat_views( + flat_param: FlatParameter, + tensor: Optional[torch.Tensor] = None, + ) -> Iterator[Tensor]: + """ + Returns unflattened ``Tensor`` views into ``tensor`` if it is not + ``None`` or ``flat_param`` otherwise, where the unflattening is based + on ``flat_param`` 's metadata. + + In other words, to get views into the unsharded flattened parameter, + pass ``tensor`` as ``None``, but to get views into tensor optimizer + state, pass ``tensor`` as the optimizer state tensor. + """ + if tensor is None: + tensor = flat_param + p_assert( + tensor.data.numel() == flat_param._unpadded_unsharded_size.numel(), + f"Expects {flat_param._unpadded_unsharded_size.numel()} numel but got " + f"{tensor.data.numel()} numel", + ) + views = ( + _ext_post_unflatten_transform(subtensor.view(shape), param_extension) + for (subtensor, shape, param_extension) in zip( + torch.split(tensor, flat_param._numels, dim=0), # type: ignore[arg-type] + flat_param._shapes, flat_param._param_extensions, + ) + ) + return views + + def _unflatten(self, as_params: bool) -> None: + """ + Unflattens the unsharded flattened parameter by setting the original + module parameter variables to be views into it. + + Args: + as_params (bool): If ``True``, then registers the original + parameters as ``nn.Parameter`` s; if ``False``, then registers + the original parameters only as ``Tensor`` s. ``False`` should + be used during forward/backward computation and when hiding the + original parameters from :meth:`nn.Module.named_parameters`. + """ + views = self._get_unflat_views(self.flat_param) + for view, (param_name, module, _) in zip(views, self.flat_param._param_infos): + if hasattr(module, param_name): + delattr(module, param_name) + if as_params: + module.register_parameter(param_name, nn.Parameter(view)) + else: + setattr(module, param_name, view) + for ( + param_name, + module, + _, + prim_param_name, + prim_module, + _, + ) in self.flat_param._shared_param_infos: + if hasattr(module, param_name): + delattr(module, param_name) + assert hasattr(prim_module, prim_param_name) + param: Union[Tensor, nn.Parameter] = getattr(prim_module, prim_param_name) + if as_params: + assert isinstance(param, nn.Parameter) + module.register_parameter(param_name, param) + else: + setattr(module, param_name, param) + + @contextlib.contextmanager + def unflatten_as_params(self) -> Generator: + """ + Assumes the flattened parameter is unsharded. When in the context, + unflattens the original parameters as ``nn.Parameter`` views into the + flattened parameter, and after the context, restores the original + parameters as ``Tensor`` views into the flattened parameter. + """ + self._unflatten(as_params=True) + try: + yield + finally: + self._unflatten(as_params=False) + + ########### + # HELPERS # + ########### + def _flat_param_to(self, *args, **kwargs): + """Wraps an in-place call to ``.to()`` for ``self.flat_param``.""" + self.flat_param.data = self.flat_param.to(*args, **kwargs) + + def _get_modules(self) -> Set[nn.Module]: + """Returns a :class:`set` of the modules whose parameters are included + in this handle's flattened parameter.""" + return set(pi.module for pi in self.flat_param._param_infos).union( + set(spi.module for spi in self.flat_param._shared_param_infos) + ) + + def parameter_module_names(self) -> Iterator[Tuple[str, str]]: + shared_param_infos = [ + ParamInfo(param_name, module, module_name) + for ( + param_name, + module, + module_name, + _, + _, + _, + ) in self.flat_param._shared_param_infos + ] + for param_name, _, module_name in chain( + self.flat_param._param_infos, shared_param_infos + ): + yield (param_name, module_name) + + ####################### + # CHECKS & INVARIANTS # + ####################### + def _check_sharded_strategy(self): + p_assert(self.uses_sharded_strategy, "Expects sharded strategy") + + def _check_on_compute_device(self, tensor: Tensor): + p_assert( + tensor.device == self.device, + f"Expects tensor to be on the compute device {self.device}", + ) + + @staticmethod + def _check_storage_freed(tensor: Tensor): + storage_size: int = tensor.storage().size() + p_assert( + storage_size == 0, + f"Expects storage to be freed but got storage with size {storage_size}", + ) + + @staticmethod + def _check_storage_allocated(tensor: Tensor): + storage_size: int = tensor.storage().size() + p_assert(storage_size > 0, "Expects storage to be allocated") + + def _check_low_precision_shard(self): + p_assert( + self._uses_param_mixed_precision, + "Not using low precision for parameters", + ) + p_assert( + getattr(self.flat_param, "_mp_shard", None) is not None, + "Expects `_mp_shard` to exist", + ) + device = self.flat_param._mp_shard.device # type: ignore[attr-defined] + p_assert( + device == self.device, + f"Expects the low precision shard to be on {self.device} but got {device}", + ) + + ############## + # PROPERTIES # + ############## + @property + def uses_sharded_strategy(self) -> bool: + return self._config.sharding_strategy != HandleShardingStrategy.NO_SHARD + + @property + def _uses_param_mixed_precision(self) -> bool: + return self._config.param_dtype is not None + + @property + def _force_full_precision(self) -> bool: + return ( + self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS + and self._uses_param_mixed_precision + ) + + +def apply_fsdp_flat_param_handle(): + torch.distributed.fsdp.flat_param.FlatParamHandle = FlatParamHandle + torch.distributed.fsdp.flat_param.FlatParameter = FlatParameter + torch.distributed.fsdp.flat_param.HandleConfig = HandleConfig + torch.distributed.fsdp.flat_param.HandleTrainingState = HandleTrainingState + torch.distributed.fsdp.flat_param.HandleShardingStrategy = HandleShardingStrategy + torch.distributed.fsdp.flat_param.FlatParamShardMetadata = FlatParamShardMetadata + torch.distributed.fsdp.flat_param.SharedParamInfo = SharedParamInfo + torch.distributed.fsdp.flat_param.ParamInfo = ParamInfo + torch.distributed.fsdp.flat_param.__all__ = __all__ \ No newline at end of file diff --git a/torch_npu/distributed/fsdp/flatten_params_wrapper.py b/torch_npu/distributed/fsdp/flatten_params_wrapper.py new file mode 100644 index 0000000000..c4872ebd00 --- /dev/null +++ b/torch_npu/distributed/fsdp/flatten_params_wrapper.py @@ -0,0 +1,175 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Tongzhou Wang +# Licensed under the MIT License. + +import contextlib +from typing import Any, Dict, Generator, List + +import torch +import torch_npu +import torch.nn as nn +from torch.distributed.utils import _replace_by_prefix + +from .flat_param import FlatParamHandle, HandleConfig + +FLAT_PARAM = "flat_param" +FPW_MODULE = "_fpw_module" + +__all__ = ["FlattenParamsWrapper"] + + +def _post_state_dict_hook( + module: nn.Module, state_dict: Dict[str, Any], prefix: str, *args: Any +) -> Dict[str, Any]: + """ + _post_state_dict_hook() is called after the state_dict() is executed + and before returning the state_dict to the users. + This API post-processes the keys of the state_dict to remove the + FlattenParamsWrapper internal prefix. + """ + # Move everything from FPW_MODULE up one level. + _replace_by_prefix(state_dict, prefix + f"{FPW_MODULE}.", prefix) + return state_dict + + +def _pre_load_state_dict_hook( + state_dict: Dict[str, Any], + prefix: str, + *args: Any, +) -> None: + """ + _pre_load_state_dict_hook() is called before the _load_from_state_dict() is + executed. This API pre-processes the keys of the state_dict to add the + FlattenParamsWrapper internal prefix. + """ + # Push everything down to FPW_MODULE level. + _replace_by_prefix(state_dict, prefix, prefix + f"{FPW_MODULE}.") + # The flat_param_* keys actually needs to move one level up. + flat_param_key = prefix + f"{FPW_MODULE}.{FLAT_PARAM}" + for k in list(state_dict.keys()): + if k.startswith(flat_param_key): + last_part = k.split(".")[-1] + assert last_part.startswith( + FLAT_PARAM + ), f"Expected key to contain flat_param, but key name is {k}" + _replace_by_prefix(state_dict, k, prefix + last_part) + + +class FlattenParamsWrapper(nn.Module): + """ + This is a wrapper for flattening parameters in a ``nn.Module`` 's subtree + into a single flattened parameter and is based on [1]. This is used for + :class:`FullyShardedDataParallel` 's recursive wrapping. + [1] https://github.com/SsnL/PyTorch-Reparam-Module + + Args: + module (nn.Module): Module to wrap. + params (List[nn.Parameter]): Parameters in ``module`` 's subtree to + flatten into a single flattened parameter. + device (torch.device): The compute and communication device for this + wrapper's handle. + config (HandleConfig): A config customizing this wrapper's handle based + on FSDP's available features. + + Attributes: + flat_param (Optional[FlatParameter]): The flattened parameter. + ``flat_param`` is ``None`` either when (1) this wrapper manages no + parameters or (2) the wrapped module's parameters are unflattened. + _fpw_module (nn.Module): The wrapped module. + _flat_param_handle (FlatParamHandle): A handle for the flattened + parameter; only present if this wrapper manages parameters. + """ + + def __init__( + self, + module: nn.Module, + params: List[nn.Parameter], + device: torch.device, + config: HandleConfig, + ) -> None: + super().__init__() + self._fpw_module = module + self.flat_param = None + # Register hooks to clean parameter names for state dict (even if this + # wrapper itself manages no parameters since it must clean names from + # submodules) + self._register_state_dict_hook(_post_state_dict_hook) + self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook) + if len(params) == 0: + return + self._flat_param_handle = FlatParamHandle(params, module, device, config) + # Defining `self.flat_param` registers the `FlatParameter` and makes it + # visible to `named_parameters()` + self.flat_param = self._flat_param_handle.flat_param + assert getattr(self, FPW_MODULE) is self._fpw_module + assert getattr(self, FLAT_PARAM) is self.flat_param + + @property + def has_params(self) -> bool: + """Returns whether this wrapper manages any parameters.""" + return hasattr(self, "_flat_param_handle") + + @property + def handle(self) -> FlatParamHandle: + assert hasattr(self, "_flat_param_handle"), ( + "Accessing the handle of a `FlattenParamsWrapper` that does not " + "manage any parameters" + ) + return self._flat_param_handle + + @property + def module(self) -> Any: + """Returns the wrapped module (like DDP).""" + return self._fpw_module + + @contextlib.contextmanager + def unflatten_as_params(self) -> Generator: + """ + Assumes that the flattened parameter is unsharded. When in the context, + unflattens the original parameters as ``nn.Parameter`` views into the + flattened parameter and de-registers the flattened parameter. After the + context, restores the original parameters as ``Tensor`` views into the + flattened parameter and re-registers the flattened parameter. + """ + if getattr(self, "flat_param", None) is None: + yield + else: + # De-register the `FlatParameter` from this wrapper to hide it from + # `named_parameters()` (though it still exists in memory) + del self.flat_param + try: + with self._flat_param_handle.unflatten_as_params(): + yield + finally: + # Re-register the `FlatParameter` + self.flat_param = self._flat_param_handle.flat_param + + def __getattr__(self, name: str) -> Any: + """Forward missing attributes of this wrapper to the wrapped module.""" + try: + return super().__getattr__(name) # defer to `nn.Module`'s logic + except AttributeError: + return getattr(self.module, name) # fall back to the wrapped module + + def __getitem__(self, key: int) -> Any: + """Forward indexing calls to the wrapped module in case the wrapped + module is an ``nn.Sequential``.""" + return self.module.__getitem__(key) + + def forward(self, *inputs: Any, **kwinputs: Any) -> Any: + if self.flat_param is not None: + self._flat_param_handle._unflatten(as_params=False) + return self.module(*inputs, **kwinputs) + + +def apply_fsdp_flatten_params_wrapper(): + torch.distributed.fsdp.flatten_params_wrapper.FlattenParamsWrapper = FlattenParamsWrapper + torch.distributed.fsdp.flatten_params_wrapper._pre_load_state_dict_hook = _pre_load_state_dict_hook + torch.distributed.fsdp.flatten_params_wrapper._post_state_dict_hook = _post_state_dict_hook + torch.distributed.fsdp.flatten_params_wrapper.__all__ = __all__ + torch.distributed.fsdp.flatten_params_wrapper.FPW_MODULE = FPW_MODULE + torch.distributed.fsdp.flatten_params_wrapper.FLAT_PARAM = FLAT_PARAM diff --git a/torch_npu/distributed/fsdp/fully_sharded_data_parallel.py b/torch_npu/distributed/fsdp/fully_sharded_data_parallel.py new file mode 100644 index 0000000000..c88f2d3747 --- /dev/null +++ b/torch_npu/distributed/fsdp/fully_sharded_data_parallel.py @@ -0,0 +1,4477 @@ +import collections +import contextlib +import copy +import functools +import itertools +import math +import traceback +import warnings +from contextlib import contextmanager +from dataclasses import dataclass +from enum import Enum, auto +from typing import ( + Any, + Callable, + Deque, + Dict, + Generator, + Iterable, + Iterator, + List, + Mapping, + NamedTuple, + Optional, + Set, + Tuple, + Union, + cast, +) + +import torch +import torch_npu +import torch.distributed as dist +import torch_npu.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from torch.distributed import ProcessGroup +from torch.distributed._shard.sharded_tensor import ( + Shard, + ShardedTensor, + init_from_local_shards, +) +from torch_npu.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + _CHECKPOINT_PREFIX, +) +from torch_npu.distributed.algorithms._comm_hooks import ( + LOW_PRECISION_HOOKS, + default_hooks, +) +from torch.distributed.distributed_c10d import _get_default_group +from torch_npu.distributed.utils import ( + _replace_by_prefix, + _sync_params_and_buffers, + _to_kwargs, +) +from torch.nn.parameter import Parameter + +from ._optim_utils import ( + _broadcast_pos_dim_tensor_states, + _broadcast_processed_optim_state_dict, + _flatten_optim_state_dict, + _get_param_id_to_param, + _get_param_id_to_param_from_optim_input, + _get_param_to_param_id, + _get_param_to_param_id_from_optim_input, + _optim_state_dict, + _process_pos_dim_tensor_state, + _rekey_sharded_optim_state_dict, +) +from ._fsdp_extensions import _ext_chunk_tensor, _ext_pre_load_state_dict_transform +from ._utils import ( + _apply_to_modules, + _apply_to_tensors, + _contains_batchnorm, + _free_storage, + _is_fsdp_flattened, + _override_batchnorm_mixed_precision, + p_assert, +) +from .flat_param import ( + FlatParameter, + FlatParamHandle, + HandleConfig, + HandleShardingStrategy, + HandleTrainingState, +) +from .flatten_params_wrapper import ( + FLAT_PARAM, + FPW_MODULE, + FlattenParamsWrapper, +) +from .wrap import ( + ParamExecOrderWrapPolicy, + _or_policy, + _recursive_wrap, + _wrap_batchnorm_individually, +) + +_TORCHDISTX_AVAIL = True +try: + from torchdistx import deferred_init, fake +except ImportError: + _TORCHDISTX_AVAIL = False + +_TORCH_FX_AVAIL = True +if not hasattr(torch, "fx"): + _TORCH_FX_AVAIL = False +if _TORCH_FX_AVAIL: + from ._symbolic_trace import ( + TracingConfig, + _init_execution_info, + _patch_tracer, + ) + + +__all__ = [ + "FullyShardedDataParallel", "ShardingStrategy", "MixedPrecision", + "CPUOffload", "BackwardPrefetch", "StateDictType", "StateDictConfig", + "FullStateDictConfig", "LocalStateDictConfig", "ShardedStateDictConfig", + "OptimStateKeyType", "TrainingState_", "clean_tensor_name", +] + + +FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module" +FSDP_PREFIX = FSDP_WRAPPED_MODULE + "." + FPW_MODULE + "." + +_PARAM_BROADCAST_BUCKET_SIZE = int(250 * 1024 * 1024) + + +class ShardingStrategy(Enum): + """ + This specifies the sharding strategy to be used for distributed training by + :class:`FullyShardedDataParallel`. + FULL_SHARD: Parameters, gradients, and optimizer states are sharded. For + the parameters, this algorithm all-gathers before the forward, + reshards after the forward, all-gathers before the backward + computation, and reshards after the backward computation. The + gradients are synchronized and sharded via reduce-scatter after + the backward computation. The sharded optimizer states are + updated locally. + SHARD_GRAD_OP: Gradients and optimizer states are sharded during + computation, and additionally parameters are sharded outside + computation. For the parameters, this algorithm all-gathers + before the forward, does not reshard after the forward, and + only reshards after the backward computation. The gradients + are synchronized and sharded via reduce-scatter after the + backward computation. The sharded optimizer states are + updated locally. Inside ``no_sync()``, the parameters are + not resharded after the backward computation. + NO_SHARD: Parameters, gradients, and optimizer states are not sharded but + instead replicated across ranks, similar to PyTorch's + ``DistributedDataParallel`` API. The gradients are synchronized + via all-reduce after the backward computation. The unsharded + optimizer states are updated locally. + HYBRID_SHARD(future support): Apply ``FULL_SHARD`` intra-node and + ``NO_SHARD`` inter-node. + + """ + FULL_SHARD = auto() + SHARD_GRAD_OP = auto() + NO_SHARD = auto() + # TODO + # HYBRID_SHARD = auto() + + +@dataclass +class MixedPrecision: + """ + A config to enable mixed precision training with FullyShardedDataParallel. + This class can be constructed with several flags: + ``param_dtype`` controls the precision of model parameters, inputs, and + therefore the precision under which computation happens. After forward + and backward passes, FSDP parameters point to full precision shards + that are kept in memory. Full precision parameters are always + checkpointed. + ``reduce_dtype`` controls the precision under which gradient reduction + would occur, which can potentially be different than ``param_dtype`` + for use cases such as communication efficiency. + ``buffer_dtype`` controls the precision that buffers are cast to. Note + that buffers are unsharded and are cast in the first forward pass, and + remain in their reduced precision state even after forward/backward + passes. However, when taking checkpoints with ``state_dict``, buffers + are checkpointed in their full precision (and then restored back to + to their reduced precision) as expected. Note that this checkpoint + support is currently limited to ``StateDictType.FULL_STATE_DICT``. + ``keep_low_precision_grads``: Whether to upcast gradients back to the + full parameter precision after backwards or not. This can be disabled + to keep the gradients in the lower precision, which can potentially + save memory if custom Optimizers are able to perform parameter updates + effectively with lower precision grads. + + .. note:: In ``summon_full_params``, parameters are summoned in full + precision but buffers are not. + + .. note:: Parameters and buffers are checkpointed in full precision. For + buffers, this is only guaranteed to work for ``StateDictType.FULL_STATE_DICT``. + + .. note:: This API is experimental and subject to change. + + .. note:: Specification of reduced precision types must be explicit, in that + if, for example, ``param_dtype`` is not specified, it will not be cast by + FSDP. Thus, a config such as ``MixedPrecision(reduce_dtype=torch.float16)`` + will not cast buffers or parameters. Note that if a ``MixedPrecision`` + config is specified without a ``reduce_dtype``, gradient communication + would occur in the `param_dtype` precision, if given, otherwise, in the + original parameter precision. + """ + # maintain a tensor of this dtype that the fp32 param shard will be cast to. + # Will control the precision of model params, inputs, and thus compute as + # well. + param_dtype: Optional[torch.dtype] = None + # Gradient communication precision. + reduce_dtype: Optional[torch.dtype] = None + # Buffer precision. + # TODO: buffer + param are usually of the same type, if user specifies + # param but not buffer, should we automatically make buffer be the same? + buffer_dtype: Optional[torch.dtype] = None + keep_low_precision_grads: Optional[bool] = False + + +@dataclass +class CPUOffload: + """ + CPU offloading config. Currently, only parameter and gradient CPU + offload are supported. + offload_params: Offloading parameters to CPUs when these parameters are + not used for computation on NPU. This implicitly enables + gradient offloading to CPUs in order for parameters and + gradients to be on the same device to work with optimizer. + """ + + offload_params: bool = False + + +class BackwardPrefetch(Enum): + """ + Specify where to prefetch next layer's full parameters + during backward pass. + BACKWARD_PRE: prefetch right before current layer's backward computation + starts, this approach will increase backward communication + and computation overalpping and potentialy improve training + performance, but it may increase the peak memory usage as + the prefetched full parameters will be kept in the NPU memory + until next layer's backward computation is done. + BACKWARD_POST: prefetch right after current layer's backward computation finishes, + this approach will not increase peak memory as prefetching happens + after current layer's full parameters are freed. + It could potentially improve backward communication and computation + overlapping as it avoids all_gather and reduce_scatter are blocked + each other in the single NCCL stream. However, based on our experiments, + for some models, the backward post backward hook fire order is not always + the reversed forward computation order, so this + approach may prefetch full parameters for layers ahead of next layer, + this 'ahead' all_gather could delay next layer's all_gather in the + single NCCL stream and cause the next layer's computation delay. So it may + cause some performance regession for some models. + """ + + BACKWARD_PRE = auto() + BACKWARD_POST = auto() + # TODO, BACKWARD_PRE_CPU, prefetch full parameters and keep them in the CPU memory + + +class TrainingState_(Enum): + """ + Simple enum to indicate what state FSDP is in. Used for asserting + to make sure APIs are called in the correct state. + ..note:: + ``BACKWARD_PRE`` and ``BACKWARD_POST`` states are used to ensure we + receives backward hooks in the correct order. It is used to catch + unexpected order of hooks being called (likely due to our + hook registration logic or autograd engine logic changes). + """ + + IDLE = auto() + FORWARD = auto() + BACKWARD_PRE = auto() + BACKWARD_POST = auto() + SUMMON_FULL_PARAMS = auto() + + +class StateDictType(Enum): + """ + This enum indicates that which type of ``state_dict`` the FSDP module is + currently processing (returning or loading). + The default value is FULL_STATE_DICT to comply the PyTorch convention. + ..note:: + FSDP currently supports three types of ``state_dict``: + 1. ``state_dict/load_state_dict`: this pair of APIs return and load + the non-sharded, unflattened parameters. The semantics is the + same as using DDP. + 2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return + and load local sharded, flattened parameters. The values returned + by ``_local_state_dict`` can be directly used by FSDP and is only + meaningful to FSDP (because parameters are flattened). Note that + these APIs are meant for use via the :func:`state_dict_type` + context manager as follows: + >>> # xdoctest: +SKIP("undefined variables") + >>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT): + ... state = fsdp.state_dict() # loads local state dict + 3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs + return and load sharded, unflattened parameters. The ``state_dict`` + return by ``sharded_state_dict`` can be used by all other parallel + schemes (resharding may be required). + """ + + FULL_STATE_DICT = auto() + LOCAL_STATE_DICT = auto() + SHARDED_STATE_DICT = auto() + +@dataclass +class StateDictConfig: + """ + ``StateDictConfig`` is the base class for all state_dict configuration classes. + Users should instantiate a child version (i.e. ``FullStateDictConfig``) in + order to configure settings for the particular type of ``state_dict`` + implementation FSDP will use. + """ + pass + +@dataclass +class FullStateDictConfig(StateDictConfig): + """ + ``FullStateDictConfig`` is a config class meant to be used with + ``StateDictType.FULL_STATE_DICT``. Currently, it accepts two parameters, + ``offload_to_cpu`` and ``rank0_only`` which can be configured to offload + the full ``state_dict`` to CPU and to materialize the ``state_dict`` on + rank 0 only. When used, it is recommended to enable both of these flags + together to optimize memory savings when taking checkpoints. Note that + this config class is meant for user via the :func:`state_dict_type` + context manager as follows: + >>> # xdoctest: +SKIP("undefined variables") + >>> fsdp = FSDP(model, auto_wrap_policy=...) + >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + >>> with FullyShardedDataParallel.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): + >>> state = fsdp.state_dict() + >>> # state will be empty on non rank 0 and contain CPU tensors on rank 0. + >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: + >>> model = model_fn() # Initialize model on CPU in preparation for wrapping with FSDP + >>> if dist.get_rank() == 0: + >>> # Load checkpoint only on rank 0 to avoid memory redundancy + >>> state_dict = torch.load("my_checkpoint.pt") + >>> model.load_state_dict(state_dict) + >>> # All ranks initialize FSDP module as usual. ``sync_module_states`` argument + >>> # communicates loaded checkpoint states from rank 0 to rest of the world. + >>> fsdp = FSDP(model, device_id=torch.npu.current_device(), auto_wrap_policy=..., sync_module_states=True) + >>> # After this point, all ranks have FSDP model with loaded checkpoint. + """ + offload_to_cpu: bool = False + rank0_only: bool = False + +@dataclass +class LocalStateDictConfig(StateDictConfig): + pass + +@dataclass +class ShardedStateDictConfig(StateDictConfig): + pass + +_state_dict_type_to_config = { + StateDictType.FULL_STATE_DICT: FullStateDictConfig, + StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig, + StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig, +} + +class OptimStateKeyType(Enum): + PARAM_NAME = auto() + PARAM_ID = auto() + + +# A handles key represents the group of `FlatParamHandle`s involved in a given +# module's forward. These will be all-gathered together in the pre-forward and +# pre-backward. +_HandlesKey = Tuple[FlatParamHandle, ...] + + +class _ExecOrderWarnStatus(Enum): + """Used internally for execution order validation.""" + NONE = auto() # no deviation yet + WARNING = auto() # deviated this iteration; currently issuing warnings + WARNED = auto() # deviated in a previous iteration + + +class _ExecOrderData: + """ + This contains the data structures to track the execution order. We track + the pre-forward order on the *first* iteration for forward prefetching + (which thus assumes static graph) and the post-forward order on *every* + iteration for backward prefetching (which thus does not assume static + graph but may be provide an incorrect order). + """ + + def __init__( + self, + debug_level: dist.DebugLevel, + backward_prefetch_limit: int, + forward_prefetch_limit: int, + ) -> None: + # Tracks the (static) pre-forward order for execution order validation + # and forward prefetching + self.handles_pre_forward_order: List[int] = [] + # Maps each handles key to its index in `handles_pre_forward_order` + self.handles_to_pre_forward_order_index: Dict[_HandlesKey, int] = {} + # Tracks the post-forward order for pre-backward prefetching + self.handles_post_forward_order: List[int] = [] + # Maps each handles key to its index in `handles_post_forward_order` + self.handles_to_post_forward_order_index: Dict[_HandlesKey, int] = {} + self.is_first_iter = True + + # Gives the max number of backward/forward prefetched all-gathers by a + # single module + self._backward_prefetch_limit = backward_prefetch_limit + self._forward_prefetch_limit = forward_prefetch_limit + + # Data structures for execution order validation + self._checking_order: bool = ( + debug_level in [dist.DebugLevel.INFO, dist.DebugLevel.DETAIL] + ) + self.process_group: Optional[dist.ProcessGroup] = None + self.world_size: Optional[int] = None + self.all_handles: List[FlatParamHandle] = [] + # Maps each handle to its index in `all_handles`, which must be the + # same across ranks for the execution order validation to work + self.handle_to_handle_index: Dict[FlatParamHandle, int] = {} + # Names are prefixed from the root module + self.flat_param_to_prefixed_param_names: Dict[FlatParameter, List[str]] = {} + # Current index in the pre-forward execution order + self.current_order_index = 0 + self.warn_status = _ExecOrderWarnStatus.NONE + + def init( + self, + fsdp_root: "FullyShardedDataParallel", + process_group: dist.ProcessGroup, + ) -> None: + """ + Initializes the data structures needed for checking the forward order. + This should be called after a root FSDP instance has been set during + lazy initialization. + """ + self.process_group = process_group + self.rank = process_group.rank() + self.world_size = process_group.size() + # Fix an order over the handles, which should be the same across ranks + for fsdp_module in fsdp_root.fsdp_modules(fsdp_root): + for handle in fsdp_module._handles: + index = len(self.all_handles) + self.all_handles.append(handle) + self.handle_to_handle_index[handle] = index + self.flat_param_to_prefixed_param_names = cast( + Dict[FlatParameter, List[str]], + _get_param_to_unflat_param_names(fsdp_root), + ) + # TODO (awgu): We can broadcast the metadata of rank 0's `all_handles` + # to check that all ranks have the same handles in the same order. + # https://github.com/pytorch/pytorch/issues/79620 + + def get_handles_to_backward_prefetch( + self, + current_handles_key: _HandlesKey, + ) -> List[_HandlesKey]: + """ + Returns a :class:`list` of the handles keys of the handles to backward + prefetch given the current handles key. If there are no valid handles + keys to prefetch, then this returns an empty :class:`list`. + """ + current_index = self.handles_to_post_forward_order_index.get(current_handles_key, None) + if current_index is None: + return None + target_index = current_index - 1 + target_handles_keys: List[_HandlesKey] = [] + for _ in range(self._backward_prefetch_limit): + if target_index < 0: + break + target_handles_keys.append( + self.handles_post_forward_order[target_index] + ) + target_index -= 1 + return target_handles_keys + + def get_handles_to_forward_prefetch( + self, + current_handles_key: _HandlesKey, + ) -> List[_HandlesKey]: + """ + Returns a :class:`list` of the handles keys of the handles to forward + prefetch given the current handles key. If there are no valid handles + keys to prefetch, then this returns an empty :class:`list`. + """ + current_index = self.handles_to_pre_forward_order_index.get(current_handles_key, None) + if current_index is None: + return None + target_index = current_index + 1 + target_handles_keys: List[_HandlesKey] = [] + for _ in range(self._forward_prefetch_limit): + if target_index >= len(self.handles_pre_forward_order): + break + target_handles_keys.append( + self.handles_pre_forward_order[target_index] + ) + target_index += 1 + return target_handles_keys + + def record_post_forward(self, handles: List[FlatParamHandle]) -> None: + """ + Records ``handles`` in the post-forward order, where ``handles`` should + be a group of handles used in the same module's forward. If ``handles`` + is empty, then it is omitted. + + Unlike :meth:`record_pre_forward`, this records the order *every* + iteration with the expectation that the recorded order is reset in + :meth:`next_iter`. + """ + if not handles: + return + handles_key = tuple(handles) + # Only record the first usage of a handles key + if handles_key in self.handles_to_post_forward_order_index: + return + index = len(self.handles_post_forward_order) + self.handles_to_post_forward_order_index[handles_key] = index + self.handles_post_forward_order.append(handles_key) + + def record_pre_forward(self, handles: List[FlatParamHandle], is_training: bool) -> None: + """ + Records ``handles`` in the pre-forward order on the first iteration, + where ``handles`` should be a group of handles used in the same + module's forward. If ``handles`` is empty, then it is omitted. + + On the first iteration, this checks the execution order across ranks. + See :meth:`_check_order` for details. + """ + if not handles: + return + handles_key = tuple(handles) + self._check_order(handles_key, is_training) + # Fix the order after the first iteration and only record the first + # usage of a handles key + if ( + not self.is_first_iter + or handles_key in self.handles_to_pre_forward_order_index + ): + return + index = len(self.handles_pre_forward_order) + self.handles_to_pre_forward_order_index[handles_key] = index + self.handles_pre_forward_order.append(handles_key) + + def _check_order(self, handles_key: _HandlesKey, is_training: bool) -> None: + """ + Checks the forward execution order as long as ``is_training`` is + ``True`` since checking in eval mode is not supported. + + - On the first iteration, this uses all-gathers to check that all ranks + are all-gathering the same handles and hence ``FlatParameter`` s, + raising an error if not. + - On subsequent iterations, if the distributed debug level is at least + INFO, then this checks that each rank is locally consistent with its + own forward order from the first iteration, issuing a warning if not. + This issues a warning on the first deviating iteration and stops + warning thereafter. + """ + # Do not check order in eval mode since the post-backward callback does + # not run so it cannot be used to mark the end of an iteration + if not is_training: + return + if self.is_first_iter: + msg_prefix = "Forward order differs across ranks:" + local_indices: Optional[Tuple[int, ...]] = self._get_handle_indices( + handles_key + ) + device = handles_key[0].device # guaranteed to be non-CPU + num_valid_indices = sum((index is not None) for index in local_indices) + tensor_kwargs = {"dtype": torch.int32, "device": device} + world_num_valid_indices = torch.zeros(self.world_size, **tensor_kwargs) + local_num_valid_indices = torch.tensor([num_valid_indices], **tensor_kwargs) + dist._all_gather_base( + world_num_valid_indices, + local_num_valid_indices, + group=self.process_group, + ) + # Check that all ranks plan to all-gather the same number of + # parameters + # TODO (awgu): Since every module has at most one handle in the + # current implementation, this should never raise the error. + for (r1, n1), (r2, n2) in itertools.combinations( + ( + (rank, world_num_valid_indices[rank]) + for rank in range(self.world_size) + ), + 2, + ): + if n1 != n2: + raise RuntimeError( + f"{msg_prefix} rank {r1} is all-gathering {n1} parameters " + f"while rank {r2} is all-gathering {n2} parameters" + ) + world_indices = torch.zeros( + self.world_size * num_valid_indices, **tensor_kwargs + ) + local_indices = torch.tensor(local_indices, **tensor_kwargs) + dist._all_gather_base( + world_indices, local_indices, group=self.process_group + ) + # Check that all ranks plan to all-gather the same index parameters + for (r1, i1), (r2, i2) in itertools.combinations( + ( + ( + rank, + world_indices[ + rank * num_valid_indices : (rank + 1) * num_valid_indices + ], + ) + for rank in range(self.world_size) + ), + 2, + ): + if i1 != i2: + r1_param_names = self._get_names_from_handle_indices(i1) + r2_param_names = self._get_names_from_handle_indices(i2) + raise RuntimeError( + f"{msg_prefix} rank {r1} is all-gathering parameters " + f"for {r1_param_names} while rank {r2} is all-gathering " + f"parameters for {r2_param_names}" + ) + elif self._checking_order: + # Only issue warnings on the first deviating iteration and stop + # checking thereafter to avoid flooding the console + if self.warn_status == _ExecOrderWarnStatus.WARNED: + return + msg_prefix = None # non-`None` means we should warn + if self.current_order_index >= len(self.handles_pre_forward_order): + # This iteration sees extra all-gather(s) compared to the first + msg_prefix = ( + "Expected to not all-gather any more parameters in the " + "forward but trying to all-gather parameters for " + ) + else: + expected_handles_key = self.handles_pre_forward_order[ + self.current_order_index + ] + if expected_handles_key != handles_key: + expected_param_names = self._get_names_from_handles( + expected_handles_key + ) + msg_prefix = ( + f"Expected to all-gather for {expected_param_names} " + "but trying to all-gather parameters for " + ) + if msg_prefix is not None: + param_names = self._get_names_from_handles(handles_key) + msg_suffix = ( + f"{param_names}" + if param_names + else "a newly-added parameter since construction time" + ) + warnings.warn( + "Forward order differs from that of the first iteration " + f"on rank {self.rank}. Collectives are unchecked and may " + f"give incorrect results or hang.\n{msg_prefix}{msg_suffix}" + ) + self.warn_status = _ExecOrderWarnStatus.WARNING + self.current_order_index += 1 + + def _get_handle_indices( + self, + handles_key: _HandlesKey, + ) -> Tuple[Optional[int], ...]: + """ + Returns the handle indices (i.e. indices into ``self.all_handles``) + corresponding to the handles in ``handles_key``. An entry in the + returned tuple is ``None`` if the handle is invalid. + """ + indices: List[int] = [] + for handle in handles_key: + if handle not in self.handle_to_handle_index: + indices.append(None) + else: + indices.append(self.handle_to_handle_index[handle]) + return tuple(indices) + + def _get_names_from_handle_indices( + self, + handle_indices: Tuple[int, ...], + ) -> List[List[str]]: + """ + Returns a list of prefixed parameter names for each handle in + ``handle_indices``. If a handle index is invalid, then its prefixed + parameter names are omitted from the returned list. + """ + prefixed_param_names: List[List[str]] = [] + for index in handle_indices: + if index is None or index < 0 or index >= len(self.all_handles): + continue + handle = self.all_handles[index] + flat_param = handle.flat_param + prefixed_param_names.append(self.flat_param_to_prefixed_param_names[flat_param]) + return prefixed_param_names + + def _get_names_from_handles( + self, + handles_key: _HandlesKey, + ) -> List[List[str]]: + """ + Returns a list of prefixed parameter names for each handle in + ``handles_key``. If a handle is invalid, then its prefixed parameter + names are omitted from the returned list. + """ + prefixed_param_names: List[List[str]] = [] + for handle in handles_key: + flat_param = handle.flat_param + if flat_param not in self.flat_param_to_prefixed_param_names: + continue + prefixed_param_names.append(self.flat_param_to_prefixed_param_names[flat_param]) + return prefixed_param_names + + def next_iter(self): + """ + Advances the internal data structures per iteration. This should be + called in the post-backward callback since that marks the true end of + an iteration. + """ + self.is_first_iter = False + self.handles_to_post_forward_order_index.clear() + self.handles_post_forward_order.clear() + if self._checking_order: + self.current_order_index = 0 + if self.warn_status == _ExecOrderWarnStatus.WARNING: + self.warn_status = _ExecOrderWarnStatus.WARNED + + +class _FreeEventQueue: + """ + This tracks all pending frees corresponding to inflight all-gathers. The + queueing pattern is iterative enqueues with a single dequeue per iteration + once the limit ``_max_num_inflight_all_gathers`` is reached. + """ + + def __init__(self) -> None: + self._queue: Deque[torch_npu.npu.Event] = collections.deque() + self._max_num_inflight_all_gathers = 2 # empirically chosen + + def enqueue(self, free_event: torch_npu.npu.Event) -> None: + """Enqueues a free event.""" + self._queue.append(free_event) + + def dequeue_if_needed(self) -> Optional[torch_npu.npu.Event]: + """Dequeues a single event if the limit is reached.""" + if len(self._queue) >= self._max_num_inflight_all_gathers: + return self._dequeue() + return None + + def _dequeue(self) -> Optional[torch_npu.npu.Event]: + """Dequeues a free event if possible.""" + if self._queue: + event = self._queue.popleft() + return event + return None + + +# TODO (awgu): Refactor this later +sharding_strategy_map = { + ShardingStrategy.NO_SHARD: HandleShardingStrategy.NO_SHARD, + ShardingStrategy.FULL_SHARD: HandleShardingStrategy.FULL_SHARD, + ShardingStrategy.SHARD_GRAD_OP: HandleShardingStrategy.SHARD_GRAD_OP, +} + + +class FullyShardedDataParallel(nn.Module): + """ + A wrapper for sharding Module parameters across data parallel workers. This + is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_. + FullyShardedDataParallel is commonly shortened to FSDP. + + .. _`Xu et al.`: https://arxiv.org/abs/2004.13336 + .. _DeepSpeed: https://www.deepspeed.ai/ + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> import torch + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> torch_npu.npu.set_device(device_id) + >>> sharded_module = FSDP(my_module) + >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) + >>> x = sharded_module(x, y=3, z=torch.Tensor([1])) + >>> loss = x.sum() + >>> loss.backward() + >>> optim.step() + + .. warning:: + The optimizer must be initialized *after* the module has been wrapped, + since FSDP will shard parameters in-place and this will break any + previously initialized optimizers. + + .. warning:: + If the destination NPU device has ID ``dev_id``, either (1) + ``module`` should already be placed on that device, (2) the device + should be set using ``torch_npu.npu.set_device(dev_id)``, or (3) + ``dev_id`` should be passed into the ``device_id`` constructor + argument. This FSDP instance's compute device will be that destination + device. For (1) and (3), the FSDP initialization always occurs on NPU. + For (2), the FSDP initialization happens on ``module`` 's current + device, which may be CPU. + + .. warning:: + FSDP currently does not support gradient accumulation outside + ``no_sync()`` when using CPU offloading. Trying to do so yields + incorrect results since FSDP will use the newly-reduced gradient + instead of accumulating with any existing gradient. + + .. warning:: + Changing the original parameter variable names after construction will + lead to undefined behavior. + + .. warning:: + Passing in `sync_module_states=True` flag requires module to be put + on NPU, or to use ``device_id`` argument to specify a NPU device that + FSDP will move module to. This is because ``sync_module_states=True`` + requires NPU communication. + + .. warning:: + As of PyTorch 1.12, FSDP only offers limited support for shared parameters + (for example, setting one ``Linear`` layer's weight to another's). In + particular, modules that share parameters must be wrapped as part of the + same FSDP unit. If enhanced shared parameter support is needed for your + use case, please ping https://github.com/pytorch/pytorch/issues/77724 + + .. note:: + Inputs into FSDP ``forward`` function will be moved to compute device + (same device FSDP module is on) before running ``forward``, so user does + not have to manually move inputs from CPU -> NPU. + + Args: + module (nn.Module): + module to be wrapped with FSDP. + process_group (Optional[ProcessGroup]): + process group for sharding + sharding_strategy (Optional[ShardingStrategy]): + Config sharding algorithm, different sharding algorithm has trade + off between memory saving and communication overhead. ``FULL_SHARD`` + will be chosen if sharding_strategy is not specified. + cpu_offload (Optional[CPUOffload]): + CPU offloading config. Currently, only parameter and gradient CPU + offload is supported. It can be enabled via passing in + ``cpu_offload=CPUOffload(offload_params=True)``. Note that this + currently implicitly enables gradient offloading to CPU in order for + params and grads to be on same device to work with optimizer. This + API is subject to change. Default is ``None`` in which case there + will be no offloading. + auto_wrap_policy (Optional[Callable[[nn.Module, bool, int], bool]]): + A callable specifying a policy to recursively wrap layers with FSDP. + Note that this policy currently will only apply to child modules of + the passed in module. The remainder modules are always wrapped in + the returned FSDP root instance. + ``size_based_auto_wrap_policy`` written in ``torch.distributed.fsdp.wrap`` is + an example of ``auto_wrap_policy`` callable, this policy wraps layers + with the number of parameters larger than 100M. ``transformer_auto_wrap_policy`` + written in ``torch.distributed.fsdp.wrap`` is an example of ``auto_wrap_policy`` + callable for transformer-like model architectures. Users can supply the customized + ``auto_wrap_policy`` callable that should accept following arguments: + ``module: nn.Module``, ``recurse: bool``, ``unwrapped_params: int``, and return + a ``bool`` specifying whether the passed in ``module``` should be wrapped + (if ``recurse=False``) or whether we should recurse down the subgraph of ``module`` + children (if ``recurse=True``). Extra customized arguments could be added to + the customized ``auto_wrap_policy`` callable as well. It is a good practice to + print out the sharded model and check whether the sharded model is what + the application wants and then adjust accordingly. + + Example:: + + >>> def custom_auto_wrap_policy( + >>> module: nn.Module, + >>> recurse: bool, + >>> unwrapped_params: int, + >>> # These are customizable for this policy function. + >>> min_num_params: int = int(1e8), + >>> ) -> bool: + >>> return unwrapped_params >= min_num_params + >>> # Configure a custom min_num_params + >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=1e5) + + backward_prefetch (Optional[BackwardPrefetch]): + This is an experimental feature that is subject to change in the + the near future. It allows users to enable two different backward_prefetch + algorithms to help backward communication and computation overlapping. + Pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. + mixed_precision (Optional[MixedPrecision]): A ``MixedPrecision`` instance + describing the mixed precision training config to be used. ``MixedPrecision`` + supports configuring parameter, buffer, and gradient communication dtype. Note + that only floating point data is cast to the reduced precision. This allows + users potential memory saving and training speedup while trading off + accuracy during model training. If ``None``, no mixed precision is applied. + Note that if ``mixed_precision`` is enabled for FSDP model that + contains ``BatchNorm`` with ``auto_wrap_policy``, FSDP will take + care to disable mixed precision for ``BatchNorm`` units by wrapping + them separately in their own FSDP unit with ``mixed_precision=None``. + This is done because several ``BatchNorm`` kernels do not implement + reduced type support at the moment. If individually wrapping the model, + users must take care to set ``mixed_precision=None`` for + ``BatchNorm`` units. + (Default: ``None``) + ignored_modules (Optional[Iterable[torch.nn.Module]]): Modules whose + own parameters and child modules' parameters and buffers are + ignored by this instance. None of the modules directly in + ``ignored_modules`` should be :class:`FullyShardedDataParallel` + instances, and any child modules that are already-constructed + :class:`FullyShardedDataParallel` instances will not be ignored if + they are nested under this instance. This argument may be used to + avoid sharding specific parameters at module granularity when using an + ``auto_wrap_policy`` or if parameters' sharding is not managed by + FSDP. (Default: ``None``) + param_init_fn (Optional[Callable[[nn.Module], None]]): + A ``Callable[torch.nn.Module] -> None`` that + specifies how modules that are currently on the meta device should be initialized + onto an actual device. Note that as of v1.12, we detect modules on the meta + device via ``is_meta`` check and apply a default initialization that calls + ``reset_parameters`` method on the passed in ``nn.Module`` if ``param_init_fn`` + is not specified, otherwise we run ``param_init_fn`` to initialize the passed + in ``nn.Module``. In particular, this means that if ``is_meta=True`` for any + module parameters for modules that will be wrapped with FSDP and ``param_init_fn`` + is not specified, we assume your module properly implements a ``reset_paramters()`` + and will throw errors if not. Note that additionally, we offer support for modules + initialized with torchdistX's (https://github.com/pytorch/torchdistX) + ``deferred_init`` API. In this case, deferred modules would be initialized + by a default initialization function that calls torchdistX's + ``materialize_module``, or the passed in ``param_init_fn``, if it is not + ``None``. The same ``Callable`` is applied to initialize all meta modules. + Note that this initialization function is applied before doing any FSDP sharding + logic. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> module = MyModule(device="meta") + >>> def my_init_fn(module): + >>> # responsible for initializing a module, such as with reset_parameters + >>> ... + >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) + >>> print(next(fsdp_model.parameters()).device) # current NPU device + >>> # With torchdistX + >>> module = deferred_init.deferred_init(MyModule, device="npu") + >>> # Will initialize via deferred_init.materialize_module(). + >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy) + + device_id (Optional[Union[int, torch.device]]): An ``int`` or ``torch.device`` + describing the NPU device the FSDP module should be moved to determining where + initialization such as sharding takes place. If this argument is not specified + and ``module`` is on CPU, we issue a warning mentioning that this argument can + be specified for faster initialization. If specified, resulting FSDP instances + will reside on this device, including moving ignored modules' parameters if + needed. Note that if ``device_id`` is specified but ``module`` is already on a + different NPU device, an error will be thrown. (Default: ``None``) + sync_module_states (bool): If ``True``, each individually wrapped FSDP unit will broadcast + module parameters from rank 0 to ensure they are the same across all ranks after + initialization. This helps ensure model parameters are the same across ranks + before starting training, but adds communication overhead to ``__init__``, as at least + one broadcast is triggered per individually wrapped FSDP unit. + This can also help load checkpoints taken by ``state_dict`` and to be loaded by + ``load_state_dict`` in a memory efficient way. See documentation for + :class:`FullStateDictConfig` for an example of this. (Default: ``False``) + forward_prefetch (bool): If ``True``, then FSDP *explicitly* prefetches + the next upcoming all-gather while executing in the forward pass. + This may improve communication and computation overlap for CPU + bound workloads. This should only be used for static graph models + since the forward order is fixed based on the first iteration's + execution. (Default: ``False``) + limit_all_gathers (bool): If ``False``, then FSDP allows the CPU + thread to schedule all-gathers without any extra synchronization. + If ``True``, then FSDP explicitly synchronizes the CPU thread to + prevent too many in-flight all-gathers. This ``bool`` only affects + the sharded strategies that schedule all-gathers. Enabling this can + help lower the number of NPU malloc retries. + """ + def __init__( + self, + module: nn.Module, + process_group: Optional[ProcessGroup] = None, + sharding_strategy: Optional[ShardingStrategy] = None, + cpu_offload: Optional[CPUOffload] = None, + auto_wrap_policy: Optional[Callable] = None, + backward_prefetch: Optional[BackwardPrefetch] = None, + mixed_precision: Optional[MixedPrecision] = None, + ignored_modules: Optional[Iterable[torch.nn.Module]] = None, + param_init_fn: Optional[Callable[[nn.Module], None]] = None, + device_id: Optional[Union[int, torch.device]] = None, + sync_module_states: bool = False, + forward_prefetch: bool = False, + limit_all_gathers: bool = False, + ): + if isinstance(auto_wrap_policy, ParamExecOrderWrapPolicy): + self._init_param_exec_order_wrap_policy( + module=module, + process_group=process_group, + sharding_strategy=sharding_strategy, + cpu_offload=cpu_offload, + auto_wrap_policy=auto_wrap_policy, + backward_prefetch=backward_prefetch, + mixed_precision=mixed_precision, + ignored_modules=ignored_modules, + param_init_fn=param_init_fn, + device_id=device_id, + sync_module_states=sync_module_states, + forward_prefetch=forward_prefetch, + limit_all_gathers=limit_all_gathers, + ) + return + + torch._C._log_api_usage_once("torch.distributed.fsdp") + super().__init__() + + self._ignored_modules = self._get_ignored_modules(module, ignored_modules) + ignored_params, self._ignored_param_names = self._get_ignored_params( + module, self._ignored_modules + ) + self._buffer_names = self._get_buffer_names(module) + if auto_wrap_policy is not None: + auto_wrap_kwargs = { + "module": module, + "auto_wrap_policy": auto_wrap_policy, + "wrapper_cls": FullyShardedDataParallel, + "ignored_modules": self._ignored_modules, + "ignored_params": ignored_params, + "only_wrap_children": True, # avoid double wrapping the root + } + fsdp_kwargs = { + "process_group": process_group, + "sharding_strategy": sharding_strategy, + "cpu_offload": cpu_offload, + "backward_prefetch": backward_prefetch, + "mixed_precision": mixed_precision, + "param_init_fn": param_init_fn, + "device_id": device_id, + "sync_module_states": sync_module_states, + "forward_prefetch": forward_prefetch, + "limit_all_gathers": limit_all_gathers, + } + self._auto_wrap(auto_wrap_kwargs, fsdp_kwargs) + + self.process_group = process_group or _get_default_group() + self.rank = self.process_group.rank() + self.world_size = self.process_group.size() + self.training_state = TrainingState_.IDLE + self.cpu_offload = cpu_offload or CPUOffload() + self.backward_prefetch = backward_prefetch + self.forward_prefetch = forward_prefetch + self.limit_all_gathers = limit_all_gathers + backward_prefetch_limit = 1 + forward_prefetch_limit = 1 + # We clamp the strategy to `NO_SHARD` for world size of 1 since they + # are currently functionally equivalent. This may change if/when we + # integrate FSDP with MoE. + if self.world_size == 1: + sharding_strategy = ShardingStrategy.NO_SHARD + self.sharding_strategy = sharding_strategy or ShardingStrategy.FULL_SHARD + self.mixed_precision = mixed_precision or MixedPrecision() + # Save a mapping from fully prefixed buffer name to its original dtype + # since for mixed precision, buffers are restored to their original + # dtype for model checkpointing + self._buffer_name_to_orig_dtype: Dict[str, torch.dtype] = {} + + self._check_single_device_module(module, ignored_params) + device_from_device_id: Optional[torch.device] = self._get_device_from_device_id(device_id) + self._materialize_module(module, param_init_fn, device_from_device_id) + self._move_module_to_device(module, ignored_params, device_from_device_id) + self.compute_device = self._get_compute_device(module, ignored_params, device_from_device_id) + params_to_flatten = list(self._get_orig_params(module, ignored_params)) + if sync_module_states: + self._sync_module_states(module, params_to_flatten) + + # This FSDP instance's handles should inherit the same process group, + # compute device, CPU offload, and mixed precision settings. However, + # different sharding strategies are allowed. + config = HandleConfig( + sharding_strategy_map[self.sharding_strategy], + self.cpu_offload.offload_params, + self.mixed_precision.param_dtype, + self.mixed_precision.reduce_dtype, + self.mixed_precision.keep_low_precision_grads, + ) + self._fsdp_wrapped_module = FlattenParamsWrapper( + module, + params_to_flatten, + self.compute_device, + config, + ) + self._check_orig_params_flattened(ignored_params) + # Invariant: `self.params` contains exactly the `FlatParameter`s of the + # handles in `self._handles` + self._handles: List[FlatParamHandle] = [] + self.params: List[FlatParameter] = [] + if self._fsdp_wrapped_module.has_params: + handle = self._fsdp_wrapped_module.handle + self.params.append(handle.flat_param) + self._register_param_handle(handle) + handle.shard(self.process_group) + if self.cpu_offload.offload_params and handle.flat_param.device != torch.device("cpu"): + with torch.no_grad(): + handle._flat_param_to(torch.device("cpu")) + + self._sync_gradients = True + self._communication_hook = self._get_default_comm_hook() + self._communication_hook_state = self._get_default_comm_hook_state() + self._hook_registered = False + + # Used to prevent running the pre-backward hook multiple times + self._ran_pre_backward_hook: Dict[_HandlesKey, bool] = {} + self._is_root: Optional[bool] = None # `None` indicates not yet set + # The following attributes are owned by the root FSDP instance and + # shared with non-root FSDP instances + self._streams: Dict[str, torch_npu.npu.Stream] = {} + self._free_event_queue = _FreeEventQueue() + self._debug_level = dist.get_debug_level() + self._exec_order_data = _ExecOrderData( + self._debug_level, + backward_prefetch_limit, + forward_prefetch_limit, + ) + self._handles_prefetched: Dict[_HandlesKey, bool] = {} + # Used for guarding against mistargeted backward prefetches + self._needs_pre_backward_unshard: Dict[_HandlesKey, bool] = {} + # Used for guarding against mistargeted forward prefetches + self._needs_pre_forward_unshard: Dict[_HandlesKey, bool] = {} + # The data structures use tuples of handles to generalize over the case + # where a module's forward involves multiple handles. + + # `_state_dict_type` controls the `state_dict()` behavior, which is + # implemented using post-save and pre-load hooks + self._state_dict_type = StateDictType.FULL_STATE_DICT + self._state_dict_config = FullStateDictConfig() + self._register_state_dict_hook(self._post_state_dict_hook) + self._post_state_dict_hook_fn = { + StateDictType.FULL_STATE_DICT: self._full_post_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: self._local_post_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: self._sharded_post_state_dict_hook, + } + self._register_load_state_dict_pre_hook( + self._pre_load_state_dict_hook, with_module=True + ) + self._pre_load_state_dict_hook_fn = { + StateDictType.FULL_STATE_DICT: self._full_pre_load_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: self._local_pre_load_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: self._sharded_pre_load_state_dict_hook, + } + self.register_load_state_dict_post_hook( + self._post_load_state_dict_hook + ) + self._post_load_state_dict_hook_fn = { + StateDictType.FULL_STATE_DICT: self._full_post_load_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: self._local_post_load_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: self._sharded_post_load_state_dict_hook, + } + + def _get_ignored_modules( + self, + root_module: nn.Module, + _ignored_modules: Optional[Iterable[torch.nn.Module]], + ) -> Set[nn.Module]: + """ + Checks that ``_ignored_modules`` is an iterable of ``nn.Module`` s + without any FSDP instances, and returns the modules contained in their + module subtrees as a :class:`set`. Nested FSDP instances are excluded, + but their already-computed ignored modules are included. + """ + if _ignored_modules is None: + return set() + msg_prefix = "`ignored_modules` should be an iterable of `torch.nn.Module`s " + try: + ignored_root_modules = set(_ignored_modules) + except TypeError: + raise TypeError(msg_prefix + f"but got {type(_ignored_modules)}") + for module in ignored_root_modules: + if not isinstance(module, torch.nn.Module): + raise TypeError(msg_prefix + f"but got an iterable with {type(module)}") + if isinstance(module, FullyShardedDataParallel): + raise ValueError("`ignored_modules` should not include FSDP modules") + # Include child modules and exclude nested FSDP modules themselves + ignored_modules = set( + child + for module in ignored_root_modules + for child in module.modules() + if not isinstance(child, (FullyShardedDataParallel, FlattenParamsWrapper)) + ) + if root_module in ignored_modules: + warnings.warn( + "Trying to ignore the top-level module passed into the FSDP " + "constructor itself will result in all parameters being " + f"ignored and is not well-supported: {module}" + ) + # Include nested FSDP modules' ignored modules + for submodule in root_module.modules(): + if isinstance(submodule, FullyShardedDataParallel): + assert hasattr(submodule, "_ignored_modules") + ignored_modules.update(submodule._ignored_modules) + return ignored_modules + + def _get_ignored_params( + self, + root_module: torch.nn.Module, + ignored_modules: Set[torch.nn.Module], + ) -> Tuple[Set[torch.nn.Parameter], Set[str]]: + """ + Returns the parameters of the modules in ``ignored_modules``, + excluding any :class:`FlatParameter` s, and their fully prefixed names, + both as :class:`set` s. + """ + ignored_params = set( + p + for m in ignored_modules + for p in m.parameters() + if not _is_fsdp_flattened(p) + ) + # Conservatively include all shared parameters' names + param_to_unflat_param_names = _get_param_to_unflat_param_names( + root_module, + dedup_shared_params=False, + ) + ignored_param_names = set() + for param in ignored_params: + unflat_param_names = param_to_unflat_param_names[param] + clean_names = [] + for k in unflat_param_names: + # Clean any module wrapper prefixes in case of nested wrapping + clean_names.append(clean_tensor_name(k)) + ignored_param_names.update(clean_names) + return ignored_params, ignored_param_names + + def _get_buffer_names(self, root_module: nn.Module) -> Set[str]: + """ + Returns the fully prefixed names of all buffers in the module hierarchy + rooted at ``root_module`` as a class:`set`. + """ + + def module_fn(module: nn.Module, prefix: str, buffer_names: Set[str]): + # For FSDP modules, only add the entry when considering the + # contained `FlattenParamsWrapper` to avoid duplication + if not isinstance(module, FullyShardedDataParallel): + for buffer_name, _ in module.named_buffers(recurse=False): + # Clean module wrapper prefixes in case of nested wrapping + prefixed_buffer_name = clean_tensor_name(prefix + buffer_name) + buffer_names.add(prefixed_buffer_name) + + def return_fn(buffer_names: Set[str], *args): + return buffer_names + + buffer_names: Set[str] = set() + return _apply_to_modules( + root_module, + module_fn, + return_fn, + buffer_names, + ) + + def _auto_wrap( + self, + auto_wrap_kwargs: Dict[str, Any], + fsdp_kwargs: Dict[str, Any], + ) -> None: + """ + Recursively auto wraps the root module given by the key "module" in + ``auto_wrap_kwargs`` with the arguments in ``auto_wrap_kwargs`` and + ``fsdp_kwargs``. + + Precondition: ``auto_wrap_policy`` contains the arguments expected by + ``_recursive_wrap()``, where ``auto_wrap_policy`` is not ``None``. + ``fsdp_kwargs`` contains all FSDP arguments except ``module``. + """ + auto_wrap_policy = auto_wrap_kwargs["auto_wrap_policy"] + root_module = auto_wrap_kwargs["module"] + assert auto_wrap_policy is not None + # For auto wrapping, submodules should not already be wrapped with FSDP + # since double wrapping is not supported + for module_name, module in root_module.named_modules(): + if isinstance(module, FullyShardedDataParallel): + raise ValueError( + f"Expected {module_name} to NOT be FullyShardedDataParallel " + "if using an `auto_wrap_policy`" + ) + mixed_precision = fsdp_kwargs["mixed_precision"] + if mixed_precision is not None and _contains_batchnorm(root_module): + _override_batchnorm_mixed_precision(root_module) + auto_wrap_policy = functools.partial( + _or_policy, policies=[_wrap_batchnorm_individually, auto_wrap_policy] + ) + warnings.warn( + "Both mixed precision and an `auto_wrap_policy` were specified " + "for FSDP, where the wrapped module has batch norm submodules. " + "The batch norm submodules will be wrapped as separate FSDP " + "instances with mixed precision disabled since some batch norm " + "kernels do not support low precision." + ) + auto_wrap_kwargs["auto_wrap_policy"] = auto_wrap_policy + _recursive_wrap(**auto_wrap_kwargs, **fsdp_kwargs) + + def _check_single_device_module( + self, + module: nn.Module, + ignored_params: Set[nn.Parameter], + ) -> None: + """ + Raises an error if ``module`` has original parameters on multiple + devices, ignoring the parameters in ``ignored_params``. Thus, after + this method, the module must be either fully on the CPU or fully on a + non-CPU device. + """ + devices = set( + param.device for param in self._get_orig_params(module, ignored_params) + ) + if len(devices) > 1: + raise RuntimeError( + f"FSDP only supports single device modules but got params on {devices}" + ) + + def _get_device_from_device_id( + self, + device_id: Optional[Union[int, torch.device]], + ) -> Optional[torch.device]: + """ + """ + if device_id is None: + return None + device = ( + device_id + if isinstance(device_id, torch.device) + else torch.device(device_id) + ) + if device == torch.device("npu"): + warnings.warn( + f"FSDP got the argument `device_id` {device_id} on rank " + f"{self.rank}, which does not have an explicit index. " + f"FSDP will use the current device {torch_npu.npu.current_device()}. " + "If this is incorrect, please explicitly call `torch_npu.npu.set_device()` " + "before FSDP initialization or pass in the explicit device " + "index as the `device_id` argument." + ) + device = torch.device("npu", torch_npu.npu.current_device()) + return device + + def _materialize_module( + self, + module: nn.Module, + param_init_fn: Optional[Callable[[nn.Module], None]], + device_from_device_id: Optional[torch.device], + ) -> None: + """ + Materializes the wrapped module ``module`` in place if needed: either + if the module has parameters that use meta device or are torchdistX + fake tensors. + + This method uses ``param_init_fn`` to materialize the module if the + function is not ``None`` and falls back to default behavior otherwise. + For meta device, this moves the module to ``device_from_device_id`` if + it is not ``None`` or the current device otherwise and calls + ``reset_parameters()``, and for torchdistX fake tensors, this calls + ``deferred_init.materialize_module()``. + """ + is_meta_module = any(p.is_meta for p in module.parameters()) + is_torchdistX_deferred_init = ( + not is_meta_module + and _TORCHDISTX_AVAIL + and any(fake.is_fake(p) for p in module.parameters()) + ) + if ( + is_meta_module or is_torchdistX_deferred_init + ) and param_init_fn is not None: + if not callable(param_init_fn): + raise ValueError( + f"Expected {param_init_fn} to be callable but got {type(param_init_fn)}" + ) + param_init_fn(module) + elif is_meta_module: + # Run default meta device initialization + materialization_device = device_from_device_id or torch_npu.npu.current_device() + module.to_empty(device=materialization_device) + try: + with torch.no_grad(): + module.reset_parameters() + except BaseException as e: + warnings.warn( + "Unable to call `reset_parameters()` for module on meta " + f"device with error {str(e)}. Please ensure your " + "module implements a `reset_parameters()` method." + ) + raise e + elif is_torchdistX_deferred_init: + # Run default torchdistX initialization + deferred_init.materialize_module( + module, + check_fn=lambda k: not isinstance(k, FullyShardedDataParallel), + ) + + def _move_module_to_device( + self, + module: nn.Module, + ignored_params: Set[nn.Parameter], + device_from_device_id: Optional[torch.device], + ): + """ + Moves ``module`` depending on ``device_from_device_id`` and its current + device. This includes moving ignored modules' parameters. + + - If ``device_from_device_id`` is not ``None``, then this moves + ``module`` to the device. + - If ``device_from_device_id`` is ``None``, then this does not move + ``module`` but warns the user if it is on CPU. + + Precondition: ``_check_single_device_module()``. + """ + cpu_device = torch.device("cpu") + param = next(self._get_orig_params(module, ignored_params), None) + if param is None: + return # no original parameters to manage + if device_from_device_id is not None: + if param.device == cpu_device: + # NOTE: This includes moving ignored modules' parameters. + module = module.to(device_from_device_id) + # TODO: This is a temporary fix to move already- constructed + # `FlatParameter`s back to CPU if needed. This is needed to + # make CPU offload work with `device_id`. + for submodule in module.modules(): + if ( + isinstance(submodule, FullyShardedDataParallel) + and submodule.cpu_offload.offload_params + ): + with torch.no_grad(): + for handle in submodule._handles: + handle._flat_param_to(torch.device("cpu")) + elif param.device == cpu_device: + warnings.warn( + "Module is put on CPU and will thus have flattening and sharding" + " run on CPU, which is less efficient than on NPU. We recommend passing in " + "`device_id` argument which will enable FSDP to put module on NPU device," + " module must also be on NPU device to work with `sync_module_states=True` flag" + " which requires NPU communication." + ) + + def _get_compute_device( + self, + module: nn.Module, + ignored_params: Set[nn.Parameter], + device_from_device_id: Optional[torch.device], + ) -> torch.device: + """ + Determines and returns this FSDP instance's compute device. If the + module is already on a non-CPU device, then the compute device is that + non-CPU device. If the module is on CPU, then the compute device is the + current device. + + Since this method should be called after materializing the module, any + non-CPU device should not be meta device. For now, the compute device + is always a NPU device with its explicit index. + + Precondition: ``_check_single_device_module()`` and + ``_move_module_to_device()``. + """ + # If the module is on NPU already, then that NPU device has priority + # over the current device + param = next(self._get_orig_params(module, ignored_params), None) + if param is not None and param.device.type == "npu": + compute_device = param.device + else: + compute_device = torch.device("npu", torch_npu.npu.current_device()) + if ( + device_from_device_id is not None + and compute_device != device_from_device_id + ): + raise ValueError( + "Inconsistent compute device and `device_id` on rank " + f"{self.rank}: {compute_device} vs {device_from_device_id}" + ) + return compute_device + + def _sync_module_states( + self, module: nn.Module, params: List[nn.Parameter] + ) -> None: + """ + Synchronizes module states (i.e. parameters ``params`` and all + not-yet-synced buffers) by broadcasting from rank 0 to all ranks. + + Precondition: ``sync_module_states == True`` and ``self.process_group`` + has been set. + """ + if params and any(param.device == torch.device("cpu") for param in params): + raise ValueError( + "Module has CPU parameters, but sync_module_states=True is specified." + "This only works for NPU module, please specify `device_id` argument or move" + " module to NPU before init." + ) + module_states: List[torch.Tensor] = [] + # TODO (awgu): When exposing the original parameters, we need to also + # use this attribute to prevent re-synchronizing parameters. + for buffer in module.buffers(): + # Avoid re-synchronizing buffers in case of nested wrapping + if not getattr(buffer, "_fsdp_synced", False): + buffer._fsdp_synced = True + module_states.append(buffer.detach()) + module_states.extend(param.detach() for param in params) + _sync_params_and_buffers( + self.process_group, module_states, _PARAM_BROADCAST_BUCKET_SIZE, src=0, + ) + + def _get_orig_params( + self, + module: nn.Module, + ignored_params: Set[nn.Parameter], + ) -> Iterator[nn.Parameter]: + """ + Returns an iterator over the original parameters in ``module``, + ignoring the parameters in ``ignored_params`` and any ``FlatParameter`` + s (which may be present due to nested FSDP wrapping). + """ + param_gen = module.parameters() + try: + while True: + param = next(param_gen) + if param not in ignored_params and not _is_fsdp_flattened(param): + yield param + except StopIteration: + pass + + def _check_orig_params_flattened(self, ignored_params: Set[nn.Parameter]) -> None: + """ + Checks that all original parameters have been flattened and hence made + invisible to ``named_parameters()``. This should be called as a sanity + check after flattening the wrapped module's parameters. + """ + for param_name, param in self.named_parameters(): + if param not in ignored_params and not _is_fsdp_flattened(param): + raise RuntimeError( + f"Found an unflattened parameter: {param_name}; " + f"{param.size()} {param.__class__}" + ) + + def _register_param_handle(self, handle: FlatParamHandle) -> None: + """Registers the parameter handle to this FSDP instance.""" + if handle not in self._handles: + self._handles.append(handle) + + @torch.no_grad() + def _unshard( + self, + handles: List[FlatParamHandle], + ) -> None: + """ + Unshards the handles in ``handles``. If the handles are in + :meth:`summon_full_params` and are using mixed precision, then they are + forced to full precision. + + Postcondition: Each handle's ``FlatParameter`` 's data is the padded + unsharded flattened parameter on the compute device. + """ + if not handles: + return + if self.limit_all_gathers: + event = self._free_event_queue.dequeue_if_needed() + if event: + event.synchronize() + any_ran_pre_unshard = False + with torch_npu.npu.stream(self._streams["pre_all_gather"]): + for handle in handles: + ran_pre_unshard = handle.pre_unshard() + any_ran_pre_unshard = any_ran_pre_unshard or ran_pre_unshard + if any_ran_pre_unshard: + self._streams["all_gather"].wait_stream(self._streams["pre_all_gather"]) + with torch_npu.npu.stream(self._streams["all_gather"]): + for handle in handles: + handle.unshard() + handle.post_unshard() + + def _reshard( + self, # unused + handles: List[FlatParamHandle], + free_unsharded_flat_params: List[bool], + ) -> None: + """ + Reshards the handles in ``handles``. ``free_unsharded_flat_params`` + should have the same length as ``handles``, and each element should + give whether the corresponding handle should free its padded unsharded + flattened parameter. + """ + if not handles: + return + p_assert( + len(handles) == len(free_unsharded_flat_params), + "Expects both lists to have equal length but got " + f"{len(handles)} and {len(free_unsharded_flat_params)}" + ) + for handle, free_unsharded_flat_param in zip( + handles, + free_unsharded_flat_params, + ): + handle.reshard(free_unsharded_flat_param) + if self.limit_all_gathers and free_unsharded_flat_param: + free_event = torch_npu.npu.Event() + free_event.record() + self._free_event_queue.enqueue(free_event) + handle.post_reshard() + # Since we prefetch entire handles keys at a time, conservatively mark + # the entire key as no longer prefetched once we free at least one + handles_key = tuple(handles) + if any(free_unsharded_flat_params): + self._handles_prefetched.pop(handles_key, None) + + @property + def module(self) -> nn.Module: + """ + Returns the wrapped module (like :class:`DistributedDataParallel`). + """ + assert isinstance(self._fsdp_wrapped_module, FlattenParamsWrapper) + return self._fsdp_wrapped_module.module + + def __getattr__(self, name: str) -> Any: + """Forward missing attributes to wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self._fsdp_wrapped_module, name) + + def __getitem__(self, key: int) -> Any: + """Forward indexing calls in case the module is a nn.Sequential.""" + return self._fsdp_wrapped_module.__getitem__(key) # type: ignore[operator] + + def check_is_root(self) -> bool: + self._lazy_init() + assert self._is_root is not None + return self._is_root + + @staticmethod + def fsdp_modules( + module: nn.Module, + root_only: bool = False, + ) -> List["FullyShardedDataParallel"]: + """ + Returns all nested FSDP instances, possibly including ``module`` itself + and only including FSDP root modules if ``root_only=True``. + + Args: + module (torch.nn.Module): Root module, which may or may not be an + ``FSDP`` module. + root_only (bool): Whether to return only FSDP root modules. + (Default: ``False``) + + Returns: + List[FullyShardedDataParallel]: FSDP modules that are nested in + the input ``module``. + """ + return [ + submodule for submodule in module.modules() + if isinstance(submodule, FullyShardedDataParallel) and + (not root_only or submodule.check_is_root()) + ] + + def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel": + r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``) + as well as self. Typical use includes initializing the parameters of a model + (see also :ref:`nn-init-doc`). + + Compared to ``torch.nn.Module.apply``, this version additionally gathers + the full parameters before applying ``fn``. It should not be called from + within another ``summon_full_params`` context. + + Args: + fn (:class:`Module` -> None): function to be applied to each submodule + + Returns: + Module: self + """ + uninitialized = self._is_root is None + self._assert_state(TrainingState_.IDLE) + with self._summon_full_params(recurse=False, writeback=True): + ret = super().apply(fn) + + # Reset lazy init that might be called by _summon_full_params, since + # it could have set is_root incorrectly for non-root FSDP instances. + if uninitialized and self._is_root: + for module in self.fsdp_modules(self): + module._reset_lazy_init() + + return ret + + def _mixed_precision_enabled_for_params(self) -> bool: + """ + Whether user explicitly enabled mixed precision for + parameters or not. + """ + return self.mixed_precision.param_dtype is not None + + def _mixed_precision_enabled_for_buffers(self) -> bool: + """ + Whether user explicitly enabled mixed precision for + buffers or not. + """ + return self.mixed_precision.buffer_dtype is not None + + def _mixed_precision_enabled_for_reduce(self) -> bool: + """ + Whether user explicitly enabled mixed precision for + gradient reduction or not. + """ + return self.mixed_precision.reduce_dtype is not None + + def _mixed_precision_keep_low_precision_grads(self) -> bool: + return ( + self.mixed_precision is not None + and self.mixed_precision.keep_low_precision_grads + ) + + def _low_precision_hook_enabled(self) -> bool: + """ + Wether a low precision hook is registered or not. + """ + return ( + self._communication_hook is not None + and self._communication_hook in LOW_PRECISION_HOOKS + ) + + def _cast_fp_inputs_to_dtype( + self, dtype: torch.dtype, *args: Any, **kwargs: Any + ) -> Tuple[Any, Any]: + """ + Casts floating point tensors in ``args`` and ``kwargs`` to the + precision given by ``dtype``, while respecting the existing + ``requires_grad`` on the tensors. + """ + def cast_fn(x: torch.Tensor) -> torch.Tensor: + if not torch.is_floating_point(x): + return x + y = x.to(dtype) + # Explicitly copy over `requires_grad` since this runs inside + # `torch.no_grad()` + if x.is_leaf: + y.requires_grad = x.requires_grad + return y + + with torch.no_grad(): + return ( + _apply_to_tensors(cast_fn, args), + _apply_to_tensors(cast_fn, kwargs) + ) + + def _cast_buffers( + self, + device: Optional[torch.device] = None, + dtype: Optional[Dict[str, torch.dtype]] = None, + memo: Optional[Set] = None, + recurse: bool = True, + ) -> None: + """Move all buffers to the given *device* and *dtype*. + If *device* is not given, then it will default to + ``self.compute_device``, otherwise buffer will be moved to ``device``. + In the case of nested FSDP instances, we will respect the child instance's + ``compute_device`` configuration. + If *dtype* is given, it must be a mapping of buffer name to buffer dtype, + and this argument is currently only given to restore back to original + buffer types during checkpoint. If *dtype* is not given, and we are + in mixed precision training, the buffer will be cast to buffer_dtype, + otherwise the buffer will not be cast. + Args: + device (torch.device, Optional): + device to cast buffers to (defaults to compute_device) + dtype: (Dict[str, torch.dtype], Optional): + Mapping of buffer name to their dtype to cast to. + memo (Set, Optional): + set of modules that have already been processed + recurse (bool, Optional): + Whether to call _cast_buffers recursively on nested FSDP + instances (default is True). + """ + if memo is None: + memo = set() + for module in self.modules(): + if module is not self and isinstance(module, FullyShardedDataParallel) and recurse: + # Allow any child FSDP instances to handle their own buffers. + module._cast_buffers(device=device, dtype=dtype, memo=memo, recurse=recurse) + elif module not in memo: + memo.add(module) + for name, buf in module.named_buffers(recurse=False): + if buf is None: + continue + buf = buf.to(device=device or self.compute_device) + if name not in self._buffer_name_to_orig_dtype: + self._buffer_name_to_orig_dtype[name] = buf.dtype + # If given, cast buffer to the given dtype. This is used to + # suppport mixed precision for buffers + # (given by self.mixed_precision.buffer_dtype) and also used + # to restore the buffer dtype to the original precision for + # state_dict() calls. + # Note that non-floating point buffers are not casted. + if torch.is_floating_point(buf): + # We are restoring the original buffer type in + # preparation for checkpoint. + if dtype: + buf = buf.to(dtype=dtype[name]) + # Note that we don't pass in self.mixed_precision.buffer_dtype + # recursively into _cast_buffers, as we want to respect + # mp config for child FSDP instances. + elif self._mixed_precision_enabled_for_buffers(): + buf = buf.to(self.mixed_precision.buffer_dtype) + + setattr(module, name, buf) + + def _reset_lazy_init(self) -> None: + """ + Reset instance so :func:`_lazy_init` will run on the next forward. + """ + self._is_root: Optional[bool] = None + for p in self.params: + if hasattr(p, "_local_shard"): + # We only need to `del` `_local_shard` because + # `_init_param_attributes()` gates the logic based on its + # existence (and not any of the other attributes). + del p._local_shard # type: ignore[attr-defined] + + def _lazy_init(self) -> None: + """ + Performs initialization lazily, typically right before the first + forward pass. The laziness is needed to ensure that the parameter + device/dtype and the FSDP hierarchy have finalized. + + This method's actual logic only runs on the root FSDP instance, which + performs initialization for all non-root FSDP instances to avoid + partial initialization. + """ + if self._is_root is not None: + return # no-op: already initialized + if not torch_npu.npu.is_available(): + # Allow the FSDP constructor to run even with NPU but check this + # once we start real execution + raise RuntimeError("FSDP does not support CPU only execution") + # The following logic is only run on the root FSDP instance since it + # will set `_is_root=False` for the non-root instances + self._is_root = True + self._assert_state(TrainingState_.IDLE) + self._init_streams() + self._cast_buffers(recurse=True) + for handle in self._handles: + self._init_param_attributes(handle) + self._exec_order_data.init(self, self.process_group) + # Initialize non-root FSDP instances and share attributes from the root + # to non-root instances + inconsistent_limit_all_gathers = False + for fsdp_module in self.fsdp_modules(self): + if fsdp_module is not self: + # Relax the assert for non-root FSDP instances in case the + # nested initialized module is wrapped again in FSDP later (e.g. + # after training to run inference) + assert fsdp_module._is_root is None or not fsdp_module._is_root, ( + "Non-root FSDP instance's `_is_root` should not have been " + "set yet or should have been set to `False`" + ) + fsdp_module._is_root = False + fsdp_module._streams = self._streams + fsdp_module._exec_order_data = self._exec_order_data + if fsdp_module.limit_all_gathers != self.limit_all_gathers: + # Prefer the root's value + inconsistent_limit_all_gathers = True + fsdp_module.limit_all_gathers = self.limit_all_gathers + fsdp_module._free_event_queue = self._free_event_queue + fsdp_module._handles_prefetched = self._handles_prefetched + fsdp_module._needs_pre_backward_unshard = self._needs_pre_backward_unshard + for handle in fsdp_module._handles: + fsdp_module._init_param_attributes(handle) + if inconsistent_limit_all_gathers: + warnings.warn( + "Found inconsistent `limit_all_gathers` values across FSDP " + f"instances on rank {self.rank}. Using the root FSDP's value " + f"of {self.limit_all_gathers} for all instances." + ) + + # TODO (awgu): Move this to the `FlatParamHandle` class later + @torch.no_grad() + def _init_param_attributes(self, handle: FlatParamHandle) -> None: + """ + We manage several attributes on each Parameter instance. + A few attributes are set here: + ``_local_shard``: a single shard of the parameter. This is needed to + recover the shard after rebuilding full parameter in forward + and backward. + ``_full_param_padded``: the full weight (padded to be evenly + divisible by ``world_size``), used for computation in the + forward and backward pass. It is initialized with the + appropriate size and then has its storage freed. This will be + resized in place and only materialized (via all-gather) as needed. + Another attribute is set by :func:`_register_post_backward_hooks`: + ``_post_backward_hook_state``: it holds the parameter's AccumulateGrad object + and the registered post hook handle. + """ + p = handle.flat_param + # If _local_shard has been set in the first lazy init and + # current parameter is pointed to _local_shard, no need to + # set the _local_shard again. + if hasattr(p, "_local_shard"): + # If CPU offloading, p._local_shard should have been placed on CPU + # during its first lazy construction. + if self.cpu_offload.offload_params: + assert p._local_shard.device == torch.device( # type: ignore[attr-defined] + "cpu" + ), ( + "Expected p._local_shard to be on CPU, " # type: ignore[attr-defined] + f"but it's on {p._local_shard.device}" # type: ignore[attr-defined] + ) + return + + # A single shard of the parameters. Also makes p._local_shard to be on + # CPU if we are CPU offloading, since p.data would be on CPU during + # init. + if self.cpu_offload.offload_params: + assert p.device == torch.device("cpu"), ( + "Expected param to be on CPU when cpu_offloading is enabled. " + "If CPU offloading is enabled correctly, you may be " + "accidentally moving the model to NPU after FSDP initialization." + ) + p._local_shard = p.data # type: ignore[attr-defined] + # If CPU offloading, pin the memory to enable faster CPU -> NPU device + # transfer. + if self.cpu_offload.offload_params: + assert p._local_shard.device == torch.device("cpu") # type: ignore[attr-defined] + p._local_shard = p._local_shard.pin_memory() # type: ignore[attr-defined] + # When offloading parameters, also move the grad shard to CPU during + # backward pass. In this case, it's important to pre-allocate the + # CPU grad shard in pinned memory so that we can do a non-blocking + # transfer. + p._cpu_grad = torch.zeros_like( # type: ignore[attr-defined] + p, device=torch.device("cpu") + ).pin_memory() + + # If mixed_precision, maintain reduced precision param shard on + # compute_device for computation in fwd/bwd. We resize storage to 0 here + # and rematerialize before building the full param when needed. After + # fwd/bwd, it is freed and we only hold on to the full precision shard. + # As a result, this reduced precision shard is not allocated if we are + # not in the forward/backward pass. + if ( + self._mixed_precision_enabled_for_params() + ): + p._mp_shard = torch.zeros_like( + p._local_shard, + device=self.compute_device, + dtype=self.mixed_precision.param_dtype + ) + _free_storage(p._mp_shard) + + # We also maintain a full-sized parameter of type self.compute_dtype. + # We resize the storage to size 0 at init (here) and only materialize + # as needed. The storage may contain padding elements so that it is + # evenly divisible by world_size, although these padding elements will + # be removed before the relevant computation. + if handle.uses_sharded_strategy: # type: ignore[attr-defined] + # We set p._full_param_padded's dtype to the desired parameter dtype + # in the case of mixed precision. This is so that when we all_gather + # into full_param_padded it can occur without issues and result in + # full_param_padded having the expected param_dtype. + full_param_dtype = ( + p.dtype if not self._mixed_precision_enabled_for_params() + else self.mixed_precision.param_dtype + ) + p._full_param_padded = torch.zeros( # type: ignore[attr-defined] + p.numel() * self.world_size, + device=self.compute_device, + dtype=full_param_dtype, + ) + p._padded_unsharded_size = p._full_param_padded.size() # type: ignore[attr-defined] + _free_storage(p._full_param_padded) # type: ignore[attr-defined] + + if self._mixed_precision_enabled_for_params(): + p._full_prec_full_param_padded = torch.zeros( # type: ignore[attr-defined] + p.numel() * self.world_size, + device=self.compute_device, + dtype=p.dtype, # full precision + ) + _free_storage(p._full_prec_full_param_padded) + + # Track whether the `FlatParameter`'s post-backward hook has been + # called for validation in `_wait_for_post_backward()` + p._post_backward_called = False + + def _init_streams(self) -> None: + """Initializes NPU streams for overlapping data transfer and + computation. This should only be called on the root FSDP instance.""" + assert self._is_root + assert torch_npu.npu.is_available() + # Stream for all-gathering parameters. + self._streams["all_gather"] = torch_npu.npu.Stream() + # Stream for overlapping grad reduction with the backward pass. + self._streams["post_backward"] = torch_npu.npu.Stream() + # Stream for pre-all-gather copies (e.g. H2D or precision cast). + self._streams["pre_all_gather"] = torch_npu.npu.Stream() + + def _wait_for_previous_optim_step(self) -> None: + """ + The root :class:`FullyShardedDataParallel` instance needs to + synchronize with the default stream to ensure that the previous + optimizer step is done. + """ + if not self._is_root: + return + current_stream = torch_npu.npu.current_stream() + self._streams["all_gather"].wait_stream(current_stream) + # Having the pre-all-gather stream wait for the current stream even if + # we do not leverage the pre-all-gather stream is tolerable since this + # only runs once per iteration + self._streams["pre_all_gather"].wait_stream(current_stream) + + def _prefetch_handles( + self, + current_handles_key: _HandlesKey, + ) -> None: + """ + Prefetches the next handles if needed (without synchronization). An + empty handles key cannot prefetch. + """ + if not current_handles_key: + return + handles_to_prefetch = self._get_handles_to_prefetch(current_handles_key) + for handles_key in handles_to_prefetch: + # Prefetch the next set of handles without synchronizing to allow + # the sync to happen as late as possible to maximize overlap + self._unshard(handles_key) + self._handles_prefetched[handles_key] = True + + def _get_handles_to_prefetch( + self, + current_handles_key: _HandlesKey, + ) -> List[_HandlesKey]: + """ + Returns a :class:`list` of the handles keys to prefetch for the next + module(s), where ``current_handles_key`` represents the current module. + + "Prefetching" refers to running the unshard logic early (without + synchronization), and the "next" modules depend on the recorded + execution order and the current training state. + """ + training_state = self._get_training_state(current_handles_key) + valid_training_states = ( + HandleTrainingState.BACKWARD_PRE, + HandleTrainingState.BACKWARD_POST, + HandleTrainingState.FORWARD, + ) + p_assert( + training_state in valid_training_states, + f"Prefetching is only supported in {valid_training_states} but " + f"currently in {training_state}" + ) + eod = self._exec_order_data + target_handles_keys: List[_HandlesKey] = [] + if ( + ( + training_state == HandleTrainingState.BACKWARD_PRE + and self.backward_prefetch == BackwardPrefetch.BACKWARD_PRE + ) + or ( + training_state == HandleTrainingState.BACKWARD_POST + and self.backward_prefetch == BackwardPrefetch.BACKWARD_POST + ) + ): + target_handles_keys = [ + target_handles_key for target_handles_key in + eod.get_handles_to_backward_prefetch(current_handles_key) + if self._needs_pre_backward_unshard.get(target_handles_key, False) + and not self._handles_prefetched.get(target_handles_key, False) + ] + elif ( + training_state == HandleTrainingState.FORWARD + and self.forward_prefetch + ): + target_handles_keys = [ + target_handles_key for target_handles_key in + eod.get_handles_to_forward_prefetch(current_handles_key) + if self._needs_pre_forward_unshard.get(target_handles_key, False) + and not self._handles_prefetched.get(target_handles_key, False) + ] + return target_handles_keys + + def _get_training_state( + self, + handles_key: _HandlesKey, + ) -> HandleTrainingState: + """Returns the training state of the handles in ``handles_key``.""" + p_assert(len(handles_key) > 0, "Expects a non-empty handles key") + training_states = set(handle._training_state for handle in handles_key) + p_assert( + len(training_states) == 1, + f"Expects uniform training state but got {training_states}" + ) + return next(iter(training_states)) + + @staticmethod + @contextlib.contextmanager + def state_dict_type( + module: nn.Module, + state_dict_type: StateDictType, + state_dict_config: Optional[StateDictConfig] = None, + ) -> Generator: + """ + A context manager to set the ``state_dict_type`` of all the descendant + FSDP modules of the target module. The target module does not have to + be a FSDP module. If the target module is a FSDP module, its + ``state_dict_type`` will also be changed. + + .. note:: This API should be called for only the top-level (root) + module. + + .. note:: This API enables users to transparently use the conventional + ``state_dict`` API to take model checkpoints in cases where the + root FSDP module is wrapped by another ``nn.Module``. For example, + the following will ensure ``state_dict`` is called on all non-FSDP + instances, while dispatching into `local_state_dict` implementation + for FSDP: + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> model = DDP(FSDP(...)) + >>> with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): + >>> checkpoint = model.state_dict() + + Args: + module (torch.nn.Module): Root module. + state_dict_type (StateDictType): the desired ``state_dict_type`` to set. + """ + prev_state_dict_type = None + prev_state_dict_config = None + # Use default config a state_dict config is not set. + if state_dict_config is None: + state_dict_config = _state_dict_type_to_config[state_dict_type]() + for submodule in FullyShardedDataParallel.fsdp_modules(module): + if prev_state_dict_type is None: + prev_state_dict_type = submodule._state_dict_type + if prev_state_dict_config is None: + prev_state_dict_config = submodule._state_dict_config + if prev_state_dict_type != submodule._state_dict_type: + raise RuntimeError("All FSDP module should the same state_dict_type.") + if type(prev_state_dict_config) != type(submodule._state_dict_config): + raise RuntimeError( + "All FSDP modules should have the same type of state_dict_config." + ) + + expected_state_dict_config_type = _state_dict_type_to_config[state_dict_type] + if expected_state_dict_config_type != type(state_dict_config): + raise RuntimeError( + f"Expected state_dict_config of type {expected_state_dict_config_type} but got {type(state_dict_config)}" + ) + submodule._state_dict_type = state_dict_type + submodule._state_dict_config = state_dict_config + try: + yield + finally: + assert prev_state_dict_type is not None # Avoid mypy warning + assert prev_state_dict_config is not None # Avoid mypy warning + for submodule in FullyShardedDataParallel.fsdp_modules(module): + submodule._state_dict_type = prev_state_dict_type + submodule._state_dict_config = prev_state_dict_config + + def _convert_to_wrapped_module_name(self, module_name: str) -> str: + module_name = module_name.replace(f"{FPW_MODULE}.", "") + module_name = module_name.replace(f"{FPW_MODULE}", "") + if module_name: + module_name = f"{module_name}." + # Activation checkpoint adds a prefix that has to be + # removed as well. + module_name = module_name.replace( + f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", "" + ) + return module_name + + @property + def _param_fqns(self) -> Iterator[Tuple[str, str, str]]: + for param_name, module_name in ( + self._fsdp_wrapped_module.handle.parameter_module_names() + ): + module_name = self._convert_to_wrapped_module_name(module_name) + fqn = f"{module_name}{param_name}" + yield fqn, param_name, module_name + + def _full_post_state_dict_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + ) -> Dict[str, Any]: + """ + Hook that runs after model.state_dict() is called before returning result to + user. For FSDP, we may have to clone the tensors in state_dict as params go + back to sharded version after _summon_full_params ends, and also remove + "_fsdp_wrapped_module" prefix. + """ + _replace_by_prefix(state_dict, prefix + f"{FSDP_WRAPPED_MODULE}.", prefix) + self._assert_state([TrainingState_.SUMMON_FULL_PARAMS]) + # Return early for trivial cases + if not state_dict or not self._fsdp_wrapped_module.has_params: + return state_dict + + # If the `FlatParameter` is registered, then this rank only needed to + # participate in the all-gather but does not actually save the state + # dict (e.g. when `rank0_only=True` and `self.rank != 0`) + if hasattr(self._fsdp_wrapped_module, "flat_param"): + return state_dict + + offload_to_cpu = self._state_dict_config.offload_to_cpu + cpu_device = torch.device("cpu") + + # Loop only the parameters saved in self._fsdp_wrapped_module to avoid + # processing buffers. + for fqn, param_name, module_name in self._param_fqns: + fqn = f"{prefix}{fqn}" + clean_key = fqn + clean_prefix = clean_tensor_name(prefix) + # Strip prefix out of key if needed as buffer names and param names + # do not have prefix considered as they are not computed in `state_dict` + # call. + if clean_key.startswith(clean_prefix): + clean_key = clean_key[len(clean_prefix):] + + # Clone non-ignored parameters before exiting the + # `_summon_full_params()` context + assert fqn in state_dict, ( + f"FSDP assumes {fqn} is in the state_dict but the state_dict " + f"only has {state_dict.keys()}. prefix={prefix}, " + f"module_name={module_name} param_name={param_name} rank={self.rank}." + ) + if clean_key not in self._ignored_param_names and \ + not getattr(state_dict[fqn], "_has_been_cloned", False): + try: + state_dict[fqn] = state_dict[fqn].cpu().clone().detach() + state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined] + except BaseException as e: + warnings.warn( + f"Failed to clone() tensor with name {fqn}. This may mean " + "that this state_dict entry could point to invalid memory " + "regions after returning from state_dict() call if this " + "parameter is managed by FSDP. Please check clone " + f"implementation of {fqn}. Error: {str(e)}" + ) + + # Offload the buffer to CPU if needed -- we do not do this in + # `_summon_full_params()` since without care, that would free + # the original buffer's NPU memory and require reallocating + # that memory later; this only affects the state dict's buffer + # variable and leaves the original buffer's NPU memory intact + if offload_to_cpu: + for clean_key in self._buffer_names: + # This is a hack to support activation checkpoint. + clean_key = clean_key.replace( + f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", "" + ) + fqn = f"{prefix}{clean_key}" + if fqn not in state_dict: + # A buffer can be registered as non-persistent. + continue + if state_dict[fqn].device != cpu_device: + state_dict[fqn] = state_dict[fqn].to(cpu_device) + return state_dict + + def _local_post_state_dict_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + ) -> Dict[str, Any]: + """ + This hook create a ShardedTensor from the local flat_param and replace + the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy + will happen. The underlying storage is the same. + """ + _replace_by_prefix(state_dict, f"{prefix}{FSDP_WRAPPED_MODULE}.", prefix) + if not self._fsdp_wrapped_module.has_params: + return state_dict + + # state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor + # value as the flat_param but it is a pure Tensor because + # nn.Module.state_dict() will detach the parameter. Therefore, we need + # to get flat_param from the FlattenParamsWrapper to get the metadata. + flat_param = getattr(self._fsdp_wrapped_module, FLAT_PARAM, None) + assert flat_param is not None + # Construct a ShardedTensor from the flat_param. + full_numel = flat_param._unpadded_unsharded_size.numel() # type: ignore[attr-defined] + shard_offset = flat_param.numel() * self.rank + valid_data_size = flat_param.numel() - flat_param._shard_numel_padded + if valid_data_size > 0 and flat_param._shard_numel_padded > 0: + flat_param = flat_param.narrow(0, 0, valid_data_size) + local_shards = [ + Shard.from_tensor_and_offsets(flat_param, [shard_offset], self.rank) + ] + state_dict[f"{prefix}{FLAT_PARAM}"] = init_from_local_shards( + local_shards, full_numel, process_group=self.process_group + ) # type: ignore[assignment] + + return state_dict + + @torch.no_grad() + def _sharded_post_state_dict_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + ) -> Dict[str, Any]: + """ + The hook replaces the unflattened, unsharded parameter in the state_dict + with a unflattened, sharded parameter (a ShardedTensor). + """ + _replace_by_prefix(state_dict, f"{prefix}{FSDP_WRAPPED_MODULE}.", prefix) + if not self._fsdp_wrapped_module.has_params: + return state_dict + + assert self.training_state != TrainingState_.SUMMON_FULL_PARAMS, ( + "Inside _sharded_post_load_state_dict_hook, the training_state must " + "not be SUMMON_FULL_PARAMS." + ) + with self._summon_full_params(recurse=False, writeback=False): + for fqn, _, _ in self._param_fqns: + # Create a ShardedTensor for the unflattened, non-sharded parameter. + param = functools.reduce(getattr, fqn.split("."), self.module) + state_dict[f"{prefix}{fqn}"] = _ext_chunk_tensor( + tensor=param, + rank=self.rank, + world_size=self.world_size, + num_devices_per_node=torch_npu.npu.device_count(), + pg=self.process_group + ) # type: ignore[assignment] + state_dict.pop(f"{prefix}{FLAT_PARAM}") + return state_dict + + @staticmethod + def _post_state_dict_hook( + module: nn.Module, + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> Dict[str, Any]: + """ + _post_state_dict_hook() is called after the state_dict() of this + FSDP module is executed. ``self._state_dict_type`` is used to decide + what postprocessing will be done. + """ + self = cast(FullyShardedDataParallel, module) + processed_state_dict = self._post_state_dict_hook_fn[self._state_dict_type](state_dict, prefix) + # Restore buffers, which currently are in their full precision type, + # back to their mixed precision type. This is because buffers are cast + # during lazy_init() and stay at their mixed precision type before/after + # forward/backward. As a result state_dict() should maintain this. + if ( + self._is_root + and self._mixed_precision_enabled_for_buffers() + ): + self._cast_buffers(recurse=True) + return processed_state_dict + + def state_dict(self, *args, **kwargs): + """ + This is the entry point of all three FSDP ``state_dict`` APIs: full, + local, and sharded. For the full state dict + (``StateDictType.FULL_STATE_DICT``), FSDP attempts to unshard the model + on all ranks, which may result in an OOM error if the full model cannot + fit on a single NPU. In that case, users may pass in a + :class:`FullStateDictConfig` to only save the checkpoint on rank 0 and/ + or to offload it to CPU memory layer by layer, enabling much larger + checkpoints. If the full model cannot fit in CPU memory, then users may + instead take a local state dict (``StateDictType.LOCAL_STATE_DICT``) + that only saves the local shard of the model. The sharded state dict + (``StateDictType.SHARDED_STATE_DICT``) saves the model parameters as + ``ShardedTensor`` s. The ``state_dict`` type can be configured using + the :meth:`state_dict_type` context manager. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> import torch + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> from torch.distributed.fsdp import StateDictType + >>> torch_npu.npu.set_device(device_id) + >>> my_module = nn.Linear(...) + >>> sharded_module = FSDP(my_module) + >>> full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + >>> with FSDP.state_dict_type(sharded_module, StateDictType.FULL_STATE_DICT, full_state_dict_config): + >>> full_dict = sharded_module.state_dict() + >>> full_dict.keys() + >>> odict_keys(['weight', 'bias']) + >>> # using local state dict + >>> with FSDP.state_dict_type(sharded_module, StateDictType.LOCAL_STATE_DICT): + >>> local_dict = sharded_module.state_dict() + >>> local_dict.keys() + >>> odict_keys(['flat_param', 'inner.flat_param']) + + .. warning:: This needs to be called on all ranks, since synchronization + primitives may be used. + """ + # TODO (rohan-varma): separate these out once a state_dict pre-hook + # is available. + if torch_npu.npu.is_available(): + torch_npu.npu.synchronize() + self._lazy_init() + if self._state_dict_type == StateDictType.FULL_STATE_DICT: + # Get config args + full_state_dict_config = ( + self._state_dict_config if self._state_dict_config is not None + else FullStateDictConfig() + ) + rank0_only = full_state_dict_config.rank0_only + offload_to_cpu = full_state_dict_config.offload_to_cpu + summon_ctx = ( + self._summon_full_params( + recurse=False, writeback=False, offload_to_cpu=offload_to_cpu, rank0_only=rank0_only + ) + if self.training_state != TrainingState_.SUMMON_FULL_PARAMS else + contextlib.suppress() + ) + with summon_ctx: + # Since buffers are not sharded and stay casted, restore them to their + # original user module specified types for checkpoint. We take care to + # recast in post_state_dict_hook for consistency with the fact that + # buffers stay casted after forward/backward. We must have the + # call here instead of above because _summon_full_params itself + # calls _lazy_init() which would cast the buffers. + if ( + self._is_root + and self._mixed_precision_enabled_for_buffers() + ): + self._cast_buffers( + dtype=self._buffer_name_to_orig_dtype, recurse=False + ) + state_dict = super().state_dict(*args, **kwargs) + + # TODO: support offload to CPU in post state dict hook. + if not rank0_only or self.rank == 0: + return state_dict + else: + return {} + + elif ( + self._state_dict_type == StateDictType.LOCAL_STATE_DICT or + self._state_dict_type == StateDictType.SHARDED_STATE_DICT + ): + if ( + self._fsdp_wrapped_module.flat_param is not None and + not self._fsdp_wrapped_module.handle.uses_sharded_strategy + ): + raise RuntimeError( + "sharded_state_dict/local_state_dict can only be called " + "when parameters are flatten and sharded." + ) + return super().state_dict(*args, **kwargs) + else: + raise ValueError(f"Unknown StateDictType {self._state_dict_type}.") + + def _local_state_dict(self, *args: Any, **kwargs: Any) -> Any: + """ + Returns the local state of the module. Parameters are flattened and + sharded, so the resulting state_dict can only be loaded after the module + has been wrapped with FSDP. + """ + with self.state_dict_type(self, StateDictType.LOCAL_STATE_DICT): + return self.state_dict(*args, **kwargs) + + def _full_post_load_state_dict_hook(self, *args, **kwargs) -> None: + # We should exit summon_full_params context. + self._assert_state([TrainingState_.SUMMON_FULL_PARAMS]) + assert getattr(self, '_full_param_ctx', None) is not None + self._full_param_ctx.__exit__(None, None, None) + self._full_param_ctx = None + + def _sharded_state_dict(self, *args: Any, **kwargs: Any) -> Any: + """ + Returns the sharded states of the module. Parameters are unflattened and + sharded, so the resulting state_dict can be used with any parallelism + (e.g., DPP, model parallelism, and single trainer) after a valid + resharding. + """ + with self.set_state_dict_type(StateDictType.SHARDED_STATE_DICT): + return self.state_dict(self, *args, **kwargs) + + def _full_pre_load_state_dict_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + ) -> None: + # We do not expect to be calling pre-hooks twice without post-hook + # call in between. + assert getattr(self, '_full_param_ctx', None) is None + # Note that it needs writeback=True to persist. + self._full_param_ctx = self._summon_full_params( + recurse=False, writeback=True + ) + self._full_param_ctx.__enter__() + _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_WRAPPED_MODULE}.") + + def _local_post_load_state_dict_hook(self, *args, **kwargs) -> None: + pass + + def _local_pre_load_state_dict_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + ) -> None: + """ + This hook finds the local flat_param for this FSDP module from the + state_dict. The flat_param should be a ShardedTensor. This hook converts + the ShardedTensor to a tensor. No copy happen unless padding is required. + """ + _replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_WRAPPED_MODULE}.") + fqn = f"{prefix}{FSDP_WRAPPED_MODULE}.{FLAT_PARAM}" + if fqn not in state_dict: + assert getattr(self._fsdp_wrapped_module, FLAT_PARAM, None) is None, ( + "No flat parameter in state_dict but self._fsdp_wrapped_module.flat_param is not None" + ) + return + load_tensor = state_dict[fqn] + assert isinstance( + load_tensor, ShardedTensor + ), "Tensors in local_state_dict should be ShardedTensor." + + # Convert the ShardedTensor to a Tensor. + shards = load_tensor.local_shards() + assert len(shards), "load_local_state_dict assume one shard per ShardedTensor." + load_tensor = cast(torch.Tensor, shards[0].tensor) + + # Get the metada of the flat_param to decide whether to pad the loaded + # tensor. + flat_param = self._fsdp_wrapped_module.flat_param + assert flat_param is not None + if flat_param._shard_numel_padded not in (0, flat_param.numel()): + assert load_tensor.numel() < flat_param.numel(), ( + f"Local shard size = {flat_param.numel()} and the tensor in " + f"the state_dict is {load_tensor.numel()}." + ) + load_tensor = F.pad(load_tensor, [0, flat_param._shard_numel_padded]) + state_dict[fqn] = load_tensor + + def _sharded_post_load_state_dict_hook(self, *args, **kwargs) -> None: + pass + + def _sharded_pre_load_state_dict_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + ) -> None: + """ + The hook combines the unflattened, sharded parameters (ShardedTensor) to + a new FlatParameter and shards the new FlatParameter to the local chunk. + """ + _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_WRAPPED_MODULE}.") + if not self._fsdp_wrapped_module.has_params: + return + + if not self._fsdp_wrapped_module.handle.uses_sharded_strategy: + raise RuntimeError( + "load_sharded_state_dict can only be called when parameters " + "are flatten and sharded." + ) + + nonsharded_tensors = [] + # TODO: Reduce the communication by using only one _all_gather_base to + # gather all the parameters in this layer. This can be achieved by + # concatenated all the local shards and then append the padding. + # https://github.com/pytorch/pytorch/issues/77461 + for (param_name, _, module_name) in self._fsdp_wrapped_module.handle.flat_param._param_infos: + module_name = self._convert_to_wrapped_module_name(module_name) + fqn = f"{prefix}{FSDP_WRAPPED_MODULE}.{module_name}{param_name}" + param = state_dict.pop(fqn) + + # All-gather the param (ShardedTensor) + param, shards = _ext_pre_load_state_dict_transform(param) + assert len(shards) < 2, ( + f"Expects 0 or 1 shard per rank but got {len(shards)} shards on rank {self.rank}" + ) + param_numel = param.size().numel() + dim_0_size = param.size()[0] + chunk_size = ( + math.ceil(dim_0_size / self.world_size) * param_numel // dim_0_size + ) + if shards: + local_tensor = cast(torch.Tensor, shards[0].tensor).flatten() + if not local_tensor.is_npu: + local_tensor = local_tensor.npu() + num_padding = chunk_size - local_tensor.numel() + if num_padding > 0: + local_tensor = F.pad(local_tensor, [0, num_padding]) + else: + local_tensor = torch.zeros(chunk_size, dtype=param.dtype).npu() + tensor = torch.empty( + chunk_size * self.world_size, dtype=local_tensor.dtype + ).npu() + dist._all_gather_base(tensor, local_tensor, group=self.process_group) + tensor = tensor.narrow(0, 0, param_numel).reshape(param.size()) + nonsharded_tensors.append(tensor) + + # Create a new flat_param from the loaded, non-sharded tensors. + flat_param = self._fsdp_wrapped_module.flat_param + loaded_flat_param = FlatParamHandle.flatten_params(nonsharded_tensors, requires_grad=False) + + # Get the chunk from the loaded flat_param for the local rank. + loaded_flat_param, num_to_pad = FlatParamHandle._get_shard( + loaded_flat_param, self.rank, self.world_size, + ) + loaded_flat_param.to(flat_param.device) + assert flat_param.numel() == loaded_flat_param.numel(), ( + f"The loaded local chunk has different numel({flat_param.numel()}) " + f"from the local chunk {flat_param.numel()}." + ) + assert flat_param._shard_numel_padded == num_to_pad, ( + f"The loaded local chunk has different padding({num_to_pad}) " + f"from the local chunk {flat_param._shard_numel_padded}." + ) + state_dict[f"{prefix}_fsdp_wrapped_module.flat_param"] = loaded_flat_param + + @staticmethod + def _pre_load_state_dict_hook( + module: nn.Module, + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> None: + """ + ``_pre_state_dict_hook` is called before ``self._load_from_state_dict()`` + is called. ``self._state_dict_type`` is used to decide what preprocessing + will be done. + """ + # Code that is common for all state_dict impls + self = cast(FullyShardedDataParallel, module) + if torch_npu.npu.is_available(): + torch_npu.npu.synchronize() + # Dispatch into state_dict specific implementation of pre-hook. + self._pre_load_state_dict_hook_fn[self._state_dict_type](state_dict, prefix) + + @staticmethod + def _post_load_state_dict_hook(module: nn.Module, *args: Any) -> None: + # Code that is common for all state_dict impls + self = cast(FullyShardedDataParallel, module) + # Dispatch into state_dict type specific implementation of post-hook for + # loading state_dict. + self._post_load_state_dict_hook_fn[self._state_dict_type]() + + def load_state_dict( + self, + state_dict: Mapping[str, Any], + *args, + **kwargs, + ) -> NamedTuple: + """ + The entry point of all three FSDP ``load_state_dict`` APIs. By default, + calling ``load_state_dict`` on an FSDP module will result in FSDP + attempting to load a "full" state_dict, i.e. a state_dict consisting of + full, unsharded, unflattened original module parameters. This requires + FSDP to load the full parameter context on each rank which could result + in NPU OOM. As a result, :func:`state_dict_type` API is available to + configure between ``load_state_dict`` implementations. User can thus use + ``with self.state_dict_type(self, StateDictType.LOCAL_STATE_DICT)`` context + manager to load a local state dict checkpoint that will restore only + local shards of the module. Currently, the only supported + implementations are ``StateDictType.LOCAL_STATE_DICT`` and + ``StateDictType.FULL_STATE_DICT`` (default). Please see :func:`state_dict` + for documentation around creating an FSDP checkpoint. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> import torch + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> from torch.distributed.fsdp import StateDictType + >>> torch_npu.npu.set_device(device_id) + >>> my_module = nn.Linear(...) + >>> sharded_module = FSDP(my_module) + >>> checkpoint = torch.load(PATH) + >>> full_state_dict = checkpoint['full_state_dict'] + >>> with FSDP.state_dict_type(sharded_module, StateDictType.FULL_STATE_DICT): + >>> sharded_module.load_state_dict(full_state_dict) + >>> full_dict.keys() + >>> odict_keys(['weight', 'bias']) + >>> # using local state dict + >>> local_state_dict = checkpoint['local_state_dict'] + >>> with FSDP.state_dict_type(sharded_module, StateDictType.LOCAL_STATE_DICT): + >>> sharded_module.load_state_dict(local_state_dict) + >>> local_dict.keys() + >>> odict_keys(['flat_param', 'inner.flat_param']) + + .. warning:: This needs to be called on all ranks, since synchronization + primitives may be used. + """ + return super().load_state_dict(state_dict, *args) + + def _load_local_state_dict( + self, + state_dict: Mapping[str, Any], + *args, + ) -> NamedTuple: + """ + Load states from a flattened, sharded state dictionary. + """ + with self.state_dict_type(self, StateDictType.LOCAL_STATE_DICT): + return self.load_state_dict(state_dict, *args) + + def _load_sharded_state_dict( + self, + state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], + strict: bool = True, + ) -> NamedTuple: + """ + Load states from a unflattened, sharded state dictionary. + """ + with self.set_state_dict_type(StateDictType.SHARDED_STATE_DICT): + return self.load_state_dict(state_dict, strict) + + def forward(self, *args: Any, **kwargs: Any) -> Any: + """ + Runs the forward pass for the wrapped module, inserting FSDP-specific + pre- and post-forward sharding logic. + """ + with torch.autograd.profiler.record_function("FullyShardedDataParallel.forward"): + self._lazy_init() + args, kwargs = self._fsdp_root_pre_forward(*args, **kwargs) + unused = None + unshard_fn = functools.partial(self._pre_forward_unshard, handles=self._handles) + # Do not free the root's parameters in the post-forward for + # `FULL_SHARD` with the intention that they are immediately used + # for backward computation (though this may not be true) + free_unsharded_flat_params = [ + handle._config.sharding_strategy == HandleShardingStrategy.FULL_SHARD + for handle in self._handles + ] + reshard_fn = functools.partial( + self._reshard, + self._handles, + free_unsharded_flat_params, + ) + self._pre_forward(self._handles, unshard_fn, unused, unused) + for handle in self._handles: + p_assert( + handle.flat_param.device == self.compute_device, + "Expected `FlatParameter` to be on the compute device " + f"{self.compute_device} but got {handle.flat_param.device}" + ) + output = self._fsdp_wrapped_module(*args, **kwargs) + return self._post_forward(self._handles, reshard_fn, unused, unused, output) + + def _pre_forward( + self, + handles: List[FlatParamHandle], + unshard_fn: Optional[Callable], + module: nn.Module, + input: Any, + ): + """ + Runs the pre-forward logic. This includes an opportunity to unshard + currently sharded parameters such as those for the current forward and + registering post-backward hooks for these current parameters. + + Args: + handles (List[FlatParamHandle]): Handles giving the parameters + used in the current forward. + unshard_fn (Optional[Callable]): A callable to unshard any + currently sharded parameters or ``None`` to not do any + unsharding. + module (nn.Module): Unused; expected by the hook signature. + input (Any): Unused; expected by the hook signature. + """ + self.training_state = TrainingState_.FORWARD + self._exec_order_data.record_pre_forward(handles, self.training) + for handle in handles: + handle._training_state = HandleTrainingState.FORWARD + if unshard_fn is not None: + unshard_fn() + # Register post-backward hooks to reshard the parameters and + # reduce-scatter their gradients. They must be re-registered every + # forward pass in case the `grad_fn` is mutated. + self._register_post_backward_hooks(handles) + + def _pre_forward_unshard( + self, + handles: List[FlatParamHandle], + ) -> None: + """Unshards parameters in the pre-forward.""" + if handles: + self._unshard(handles) + handles_key = tuple(handles) + self._needs_pre_forward_unshard[handles_key] = False + torch_npu.npu.current_stream().wait_stream(self._streams["all_gather"]) + self._prefetch_handles(handles_key) + + def _post_forward( + self, + handles: List[FlatParamHandle], + reshard_fn: Optional[Callable], + module: nn.Module, + input: Any, + output: Any, + ) -> Any: + """ + Runs the post-forward logic. This includes an opportunity to reshard + currently unsharded parameters such as those used in the current + forward and registering pre-backward hooks on the forward outputs. + + Args: + handles (List[FlatParamHandle]): Handles giving the parameters + used in the current forward. + reshard_fn (Optional[Callable]): A callable to reshard any + currently unsharded parameters (e.g. from the current forward) + or ``None`` to not do any resharding. + module (nn.Module): Unused; expected by the hook signature. + input (Any): Unused; exepcted by the hook signature. + output (Any): Forward pass output; pre-backward hooks are + registered on the tensors that require gradients in this + output. + + Postcondition: Each ``FlatParameter`` 's data points to the sharded + flattened parameter. + """ + self._exec_order_data.record_post_forward(handles) + if reshard_fn is not None: + reshard_fn() + # Register pre-backward hooks to unshard the flattened parameters + # for the gradient computation (if needed) + output = self._register_pre_backward_hooks(output, handles) + self.training_state = TrainingState_.IDLE + for handle in handles: + handle._training_state = HandleTrainingState.IDLE + return output + + def _cast_forward_inputs(self, *args, **kwargs): + """Moves the forward inputs to the compute device and casts them to the + appropriate dtype if needed.""" + # TODO: Do not use the side stream for tensor copies for now; + # investigate the perf with/without it + # TODO: For mixed precision, move the inputs to the compute device and + # cast to reduced-precision in a single `to()` call + args, kwargs = _to_kwargs(args, kwargs, self.compute_device.index, False) + args = args[0] + kwargs = kwargs[0] + if self._mixed_precision_enabled_for_params(): + input_dtype = self.mixed_precision.param_dtype + args, kwargs = self._cast_fp_inputs_to_dtype( + input_dtype, *args, **kwargs, + ) + return args, kwargs + + def _fsdp_root_pre_forward(self, *args, **kwargs): + """ + Runs pre-forward logic specific to the root FSDP instance, which should + run before any individual module's pre-forward. This includes + synchronizing with the previous iteration and casting the forward + inputs appropriately. If this is called on a non-root FSDP instance, + then the forward inputs are returned directly. + """ + p_assert(self._is_root is not None, "Expects a root FSDP to have been set") + if not self._is_root: + return args, kwargs + if self.forward_prefetch: + for fsdp_module in self.fsdp_modules(self): + handles_key = tuple(fsdp_module._handles) + if handles_key: + self._needs_pre_forward_unshard[handles_key] = True + self._wait_for_previous_optim_step() + args, kwargs = self._cast_forward_inputs(*args, **kwargs) + return args, kwargs + + @staticmethod + @contextlib.contextmanager + def summon_full_params( + module, + recurse: bool = True, + writeback: bool = True, + rank0_only: bool = False, + offload_to_cpu: bool = False, + ) -> Generator: + r""" A context manager to expose full params for FSDP instances. + Can be useful *after* forward/backward for a model to get + the params for additional processing or checking. It can take a non-FSDP + module and will summon full params for all contained FSDP modules as + well as their children, depending on the ``recurse`` argument. + + .. note:: This can be used on inner FSDPs. + .. note:: This can *not* be used within a forward or backward pass. Nor + can forward and backward be started from within this context. + .. note:: Parameters will revert to their local shards after the context + manager exits, storage behavior is the same as forward. + .. note:: The full parameters can be modified, but only the portion + corresponding to the local param shard will persist after the + context manager exits (unless ``writeback=False``, in which case + changes will be discarded). In the case where FSDP does not shard + the parameters, currently only when ``world_size == 1``, or ``NO_SHARD`` + config, the modification is persisted regardless of ``writeback``. + .. note:: This method works on modules which are not FSDP themselves but + may contain multiple independent FSDP units. In that case, the given + arguments will apply to all contained FSDP units. + + .. warning:: Note that ``rank0_only=True`` in conjunction with + ``writeback=True`` is not currently supported and will raise an + error. This is because model parameter shapes would be different + across ranks within the context, and writing to them can lead to + inconsistency across ranks when the context is exited. + + .. warning:: Note that ``offload_to_cpu`` and ``rank0_only=False`` will + result in full parameters being redundantly copied to CPU memory for + NPUs that reside on the same machine, which may incur the risk of + CPU OOM. It is recommended to use ``offload_to_cpu`` with + ``rank0_only=True``. + + Args: + recurse (bool, Optional): recursively summon all params for nested + FSDP instances (default: True). + writeback (bool, Optional): if ``False``, modifications to params are + discarded after the context manager exits; + disabling this can be slightly more efficient (default: True) + rank0_only (bool, Optional): if ``True``, full parameters are + materialized on only global rank 0. This means that within the + context, only rank 0 will have full parameters and the other + ranks will have sharded parameters. Note that setting + ``rank0_only=True`` with ``writeback=True`` is not supported, + as model parameter shapes will be different across ranks + within the context, and writing to them can lead to + inconsistency across ranks when the context is exited. + offload_to_cpu (bool, Optional): If ``True``, full parameters are + offloaded to CPU. Note that this offloading currently only + occurs if the parameter is sharded (which is only not the case + for world_size = 1 or ``NO_SHARD`` config). It is recommended + to use ``offload_to_cpu`` with ``rank0_only=True`` to avoid + redundant copies of model parameters being offloaded to the same CPU memory. + """ + # Note that we specify root_only as FSDP roots will handle summoning + # child FSDP instances based on recurse argument. + root_fsdp_modules = FullyShardedDataParallel.fsdp_modules( + module, root_only=True + ) + # Summon all params for all FSDP instances + with contextlib.ExitStack() as stack: + for module in root_fsdp_modules: + stack.enter_context( + module._summon_full_params( + recurse=recurse, + writeback=writeback, + rank0_only=rank0_only, + offload_to_cpu=offload_to_cpu, + ) + ) + # Yield to the caller, with full params in all FSDP instances. + yield + # Exiting from the ExitStack will reshard all params. + return + + @contextlib.contextmanager + def _summon_full_params( + self, + recurse: bool = True, + writeback: bool = True, + rank0_only: bool = False, + offload_to_cpu: bool = False, + ): + if writeback and rank0_only: + raise ValueError( + "writeback=True and rank0_only=True is not supported, as model " + "parameter shapes will be different across ranks, and writing " + "to them can lead to inconsistencies across ranks when the " + "context is exited." + ) + if offload_to_cpu and not rank0_only: + warnings.warn( + "offload_to_cpu and rank0_only=False will result in " + "full parameters being redundantly copied to CPU memory for " + "NPUs that reside on the same machine, which may incur the risk of " + "CPU OOM. It is recommended to use ``offload_to_cpu`` with " + "rank0_only=True." + ) + + if recurse: + with contextlib.ExitStack() as stack: + for module in self.fsdp_modules(self): + stack.enter_context( + module._summon_full_params( + recurse=False, + writeback=writeback, + rank0_only=rank0_only, + offload_to_cpu=offload_to_cpu, + ) + ) + yield + return + + torch_npu.npu.synchronize() + self._lazy_init() + self._assert_state([TrainingState_.IDLE]) + for handle in self._handles: + assert handle._training_state == HandleTrainingState.IDLE + self.training_state = TrainingState_.SUMMON_FULL_PARAMS + for handle in self._handles: + handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS + + free_unsharded_flat_params = [handle.needs_unshard() for handle in self._handles] + self._unshard(self._handles) + torch_npu.npu.current_stream().wait_stream(self._streams["all_gather"]) + + if rank0_only and self.rank != 0: + # Free the unsharded flattened parameter early + self._reshard(self._handles, free_unsharded_flat_params) + try: + yield + finally: + self.training_state = TrainingState_.IDLE + for handle in self._handles: + handle._training_state = HandleTrainingState.IDLE + else: + # Unflatten the unsharded flattened parameters + with contextlib.ExitStack() as stack: + # Invariant: rank == 0 or !rank0_only + for handle in self._handles: + if offload_to_cpu and handle.uses_sharded_strategy: + stack.enter_context(handle.to_cpu()) + # TODO (awgu): This FPW call assumes 1 `FlatParameter` + stack.enter_context(self._fsdp_wrapped_module.unflatten_as_params()) + try: + yield + finally: + stack.close() + if writeback: + self._write_back_to_local_shard(self._handles) + self._reshard(self._handles, free_unsharded_flat_params) + self.training_state = TrainingState_.IDLE + for handle in self._handles: + handle._training_state = HandleTrainingState.IDLE + + @torch.no_grad() + def _write_back_to_local_shard(self, handles: List[FlatParamHandle]): + """ + For each handle, writes back the this rank's shard of the unsharded + flattened parameter to the sharded flattened parameter. + + Precondition: Each handle's ``FlatParameter`` 's data points to the + padded unsharded flattened parameter. + """ + for handle in handles: + # For `NO_SHARD`, `_local_shard` is the unsharded flattened + # parameter as well + if not handle.uses_sharded_strategy: + continue + assert ( + handle.flat_param.ndim == 1 + ), f"Expects `flat_param` to be flattened but got {handle.flat_param.shape}" + # Get the unpadded shard instead of the padded shard to persist + # user changes to the padding (though FSDP does not explicitly + # support this) + shard, _ = FlatParamHandle._get_unpadded_shard(handle.flat_param, handle.rank, handle.world_size) + handle.flat_param._local_shard[:shard.numel()].copy_(shard) + + def named_buffers( + self, + *args, + **kwargs, + ) -> Iterator[Tuple[str, torch.Tensor]]: + """ + Overrides :meth:`named_buffers()` to intercept buffer names and + remove all occurrences of the FSDP-specific flattened buffer prefix + when inside the :meth:`summon_full_params` context manager. + """ + in_summon_full_params = self.training_state == TrainingState_.SUMMON_FULL_PARAMS + for buffer_name, buffer in super().named_buffers(*args, **kwargs): + if in_summon_full_params: + # Remove any instances of the FSDP-specific prefix; there can + # be multiple in the case of nested FSDP modules + buffer_name = buffer_name.replace(FSDP_PREFIX, "") + yield (buffer_name, buffer) + + def named_parameters( + self, + *args, + **kwargs, + ) -> Iterator[Tuple[str, torch.nn.Parameter]]: + """ + Overrides :meth:`named_parameters()` to intercept parameter names and + remove all occurrences of the FSDP-specific flattened parameter prefix + when inside the :meth:`summon_full_params` context manager. + """ + # Determine which logic to use based on the context at call time + in_summon_full_params = self.training_state == TrainingState_.SUMMON_FULL_PARAMS + for param_name, param in super().named_parameters(*args, **kwargs): + if in_summon_full_params: + # Remove any instances of the FSDP-specific prefix; there can + # be multiple in the case of nested FSDP modules + param_name = param_name.replace(FSDP_PREFIX, "") + yield (param_name, param) + + def _register_pre_backward_hooks( + self, + outputs: Any, + handles: List[FlatParamHandle], + ) -> Any: + """ + Registers pre-backward hooks on the tensors that require gradients in + the forward pass outputs ``outputs``, which were computed using the + ``FlatParameter`` s of ``handles``. + + Returns: + Forward pass outputs with pre-backward hooks registered to tensors + that require gradients. + """ + # If there is no gradient computation, then there is no need for + # pre-backward logic + if not torch.is_grad_enabled(): + return outputs + + if self._is_root: + self._post_backward_callback_queued = False # only defined on the root + + handles_key = tuple(handles) + if handles_key: + # Since these handles' `FlatParameter`s participated in a forward, + # we conservatively assume that they will be used in the backward + self._needs_pre_backward_unshard[handles_key] = False + self._ran_pre_backward_hook[handles_key] = False + + def _pre_backward_hook(_handles: List[FlatParamHandle], *unused: Any) -> None: + """Prepares ``_handles`` 's ``FlatParameter`` s for gradient + computation.""" + _handles_key = tuple(_handles) # avoid shadowing `handles_key` + # Only run the pre-backward hook once per group of handles involved + # in the same module forward computation + if _handles_key and self._ran_pre_backward_hook.get(_handles_key, False): + return + + with torch.autograd.profiler.record_function( + "FullyShardedDataParallel._pre_backward_hook" + ): + # Queue the post-backward callback once for the root FSDP + # instance to attach it to the outermost backward graph task so + # that it is called after all backward calls complete + if self._is_root and not self._post_backward_callback_queued: + self._queue_wait_for_post_backward() + elif _handles_key: + self._assert_state([TrainingState_.IDLE]) + self.training_state = TrainingState_.BACKWARD_PRE + # Queueing the post-backward callback is the only logic that is + # not per-handle in the pre-backward hook, so we can return + # early here if there are no handles. + if not _handles_key: + return + for handle in _handles: + handle._training_state = HandleTrainingState.BACKWARD_PRE + + # If the handles have been prefetched, this `_unshard()` simply + # switches to using the unsharded parameter + self._unshard(_handles) + torch_npu.npu.current_stream().wait_stream(self._streams["all_gather"]) + + # Set this to `False` to ensure that a mistargeted prefetch + # does not actually unshard these handles + self._needs_pre_backward_unshard[_handles_key] = False + self._prefetch_handles(_handles_key) + for handle in _handles: + handle.prepare_gradient() + self._ran_pre_backward_hook[_handles_key] = True + + def _register_hook(t: torch.Tensor) -> torch.Tensor: + if t.requires_grad: + t.register_hook(functools.partial(_pre_backward_hook, handles)) + self._needs_pre_backward_unshard[handles_key] = True + return t + + return _apply_to_tensors(_register_hook, outputs) + + def _register_post_backward_hooks( + self, + handles: List[FlatParamHandle], + ) -> None: + """ + Registers post-backward hooks on the ``FlatParameter`` s' + ``AccumulateGrad`` objects to reshard and to reduce-scatter gradients. + + The ``AccumulateGrad`` object represents the last function that + finalizes the ``FlatParameter`` 's gradient, so it only runs after its + entire gradient computation has finished. + + We register the post-backward hook only once in the *first* forward + that a ``FlatParameter`` participates in. This relies on the + ``AccumulateGrad`` object being preserved through multiple forwards. + """ + # If there is no gradient computation, then there is no need for + # post-backward logic + if not torch.is_grad_enabled(): + return + for handle in handles: + flat_param = handle.flat_param + already_registered = hasattr(flat_param, "_post_backward_hook_state") + if already_registered or not flat_param.requires_grad: + continue + # Get the `AccumulateGrad` object + temp_flat_param = flat_param.expand_as(flat_param) + p_assert( + temp_flat_param.grad_fn is not None, + "The `grad_fn` is needed to access the `AccumulateGrad` and " + "register the post-backward hook" + ) + acc_grad = temp_flat_param.grad_fn.next_functions[0][0] + hook_handle = acc_grad.register_hook( + functools.partial(self._post_backward_hook, handle) + ) + flat_param._post_backward_hook_state = (acc_grad, hook_handle) # type: ignore[attr-defined] + + @torch.no_grad() + def _post_backward_hook( + self, + handle: FlatParamHandle, + *unused: Any, + ) -> None: + """ + Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``. + + Precondition: The ``FlatParameter`` 's ``.grad`` attribute contains the + unsharded gradient for the local batch. + + Postcondition: + - If using ``NO_SHARD``, then the ``.grad`` attribute is the reduced + unsharded gradient. + - Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded + gradient (accumulating with any existing gradient). + """ + param = handle.flat_param + param._post_backward_called = True + with torch.autograd.profiler.record_function( + "FullyShardedDataParallel._post_backward_hook" + ): + # First hook callback will see PRE state. If we have multiple params, + # then subsequent hook callbacks will see POST state. + self._assert_state([TrainingState_.BACKWARD_PRE, TrainingState_.BACKWARD_POST]) + self.training_state = TrainingState_.BACKWARD_POST + handle._training_state = HandleTrainingState.BACKWARD_POST + + if self._use_param_exec_order_policy() and self._param_exec_order_prep_stage: + # In self._fsdp_params_exec_order, the parameters are ordered based on + # the execution order in the backward pass in the first iteration. + self._fsdp_params_exec_order.append(param) + + if param.grad is None: + return + if param.grad.requires_grad: + raise RuntimeError( + "FSDP only works with gradients that don't require gradients" + ) + + free_unsharded_flat_param = self._should_free_unsharded_flat_param(handle) + self._reshard([handle], [free_unsharded_flat_param]) + + # TODO (awgu): Post-backward prefetching does not support the + # multiple handles per module case (which was why we keyed by + # *tuple*). The post-backward hook runs per handle, not per group + # of handles. To generalize this, we may need a 2-level mapping, + # where we map each individual handle to its groups of handles and + # then from the groups of handles to their indices in the order. + handles_key = (handle,) + self._prefetch_handles(handles_key) + + if not self._sync_gradients: + return + + # Wait for all ops in the current stream (e.g. gradient + # computation) to finish before reduce-scattering the gradient + self._streams["post_backward"].wait_stream(torch_npu.npu.current_stream()) + + with torch_npu.npu.stream(self._streams["post_backward"]): + orig_grad_data = param.grad.data + if ( + self._mixed_precision_enabled_for_reduce() + and not self._low_precision_hook_enabled() + ): + # Cast gradient to precision in which it should be communicated. + # If a low precision hook is registered and reduce_dtype is specified + # in `MixedPrecision`, communication hook will take care of + # casting to lower precision and back. + # TODO: Make this a communication hook when communication hooks + # are implemented for FSDP. Note that this is a noop if the + # reduce_dtype matches the param dtype. + param.grad.data = param.grad.data.to(self.mixed_precision.reduce_dtype) + + if self._exec_order_data.is_first_iter: + # For all sharding strategies communication is performed through `_communication_hook`: + # default comm hooks are: `reduce_scatter` for sharded strategies and + # `all_reduce` for non-sharded strategies. This checks asserts that `_communication_hook` + # and `_communication_hook_state`, required for communication not `None`.` + p_assert( + self._communication_hook is not None, + "Communication hook should not be None" + ) + p_assert( + self._communication_hook_state is not None, + "Communication hook state should not be None" + ) + grad = param.grad.data + if handle.uses_sharded_strategy: + # We clear `param.grad` to permit repeated gradient + # computations when this FSDP module is called multiple times. + # This is to avoid a race among multiple re-entrant backward + # passes. For example, the second backward pass computation + # precedes ahead of the first backward pass reduction, which is + # possible since the reduction is in a different stream and is + # async. Then, the first backward pass may be incorrectly + # reducing the second backward pass's `param.grad`. + # The reduced gradients are accumulated in + # `param._saved_grad_shard`, and the gradient reductions can + # happen in arbitrary order, though we tolerate this due to the + # (approximate) commutativity of floating-point addition. + param.grad = None + grad_flatten = torch.flatten(grad) + chunks = list(grad_flatten.chunk(self.world_size)) + num_pad = self.world_size * chunks[0].numel() - grad.numel() + input_flattened = F.pad(grad_flatten, [0, num_pad]) + output = torch.zeros_like(chunks[0]) + self._communication_hook(self._communication_hook_state, input_flattened, output) + + self._cast_grad_to_param_dtype(output, param) + + # To support gradient accumulation outside `no_sync()`, we save + # the gradient data to `param._saved_grad_shard` before the + # backward pass, accumulate gradients into it here, and set + # `param.grad` with the accumulated value at the end of the + # backward pass in preparation for the optimizer step. + accumulate_grad = hasattr(param, "_saved_grad_shard") + if accumulate_grad: + p_assert( + param._saved_grad_shard.shape == output.shape, # type: ignore[attr-defined] + "Shape mismatch when accumulating gradients: " # type: ignore[attr-defined] + f"existing grad shape={param._saved_grad_shard.shape} " + f"new grad shape={output.shape}" # type: ignore[attr-defined] + ) + p_assert( + param._saved_grad_shard.device == output.device, # type: ignore[attr-defined] + "Device mismatch when accumulating gradients: " # type: ignore[attr-defined] + f"existing grad device={param._saved_grad_shard.device} " + f"new grad device={output.device}" # type: ignore[attr-defined] + ) + param._saved_grad_shard += output # type: ignore[attr-defined] + else: + param._saved_grad_shard = output # type: ignore[attr-defined] + grad = param._saved_grad_shard # type: ignore[attr-defined] + else: + if self.sharding_strategy == ShardingStrategy.NO_SHARD: + self._communication_hook(self._communication_hook_state, param.grad) + + # For NO_SHARD keeping grads in the reduced precision, we + # can simply omit the cast as needed, we can't do this for + # other sharding strategies because grad field is assigned + # in _finalize_params. TODO (rvarm1) this divergence in + # logic is not ideal. + if not self._mixed_precision_keep_low_precision_grads(): + self._cast_grad_to_param_dtype(param.grad, param) + + # Regardless of sharding or not, offload the grad to CPU if we are + # offloading params. This is so param and grad reside on same device + # which is needed for the optimizer step. + if handle._config.offload_params: + # We specify non_blocking=True + # and ensure the appropriate synchronization is done by waiting + # streams in _wait_for_post_backward. + param._cpu_grad.copy_( # type: ignore[attr-defined] + grad.detach(), non_blocking=True + ) + # Don't let this memory get reused until after the transfer. + grad.data.record_stream(torch_npu.npu.current_stream()) + + # After _post_backward_hook returns, orig_grad_data will eventually + # go out of scope, at which point it could otherwise be freed for + # further reuse by the main stream while the div/reduce_scatter/copy + # are underway in the post_backward stream. See: + # github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py + orig_grad_data.record_stream(self._streams["post_backward"]) + + def _cast_grad_to_param_dtype( + self, + grad: torch.Tensor, + param: FlatParameter, + ): + """ + Casts gradient ``grad`` back to the full parameter dtype so that the + optimizer step runs with that dtype. This performs an actual cast if + 1. parameters were in reduced precision during the forward since then + gradients would be in that reduced precision, or + 2. parameters were not in reduced precision but gradients were in + reduced precision for communication. + However, if a low precision communication hook is registered, then this + dtype cast happens in the hook instead. + """ + self._assert_state(TrainingState_.BACKWARD_POST) + if ( + not self._low_precision_hook_enabled() + and ( + self._mixed_precision_enabled_for_params() + or self._mixed_precision_enabled_for_reduce() + ) + ): + low_prec_grad_data = grad.data + grad.data = grad.data.to(dtype=param.dtype) + # Do not let the low precision gradient memory get reused until + # the cast to full parameter precision completes + low_prec_grad_data.record_stream(torch_npu.npu.current_stream()) + + def _should_free_unsharded_flat_param(self, handle: FlatParamHandle): + return ( + (self._sync_gradients and handle.uses_sharded_strategy) + or handle._config.sharding_strategy == HandleShardingStrategy.FULL_SHARD + ) + + def _queue_wait_for_post_backward(self) -> None: + """ + Queues a post-backward callback from the root FSDP instance, which + should happen at the beginning of its pre-backward. + """ + p_assert( + self._is_root, + "`_queue_wait_for_post_backward()` should be called on the root FSDP instance" + ) + if self._post_backward_callback_queued: + return + self._assert_state([TrainingState_.IDLE]) + self._post_backward_callback_queued = True + Variable._execution_engine.queue_callback(self._wait_for_post_backward) + + @torch.no_grad() + def _wait_for_post_backward(self) -> None: + """Wait for post-backward to finish. Only called on root instance.""" + assert self._is_root, "_wait_for_post_backward can only be called on root." + # Root's training state might be backward_pre or backward_post depending on + # if root parameter's post backward hook was called. The post-backward hook + # may not have been called if gradient was not computed for this param/FSDP + # module. + + if self._sync_gradients: + torch_npu.npu.current_stream().wait_stream(self._streams["post_backward"]) + if self.cpu_offload.offload_params: + # We need to wait for the non-blocking NPU -> + # CPU grad transfers to finish. We need to do this for NPU -> CPU + # copies because when grad is on CPU, it won't wait for any NPU + # stream to finish NPU -> CPU copies unless we explicitly block the + # host-side with synchronize(). + torch_npu.npu.current_stream().synchronize() + self._exec_order_data.next_iter() + + # A backward pass is done, clean up below. + def _catch_all_reshard(fsdp_module: FullyShardedDataParallel) -> None: + """ + Reshards full parameters that may have not been resharded in + post_backward_hook. This can happen when an FSDP module's output + is used in forward so its pre-backward fires unsharding the param, + but post-backward does not fire since the output was not ultimately + used in loss computation so FSDP parameter did not get a gradient. + """ + # Note that we wrap resharding logic in a try-catch as a defensive + # approach, as if an error is thrown, we are in the backwards pass, + # and autograd would not print out much useful info about the actual + # error hit. + try: + free_unsharded_flat_params: List[bool] = [] + handles_to_reshard: List[FlatParamHandle] = [] + for handle in fsdp_module._handles: + # TODO: This already-resharded check is brittle: + # https://github.com/pytorch/pytorch/issues/83956 + already_resharded = ( + handle.flat_param.data_ptr() == handle.flat_param._local_shard.data_ptr() + ) + if already_resharded: + continue + free_unsharded_flat_params.append(self._should_free_unsharded_flat_param(handle)) + handles_to_reshard.append(handle) + self._reshard(handles_to_reshard, free_unsharded_flat_params) + except Exception as e: + p_assert( + False, + f"Got exception while resharding module {fsdp_module}: {str(e)}", + raise_assertion_error=False + ) + raise e + + def _finalize_params(fsdp_module: FullyShardedDataParallel) -> None: + """Helper used below on all fsdp modules.""" + for handle in fsdp_module._handles: + p = handle.flat_param + if p.requires_grad: + if hasattr(p, "_post_backward_hook_state"): + p_assert( + len(p._post_backward_hook_state) == 2, # type: ignore[attr-defined] + "p._post_backward_hook_state fields are not valid." + ) + p._post_backward_hook_state[1].remove() # type: ignore[attr-defined] + delattr(p, "_post_backward_hook_state") + # Preserve the gradient accumulation state if not + # synchronizing: `p.grad` remains the unsharded gradient + # accumulated from prior `no_sync()` iterations, and + # `p._saved_grad_shard` remains the sharded gradient from + # the last synchronized iteration + if not self._sync_gradients: + continue + # Set `p.grad` as needed to ensure optimizer correctness + # since optimizers operate on the `grad` attribute + if hasattr(p, "_cpu_grad"): + p_assert( + p.device == torch.device("cpu"), + f"Device mismatch: p={p.device} " # type: ignore[attr-defined] + f"p._cpu_grad={p._cpu_grad}" + ) + p.grad = p._cpu_grad # type: ignore[attr-defined] + elif hasattr(p, "_saved_grad_shard"): + p_assert( + p.device == p._saved_grad_shard.device, # type: ignore[attr-defined] + f"Device mismatch: p={p.device} " # type: ignore[attr-defined] + f"p._saved_grad_shard={p._saved_grad_shard.device}" + ) + # Check if post-backward was called for this param (FSDP unit). + # TODO: This logic will have to be revisited when non-recursive wrapping + # lands. If it was not called, there is no new gradient to accumulate + if p._post_backward_called: + p.grad = p._saved_grad_shard + if fsdp_module._mixed_precision_keep_low_precision_grads(): + p.grad.data = p.grad.to( + fsdp_module.mixed_precision.param_dtype + ) + else: + p_assert( + not handle.uses_sharded_strategy or not p._post_backward_called, + "All sharded parameters that received a gradient " + "should use `_saved_grad_shard`" + ) + if hasattr(p, "_saved_grad_shard"): + delattr(p, "_saved_grad_shard") + + p_assert( + hasattr(p, '_post_backward_called'), + "Expected flag _post_backward_called to be set on param." + ) + # Reset _post_backward_called in preparation for the next iteration. + p._post_backward_called = False + + # Update root and nested FSDP's hooks and flags. + for m in self.fsdp_modules(self): # includes self + _finalize_params(m) + _catch_all_reshard(m) + m._ran_pre_backward_hook.clear() + m.training_state = TrainingState_.IDLE + for handle in m._handles: + handle._training_state = HandleTrainingState.IDLE + m._handles_prefetched.clear() + if m._is_root: + # reset this flag for cases like "one forward pass + multiple backward passes" + self._post_backward_callback_queued = False + + if self._use_param_exec_order_policy() and self._param_exec_order_prep_stage: + self._param_exec_order_policy_second_iter_init() + + def _param_exec_order_policy_second_iter_init(self) -> None: + self._param_exec_order_prep_stage = False + # Let the parameters in self._fsdp_params_exec_order ordered based on + # the execution order in the forward pass. + self._fsdp_params_exec_order.reverse() + for m in self.modules(): + if m is not self and isinstance(m, FullyShardedDataParallel): + assert hasattr( + m, "_param_exec_order_policy" + ), "Non-root FSDP modules should also have _param_exec_order_policy attribute" + assert hasattr( + m, "_param_exec_order_prep_stage" + ), "Non-root FSDP modules should also have _param_exec_order_prep_stage attribute" + m._param_exec_order_prep_stage = False + # TODO (linjianma): Construct a fsdp_wrap_map whose keys are all children modules with a FSDP wrap, + # and values are its FSDP wraps. These children FSDP wraps will be detached from the root FSDP module + # and will be used to schedule the parameters (rebuild_full_params and reshard). + # TODO (linjianma): Remove all internal FSDP wraps from the root FSDP module. + # TODO (linjianma): Based on self._fsdp_params_exec_order, get the information + # needed to patch the forward() function of each key in the fsdp_wrap_map. The rules are as follows: + # 1: Before each forward(), rebuild_full_params of all parameters that are currently sharded and + # will be used in the forward, and reshard all parameters that are currently full and will not be + # used in the next forward() + # 2: After each forward(), reshard all parameters just used in the forward, and rebuild_full_params of + # all parameters that will be used next. + # TODO (linjianma): Patch the forward of each model in the keys + # of fsdp_wrap_map based on the information above. + + def _assert_state(self, state: Union[TrainingState_, List[TrainingState_]]) -> None: + """Assert we are in the given state.""" + # Since assert can be turned off and this error checking + # is really important, we use explicit error checking + # and raise a ValueError if needed. + if isinstance(state, TrainingState_): + state = [state] + if self.training_state not in state: + msg = ( + f"expected to be in states {state} but current state " + f"is {self.training_state}" + ) + # In case we are failing in the context of autograd hook, asserting + # may not generate useful msg. So, let's print it to be sure. + if self.rank == 0: + print(f"Asserting FSDP instance is: {self}") + print(f"ERROR: {msg}") + traceback.print_stack() + raise ValueError(msg) + + @contextmanager + def no_sync(self) -> Generator: + """ + A context manager to disable gradient synchronizations across FSDP + instances. Within this context, gradients will be accumulated in module + variables, which will later be synchronized in the first + forward-backward pass after exiting the context. This should only be + used on the root FSDP instance and will recursively apply to all + children FSDP instances. + + .. note:: This likely results in higher memory usage because FSDP will + accumulate the full model gradients (instead of gradient shards) + until the eventual sync. + + .. note:: When used with CPU offloading, the gradients will not be + offloaded to CPU when inside the context manager. Instead, they + will only be offloaded right after the eventual sync. + """ + self._lazy_init() + assert self._is_root, "`no_sync()` on inner FSDP instances is not supported" + self._assert_state(TrainingState_.IDLE) + old_flags = [] + for m in self.modules(): + if isinstance(m, FullyShardedDataParallel): + old_flags.append((m, m._sync_gradients)) + m._sync_gradients = False + try: + yield + finally: + for m, old_flag in old_flags: + assert not m._sync_gradients, ( + "`_sync_gradients` was incorrectly set to " + "`True` while in the `no_sync()` context manager" + ) + m._sync_gradients = old_flag + + @property + def params_with_grad(self) -> List[Parameter]: + """ + Recursively returns a list of all module parameters that have a gradient. + """ + return [p for p in self.parameters() if p.grad is not None] + + @torch.no_grad() + def clip_grad_norm_( + self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0 + ) -> None: + """ + Clip all gradients at this point in time. The norm is computed over all + gradients together, as if they were concatenated into a single vector. + Gradients are modified in-place. + + Args: + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` + for infinity norm. + + Returns: + Total norm of the parameters (viewed as a single vector). + + .. note:: This is analogous to ``torch.nn.utils.clip_grad_norm_`` but + handles the partitioning and multiple devices per rank under the + hood. The default torch util is not applicable here, because each + rank only has a partial view of all the grads in the model, so + calling it for FSDP models would lead to different scaling being + applied per subset of model parameters. + + .. warning:: This needs to be called on all ranks, since synchronization + primitives will be used. + """ + self._lazy_init() + self._wait_for_previous_optim_step() + assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance" + self._assert_state(TrainingState_.IDLE) + + max_norm = float(max_norm) + norm_type = float(norm_type) + # Computes the max norm for this shard's gradients and sync's across workers + local_norm = _calc_grad_norm(self.params_with_grad, norm_type).npu() # type: ignore[arg-type] + if norm_type == math.inf: + total_norm = local_norm + dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group) + else: + total_norm = local_norm ** norm_type + dist.all_reduce(total_norm, group=self.process_group) + total_norm = total_norm ** (1.0 / norm_type) + + if self.cpu_offload: + total_norm = total_norm.cpu() + + clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6) + if clip_coef < 1: + # multiply by clip_coef, aka, (max_norm/total_norm). + for p in self.params_with_grad: + assert p.grad is not None + p.grad.detach().mul_(clip_coef.to(p.grad.device)) + + @staticmethod + def _warn_optim_input(optim_input): + if optim_input is not None: + warnings.warn( + "The `optim_input` argument is deprecated and will be removed after PyTorch 1.13. You may remove it " + "from your code without changing its functionality." + ) + + @staticmethod + def _is_using_optim_input(optim_input, optim) -> bool: + if optim_input is None and optim is None: + # Use the default behavior of `optim_input`` + return True + if optim_input is not None: + # Use the `optim_input` code path + return True + # Use the `optim` code path + return False + + @staticmethod + def full_optim_state_dict( + model: torch.nn.Module, + optim: torch.optim.Optimizer, + optim_input: Optional[Union[ + List[Dict[str, Any]], Iterable[torch.nn.Parameter], + ]] = None, + rank0_only: bool = True, + group: Optional[dist.ProcessGroup] = None, + ) -> Dict[str, Any]: + """ + Consolidates the full optimizer state on rank 0 and returns it + as a :class:`dict` following the convention of + :meth:`torch.optim.Optimizer.state_dict`, i.e. with keys ``"state"`` + and ``"param_groups"``. The flattened parameters in ``FSDP`` modules + contained in ``model`` are mapped back to their unflattened parameters. + + .. warning:: This needs to be called on all ranks since synchronization + primitives are used. However, if ``rank0_only=True``, then the + state dict is only populated on rank 0, and all other ranks return + an empty :class:`dict`. + + .. warning:: Unlike ``torch.optim.Optimizer.state_dict()``, this method + uses full parameter names as keys instead of parameter IDs. + + .. note:: Like in :meth:`torch.optim.Optimizer.state_dict`, the tensors + contained in the optimizer state dict are not cloned, so there may + be aliasing surprises. For best practices, consider saving the + returned optimizer state dict immediately, e.g. using + ``torch.save()``. + + Args: + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + were passed into the optimizer ``optim``. + optim (torch.optim.Optimizer): Optimizer for ``model`` 's + parameters. + optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): + Input passed into the optimizer ``optim`` representing either a + :class:`list` of parameter groups or an iterable of parameters; + if ``None``, then this method assumes the input was + ``model.parameters()``. This argument is deprecated, and there + is no need to pass it in anymore. (Default: ``None``) + rank0_only (bool): If ``True``, saves the populated :class:`dict` + only on rank 0; if ``False``, saves it on all ranks. (Default: + ``True``) + group (dist.ProcessGroup): Model's process group or ``None`` if using + the default process group. (Default: ``None``) + + Returns: + Dict[str, Any]: A :class:`dict` containing the optimizer state for + ``model`` 's original unflattened parameters and including keys + "state" and "param_groups" following the convention of + :meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=True``, + then nonzero ranks return an empty :class:`dict`. + """ + FullyShardedDataParallel._warn_optim_input(optim_input) + using_optim_input = FullyShardedDataParallel._is_using_optim_input( + optim_input, optim, + ) + return _optim_state_dict( + model=model, + optim=optim, + optim_input=optim_input, + rank0_only=rank0_only, + shard_state=False, + group=group, + using_optim_input=using_optim_input, + ) + + @staticmethod + def sharded_optim_state_dict( + model: torch.nn.Module, + optim: torch.optim.Optimizer, + optim_input: Optional[ + Union[ + List[Dict[str, Any]], Iterable[torch.nn.Parameter], + ] + ] = None, + group: Optional[dist.ProcessGroup] = None, + ) -> Dict[str, Any]: + """ + The API is similar to :meth:`full_optim_state_dict` but this API chunks + all non-zero-dimension states to :class:`ShardedTensor` to save memory. + This API should only be used when the model ``state_dict`` is derived + with the context manager ``with state_dict_type(SHARDED_STATE_DICT):``. + + For the detailed usage, refer to :meth:`full_optim_state_dict`. + + .. warning:: The returned state dict contains ``ShardedTensor`` and + cannot be directly used by the regular ``optim.load_state_dict``. + """ + FullyShardedDataParallel._warn_optim_input(optim_input) + using_optim_input = FullyShardedDataParallel._is_using_optim_input( + optim_input, optim, + ) + # TODO: The ultimate goal of the optimizer state APIs should be the same + # as state_dict/load_state_dict -- using one API to get optimizer states + # and one API to load optimizer states. ``state_dict_type`` will be used + # to decide which optimizer states should be returned. + # There are currently two APIs to load a full optimizer state. So the + # first step of the unification is to merge the two full optimizer state + # loading APIs. + # Task: https://github.com/pytorch/pytorch/issues/82232 + return _optim_state_dict( + model=model, + optim=optim, + optim_input=optim_input, + rank0_only=False, + shard_state=True, + group=group, + using_optim_input=using_optim_input, + ) + + @staticmethod + def shard_full_optim_state_dict( + full_optim_state_dict: Dict[str, Any], + model: torch.nn.Module, + optim_input: Optional[ + Union[ + List[Dict[str, Any]], Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + ) -> Dict[str, Any]: + """ + Shards the full optimizer state dict ``full_optim_state_dict`` by + remapping the state to flattened parameters instead of unflattened + parameters and restricting to only this rank's part of the optimizer + state. The first argument should be the return value of + :meth:`full_optim_state_dict`. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> model, optim = ... + >>> full_osd = FSDP.full_optim_state_dict(model, optim) + >>> torch.save(full_osd, PATH) + >>> # Define new model with possibly different world size + >>> new_model, new_optim = ... + >>> full_osd = torch.load(PATH) + >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model) + >>> new_optim.load_state_dict(sharded_osd) + + .. note:: Both :meth:`shard_full_optim_state_dict` and + :meth:`scatter_full_optim_state_dict` may be used to get the + sharded optimizer state dict to load. Assuming that the full + optimizer state dict resides in CPU memory, the former requires + each rank to have the full dict in CPU memory, where each rank + individually shards the dict without any communication, while the + latter requires only rank 0 to have the full dict in CPU memory, + where rank 0 moves each shard to NPU memory (for NCCL) and + communicates it to ranks appropriately. Hence, the former has + higher aggregate CPU memory cost, while the latter has higher + communication cost. + + Args: + full_optim_state_dict (Dict[str, Any]): Optimizer state dict + corresponding to the unflattened parameters and holding the + full non-sharded optimizer state. + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + correspond to the optimizer state in ``full_optim_state_dict``. + optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): + Input passed into the optimizer representing either a + :class:`list` of parameter groups or an iterable of parameters; + if ``None``, then this method assumes the input was + ``model.parameters()``. This argument is deprecated, and there + is no need to pass it in anymore. (Default: ``None``) + optim (Optional[torch.optim.Optimizer]): Optimizer that will load + the state dict returned by this method. This is the preferred + argument to use over ``optim_input``. (Default: ``None``) + + Returns: + Dict[str, Any]: The full optimizer state dict now remapped to + flattened parameters instead of unflattened parameters and + restricted to only include this rank's part of the optimizer state. + """ + FullyShardedDataParallel._warn_optim_input(optim_input) + using_optim_input = FullyShardedDataParallel._is_using_optim_input( + optim_input, optim, + ) + sharded_osd = _flatten_optim_state_dict( + full_optim_state_dict, model, True, + ) + return _rekey_sharded_optim_state_dict( + sharded_osd, model, optim, optim_input, using_optim_input, + ) + + @staticmethod + def flatten_sharded_optim_state_dict( + sharded_optim_state_dict: Dict[str, Any], + model: torch.nn.Module, + optim_input: Optional[ + Union[ + List[Dict[str, Any]], Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + ) -> Dict[str, Any]: + """ + The API is similar to :meth:`shard_full_optim_state_dict`. The only + difference is that the input ``sharded_optim_state_dict`` should be + returned from :meth:`sharded_optim_state_dict`. Therefore, there will + be all-gather calls on each rank to gather ``ShardedTensor`` s. + + Args: + sharded_optim_state_dict (Dict[str, Any]): Optimizer state dict + corresponding to the unflattened parameters and holding the + sharded optimizer state. + model (torch.nn.Module): + Refer to :meth:``shard_full_optim_state_dict``. + + Returns: + Refer to :meth:`shard_full_optim_state_dict`. + """ + FullyShardedDataParallel._warn_optim_input(optim_input) + using_optim_input = FullyShardedDataParallel._is_using_optim_input( + optim_input, optim, + ) + # TODO: The implementation is the same as ``shard_full_optim_state_dict``. + # See the TODO in ``shard_full_optim_state_dict`` for the future + # unification plan. + flattened_osd = _flatten_optim_state_dict( + sharded_optim_state_dict, + model=model, + shard_state=True, + ) + return _rekey_sharded_optim_state_dict( + flattened_osd, model, optim, optim_input, using_optim_input, + ) + + @staticmethod + def scatter_full_optim_state_dict( + full_optim_state_dict: Optional[Dict[str, Any]], + model: torch.nn.Module, + optim_input: Optional[Union[ + List[Dict[str, Any]], Iterable[torch.nn.Parameter], + ]] = None, + optim: Optional[torch.optim.Optimizer] = None, + group: Optional[Any] = None, + ) -> Dict[str, Any]: + """ + Scatters the full optimizer state dict from rank 0 to all other ranks, + returning the sharded optimizer state dict on each rank. The return + value is the same as :meth:`shard_full_optim_state_dict`, and on rank + 0, the first argument should be the return value of + :meth:`full_optim_state_dict`. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> model, optim = ... + >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 + >>> # Define new model with possibly different world size + >>> new_model, new_optim, new_group = ... + >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) + >>> new_optim.load_state_dict(sharded_osd) + + .. note:: Both :meth:`shard_full_optim_state_dict` and + :meth:`scatter_full_optim_state_dict` may be used to get the + sharded optimizer state dict to load. Assuming that the full + optimizer state dict resides in CPU memory, the former requires + each rank to have the full dict in CPU memory, where each rank + individually shards the dict without any communication, while the + latter requires only rank 0 to have the full dict in CPU memory, + where rank 0 moves each shard to NPU memory (for NCCL) and + communicates it to ranks appropriately. Hence, the former has + higher aggregate CPU memory cost, while the latter has higher + communication cost. + + Args: + full_optim_state_dict (Optional[Dict[str, Any]]): Optimizer state + dict corresponding to the unflattened parameters and holding + the full non-sharded optimizer state if on rank 0; the argument + is ignored on nonzero ranks. + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + correspond to the optimizer state in ``full_optim_state_dict``. + optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): + Input passed into the optimizer representing either a + :class:`list` of parameter groups or an iterable of parameters; + if ``None``, then this method assumes the input was + ``model.parameters()``. This argument is deprecated, and there + is no need to pass it in anymore. (Default: ``None``) + optim (Optional[torch.optim.Optimizer]): Optimizer that will load + the state dict returned by this method. This is the preferred + argument to use over ``optim_input``. (Default: ``None``) + group (dist.ProcessGroup): Model's process group or ``None`` if + using the default process group. (Default: ``None``) + + Returns: + Dict[str, Any]: The full optimizer state dict now remapped to + flattened parameters instead of unflattened parameters and + restricted to only include this rank's part of the optimizer state. + """ + FullyShardedDataParallel._warn_optim_input(optim_input) + using_optim_input = FullyShardedDataParallel._is_using_optim_input( + optim_input, optim, + ) + # Try to use the passed-in process group, the model's process group, + # or the default process group (i.e. `None`) in that priority order + if group is None and hasattr(model, "process_group"): + group = model.process_group + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + # Check for a valid broadcast device, preferring NPU when available + using_nccl = dist.distributed_c10d._check_for_nccl_backend(group) + broadcast_device = torch.device("npu") if torch_npu.npu.is_available() \ + else torch.device("cpu") + if using_nccl and not torch_npu.npu.is_available(): + raise RuntimeError("NCCL requires a NPU for collectives") + # Flatten the optimizer state dict and construct a copy with the + # positive-dimension tensors' shapes in place of the tensors themselves + # since those tensors will be broadcast separately to avoid copying + if rank == 0: + if full_optim_state_dict is None: + raise ValueError("Rank 0 must pass in the full optimizer state dict") + flat_osd = _flatten_optim_state_dict( + full_optim_state_dict, + model=model, + shard_state=False, + ) + processed_osd = _process_pos_dim_tensor_state(flat_osd, world_size) + # Broadcast the optim state dict without positive-dimension tensor + # state and the FSDP parameter IDs from rank 0 to all ranks + processed_osd = _broadcast_processed_optim_state_dict( + processed_osd if rank == 0 else None, rank, group, + ) + # Broadcast positive-dimension tensor state (both sharded tensors for + # FSDP parameters and unsharded tensors for non-FSDP parameters) + sharded_osd = _broadcast_pos_dim_tensor_states( + processed_osd, flat_osd if rank == 0 else None, rank, world_size, + group, broadcast_device, + ) + # Rekey the optimizer state dict to use parameter IDs according to this + # rank's `optim` + sharded_osd = _rekey_sharded_optim_state_dict( + sharded_osd, model, optim, optim_input, using_optim_input, + ) + return sharded_osd + + @staticmethod + def rekey_optim_state_dict( + optim_state_dict: Dict[str, Any], + optim_state_key_type: OptimStateKeyType, + model: torch.nn.Module, + optim_input: Optional[Union[ + List[Dict[str, Any]], Iterable[torch.nn.Parameter], + ]] = None, + optim: Optional[torch.optim.Optimizer] = None, + ) -> Dict[str, Any]: + """ + Re-keys the optimizer state dict ``optim_state_dict`` to use the key + type ``optim_state_key_type``. This can be used to achieve + compatibility between optimizer state dicts from models with FSDP + instances and ones without. + + To re-key an FSDP full optimizer state dict (i.e. from + :meth:`full_optim_state_dict`) to use parameter IDs and be loadable to + a non-wrapped model:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> wrapped_model, wrapped_optim = ... + >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) + >>> nonwrapped_model, nonwrapped_optim = ... + >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) + >>> nonwrapped_optim.load_state_dict(rekeyed_osd) + + To re-key a normal optimizer state dict from a non-wrapped model to be + loadable to a wrapped model:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> nonwrapped_model, nonwrapped_optim = ... + >>> osd = nonwrapped_optim.state_dict() + >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) + >>> wrapped_model, wrapped_optim = ... + >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) + >>> wrapped_optim.load_state_dict(sharded_osd) + + Returns: + Dict[str, Any]: The optimizer state dict re-keyed using the + parameter keys specified by ``optim_state_key_type``. + """ + FullyShardedDataParallel._warn_optim_input(optim_input) + using_optim_input = FullyShardedDataParallel._is_using_optim_input( + optim_input, optim, + ) + assert optim_state_key_type in ( + OptimStateKeyType.PARAM_NAME, OptimStateKeyType.PARAM_ID, + ) + osd = optim_state_dict # alias + # Validate that the existing parameter keys are uniformly typed + uses_param_name_mask = [ + type(param_key) is str for param_key in osd["state"] + ] + uses_param_id_mask = [ + type(param_key) is int for param_key in osd["state"] + ] + if ( + (any(uses_param_name_mask) and not all(uses_param_name_mask)) + or (any(uses_param_id_mask) and not all(uses_param_id_mask)) + ): + error_msg = f"Invalid parameter keys: {osd['state'].keys()}" + raise ValueError(error_msg) + # Return directly if the existing key type matches the target key type + if (optim_state_key_type == OptimStateKeyType.PARAM_NAME and + all(uses_param_name_mask)) or \ + (optim_state_key_type == OptimStateKeyType.PARAM_ID and + all(uses_param_id_mask)): + return osd + # Otherwise, actually perform the re-keying + new_osd = {} + if optim_state_key_type == OptimStateKeyType.PARAM_NAME: # ID -> name + param_id_to_param = ( + _get_param_id_to_param_from_optim_input(model, optim_input) + if using_optim_input + else _get_param_id_to_param(optim) + ) + param_to_param_name = _get_param_to_param_name(model) + param_id_to_param_name: List[str] = [ + param_to_param_name[param] for param in param_id_to_param + ] + new_osd["state"] = { + param_id_to_param_name[param_id]: param_state + for param_id, param_state in osd["state"].items() + } + new_osd["param_groups"] = copy.deepcopy(osd["param_groups"]) + for param_group in new_osd["param_groups"]: + param_group["params"] = sorted([ + param_id_to_param_name[param_id] + for param_id in param_group["params"] + ]) + return new_osd + elif optim_state_key_type == OptimStateKeyType.PARAM_ID: # name -> ID + param_name_to_param = _get_param_name_to_param(model) + param_to_param_id = ( + _get_param_to_param_id_from_optim_input(model, optim_input) + if using_optim_input + else _get_param_to_param_id(optim) + ) + # Because not all model parameters may be passed as the optimizer + # input, we may need to drop some parameters from this mapping + param_name_to_param_id = { + param_name: param_to_param_id[param] + for param_name, param in param_name_to_param.items() + if param in param_to_param_id + } + new_osd["state"] = { + param_name_to_param_id[param_name]: param_state + for param_name, param_state in osd["state"].items() + } + new_osd["param_groups"] = copy.deepcopy(osd["param_groups"]) + for param_group in new_osd["param_groups"]: + param_group["params"] = sorted([ + param_name_to_param_id[param_name] + for param_name in param_group["params"] + ]) + return new_osd + return new_osd # should never reach here + + def _get_default_comm_hook(self) -> Any: + r""" + Returns a default communication hook based on a sharding strategy. + """ + if self.sharding_strategy != ShardingStrategy.NO_SHARD: + return default_hooks.reduce_scatter_hook + else: + return default_hooks.allreduce_hook + + def _get_default_comm_hook_state(self) -> Any: + r""" + Returns a default communication hook state based on a sharding strategy. + """ + return default_hooks.DefaultState(process_group=self.process_group) + + def register_comm_hook(self, state: object, hook: callable): + """ + Registers a communication hook which is an enhancement that provides a + flexible hook to users where they can specify how FSDP aggregates gradients + across multiple workers. + This hook can be used to implement several algorithms like + `GossipGrad `_ and gradient compression + which involve different communication strategies for + parameter syncs while training with :class:`FullyShardedDataParallel`. + + .. warning :: + FSDP communication hook should be registered before running an initial forward pass + and only once. + + Args: + state (object): Passed to the hook to maintain any state information during the training process. + Examples include error feedback in gradient compression, + peers to communicate with next in `GossipGrad `_, etc. + It is locally stored by each worker + and shared by all the gradient tensors on the worker. + hook (Callable): Callable, which has one of the following signatures: + 1) ``hook: Callable[torch.Tensor] -> None``: + This function takes in a Python tensor, which represents + the full, flattened, unsharded gradient with respect to all variables + corresponding to the model this FSDP unit is wrapping + (that are not wrapped by other FSDP sub-units). + It then performs all necessary processing and returns ``None``; + 2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``: + This function takes in two Python tensors, the first one represents + the full, flattened, unsharded gradient with respect to all variables + corresponding to the model this FSDP unit is wrapping + (that are not wrapped by other FSDP sub-units). The latter + represents a pre-sized tensor to store a chunk of a sharded gradient after + reduction. + In both cases, callable performs all necessary processing and returns ``None``. + Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case. + Callables with signature 2 are expected to handle gradient communication for sharded cases. + + """ + if not self.check_is_root(): + raise AssertionError("register_comm_hook can only be called on a root instance.") + for submodule in self.fsdp_modules(self): + assert not submodule._hook_registered, "communication hook can be only registered once" + submodule._hook_registered = True + assert submodule._communication_hook == self._get_default_comm_hook(),\ + f"communication hook should be default, but it is {submodule._communication_hook.__name__} instead" + submodule._communication_hook_state = state + submodule._communication_hook = hook + + + def _init_param_exec_order_wrap_policy(self, *args, **kwargs) -> None: + auto_wrap_policy = kwargs["auto_wrap_policy"] + module = kwargs["module"] + assert hasattr(auto_wrap_policy, "tracing_config") + if not _TORCH_FX_AVAIL: + assert ( + auto_wrap_policy.tracing_config is None + ), "tracing_config should be None when torch.fx is not enabled" + elif isinstance( + auto_wrap_policy.tracing_config, + TracingConfig + ): + tracer = auto_wrap_policy.tracing_config.tracer + execution_info = _init_execution_info(module) + + for m in module.modules(): + assert not isinstance( + m, FullyShardedDataParallel + ), "The input module of _patch_tracer should not contain FSDP modules" + + with _patch_tracer( + tracer=tracer, + root_module=module, + execution_info=execution_info, + ): + try: + tracer.trace(module, auto_wrap_policy.tracing_config.concrete_args) + except BaseException as e: + raise RuntimeError( + "tracer.trace failed inside _init_param_exec_order_wrap_policy" + f" with the error: {e}." + ) + else: + assert ( + auto_wrap_policy.tracing_config is None + ), "tracing_config should either be an instance of TracingConfig or be None" + # The initial FSDP wrapping is done with auto_wrap_policy.init_policy + kwargs["auto_wrap_policy"] = auto_wrap_policy.init_policy + self.__init__(*args, **kwargs) + self._param_exec_order_policy: bool = True + # self._param_exec_order_prep_stage is set to True before we get the execution order + self._param_exec_order_prep_stage: bool = True + # A list that stores the flatten parameters and its name based on the parameter execution order + self._fsdp_params_exec_order: List[FlatParameter] = [] + if _TORCH_FX_AVAIL and isinstance( + auto_wrap_policy.tracing_config, + TracingConfig + ): + # Initialize a dict that maps each module to its parent FSDP wrap + module_to_fsdp: Dict[nn.Module, FullyShardedDataParallel] = dict() + for wrap in self.fsdp_modules(self): + module_to_fsdp[wrap.module] = wrap + # Set self._fsdp_params_exec_order based on execution_info.module_forward_order. + # TODO (linjianma): self._fsdp_params_exec_order will be set based on + # the parameter execution order rather than module_forward_order, + # once the non-recursive wrapping policy is fully implemented. + for m in execution_info.module_forward_order: + if m in module_to_fsdp: + for flat_param in module_to_fsdp[m].params: + self._fsdp_params_exec_order.append(flat_param) + self._param_exec_order_prep_stage = False + + for m in self.modules(): + if m is not self and isinstance(m, FullyShardedDataParallel): + # Assignment by reference, so each children FSDP wrap has access to + # the _fsdp_params_exec_order of the root module + m._fsdp_params_exec_order = self._fsdp_params_exec_order + m._param_exec_order_policy = self._param_exec_order_policy + m._param_exec_order_prep_stage = self._param_exec_order_prep_stage + + def _use_param_exec_order_policy(self) -> bool: + return ( + hasattr(self, "_param_exec_order_policy") + and self._param_exec_order_policy + ) + + def _is_param_exec_order_prep_stage(self) -> bool: + is_prep_stage = ( + hasattr(self, "_param_exec_order_prep_stage") + and self._param_exec_order_prep_stage + ) + if not is_prep_stage: + for p in self.parameters(): + assert ( + not hasattr(p, "_params_exec_order_hook_handle") + ), "When not in execution order prep stage, all _params_exec_order_hook_handle should be removed." + return is_prep_stage + + +def _calc_grad_norm(parameters: List[torch.nn.Parameter], p: float) -> torch.Tensor: + r"""Calculate gradient norm of an iterable of parameters. + Returns: + Total norm of the parameters (viewed as a single vector). + """ + parameters = [p for p in parameters if p.grad is not None] + + if len(parameters) == 0: + return torch.tensor(0.0) + if p == math.inf: + local_norm = torch.tensor(max(par.grad.detach().abs().max() for par in parameters)) + else: + # Compute the norm in full precision no matter what + local_norm = torch.linalg.vector_norm( + torch.stack( + [ + torch.linalg.vector_norm(par.grad.detach(), p, dtype=torch.float32) + for par in parameters + ] + ), + p, + ) + local_norm.to(dtype=parameters[0].dtype) + return local_norm + + +def _get_param_to_unflat_param_names( + model: torch.nn.Module, + dedup_shared_params: bool = True, +) -> Dict[torch.nn.Parameter, List[str]]: + """ + Constructs a mapping from flattened parameter (including non-FSDP-module + parameters) to its unflattened parameter names. For non-FSDP-module + parameters, these mapped-to lists always contain a single element. The + unflattened parameter names should match the keys of the model state dict. + + For shared parameters, only the first parameter name is included (following + the ``torch.nn.Module.parameters()`` order). + + Args: + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance). + dedup_shared_params (bool): If ``True``, only includes the first + list of unflattened parameter names corresponding to a parameter + in the module walk order; if ``False``, then includes all of the + unflattened parameter names. + """ + def module_fn(module, prefix, param_to_unflat_param_names): + # For FSDP modules, only add the entry when considering the contained + # `FlattenParamsWrapper` to avoid duplication + if not isinstance(module, FullyShardedDataParallel): + for param_name, param in module.named_parameters(recurse=False): + module_prefixed_param_names = ( + param._prefixed_param_names if type(param) is FlatParameter + else [param_name] + ) # prefixed from `module` + fully_prefixed_param_names = [ + clean_tensor_name(prefix + name) + for name in module_prefixed_param_names + ] # fully prefixed from the top level including `prefix` + # If this parameter has already been visited, then it is a + # shared parameter; then, only take the first parameter name + is_shared_param = param in param_to_unflat_param_names + if not is_shared_param: + param_to_unflat_param_names[param] = fully_prefixed_param_names + elif not dedup_shared_params: + param_to_unflat_param_names[param].extend(fully_prefixed_param_names) + + def return_fn(param_to_unflat_param_names): + return param_to_unflat_param_names + + param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {} + return _apply_to_modules( + model, module_fn, return_fn, param_to_unflat_param_names, + ) + + +def _get_param_to_param_name( + model: torch.nn.Module, +) -> Dict[torch.nn.Parameter, str]: + """ + Constructs a mapping from parameters to their parameter names. ``model`` + should not contain any :class:`FullyShardedDataParallel` instances, which + means that none of the parameters should be ``FlatParameter`` s. As a + result, compared to :meth:`_get_param_to_unflat_param_names`, the mapped + values may be flattened from singleton :class:`list` s to the contained + names themselves. + + Args: + model (torch.nn.Module): Root module, which should not contain any + :class:`FullyShardedDataParallel` instances. + """ + param_to_param_names = _get_param_to_unflat_param_names(model) + for param_names in param_to_param_names.values(): + assert len(param_names) > 0, "`_get_param_to_unflat_param_names()` " \ + "should not construct empty lists" + if len(param_names) > 1: + raise RuntimeError( + "Each parameter should only map to one parameter name but got " + f"{len(param_names)}: {param_names}" + ) + param_to_param_name = { + param: param_names[0] + for param, param_names in param_to_param_names.items() + } + return param_to_param_name + + +def _get_param_name_to_param( + model: torch.nn.Module, +) -> Dict[str, torch.nn.Parameter]: + """Constructs the inverse mapping of :meth:`_get_param_to_param_name`.""" + param_to_param_name = _get_param_to_param_name(model) + return dict(zip(param_to_param_name.values(), param_to_param_name.keys())) + + +def clean_tensor_name(tensor_name: str) -> str: + """Cleans the parameter or buffer name by removing any module wrapper + prefixes.""" + # Call `replace()` twice separately since the name may not have both + tensor_name = tensor_name.replace(FSDP_WRAPPED_MODULE + ".", "") + tensor_name = tensor_name.replace(FPW_MODULE + ".", "") + # TODO: Explicitly replacing checkpoint_wrapper prefix is not ideal, + # as it increases coupling between CheckpointWrapper and FSDP. This is also not + # scalable for additional wrapped modules, we should come up with a general solution + # for this issue. + tensor_name = tensor_name.replace(_CHECKPOINT_PREFIX + ".", "") + return tensor_name + + +def apply_fsdp(): + torch.distributed.fsdp.fully_sharded_data_parallel.clean_tensor_name = clean_tensor_name + torch.distributed.fsdp.fully_sharded_data_parallel._get_param_name_to_param = _get_param_name_to_param + torch.distributed.fsdp.fully_sharded_data_parallel._get_param_to_param_name = _get_param_to_param_name + torch.distributed.fsdp.fully_sharded_data_parallel._get_param_to_unflat_param_names = _get_param_to_unflat_param_names + torch.distributed.fsdp.fully_sharded_data_parallel._calc_grad_norm = _calc_grad_norm + torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel = FullyShardedDataParallel + torch.distributed.fsdp.fully_sharded_data_parallel.sharding_strategy_map = sharding_strategy_map + torch.distributed.fsdp.fully_sharded_data_parallel._FreeEventQueue = _FreeEventQueue + torch.distributed.fsdp.fully_sharded_data_parallel._ExecOrderData = _ExecOrderData + torch.distributed.fsdp.fully_sharded_data_parallel._ExecOrderWarnStatus = _ExecOrderWarnStatus + torch.distributed.fsdp.fully_sharded_data_parallel._HandlesKey = _HandlesKey + torch.distributed.fsdp.fully_sharded_data_parallel.OptimStateKeyType = OptimStateKeyType + torch.distributed.fsdp.fully_sharded_data_parallel._state_dict_type_to_config = _state_dict_type_to_config + torch.distributed.fsdp.fully_sharded_data_parallel.ShardedStateDictConfig = ShardedStateDictConfig + torch.distributed.fsdp.fully_sharded_data_parallel.LocalStateDictConfig = LocalStateDictConfig + torch.distributed.fsdp.fully_sharded_data_parallel.FullStateDictConfig = FullStateDictConfig + torch.distributed.fsdp.fully_sharded_data_parallel.StateDictConfig = StateDictConfig + torch.distributed.fsdp.fully_sharded_data_parallel.StateDictType = StateDictType + torch.distributed.fsdp.fully_sharded_data_parallel.TrainingState_ = TrainingState_ + torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch = BackwardPrefetch + torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload = CPUOffload + torch.distributed.fsdp.fully_sharded_data_parallel.MixedPrecision = MixedPrecision + torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy = ShardingStrategy + torch.distributed.fsdp.fully_sharded_data_parallel._PARAM_BROADCAST_BUCKET_SIZE = _PARAM_BROADCAST_BUCKET_SIZE + torch.distributed.fsdp.fully_sharded_data_parallel.FSDP_PREFIX = FSDP_PREFIX + torch.distributed.fsdp.fully_sharded_data_parallel.FSDP_WRAPPED_MODULE = FSDP_WRAPPED_MODULE + torch.distributed.fsdp.fully_sharded_data_parallel.__all__ = __all__ + torch.distributed.fsdp.fully_sharded_data_parallel._TORCH_FX_AVAIL = _TORCH_FX_AVAIL + torch.distributed.fsdp.fully_sharded_data_parallel._TORCHDISTX_AVAIL = _TORCHDISTX_AVAIL diff --git a/torch_npu/distributed/fsdp/sharded_grad_scaler.py b/torch_npu/distributed/fsdp/sharded_grad_scaler.py new file mode 100644 index 0000000000..84350bce7a --- /dev/null +++ b/torch_npu/distributed/fsdp/sharded_grad_scaler.py @@ -0,0 +1,355 @@ +from collections import abc, defaultdict +import logging +from typing import Dict, List, Optional, Union + +import torch +import torch_npu +from torch.cuda import FloatTensor # type: ignore[attr-defined] +from torch_npu.npu.amp.grad_scaler import GradScaler, OptState, _NpuMultiDeviceReplicator +from torch.distributed.distributed_c10d import ProcessGroup +import torch.distributed as dist +from torch.optim.sgd import SGD + + +def _refresh_per_optimizer_state(): + return {"stage": OptState.READY, "found_inf_per_device": {}} + + +def _is_supported_device(tensor: torch.Tensor): + return tensor.is_npu or tensor.device.type in ("xla", "cpu") + + +class _GeneralMultiDeviceReplicator(_NpuMultiDeviceReplicator): + """ + Lazily serves tensor to request device. This class extends + _NpuMultiDeviceReplicator to allow support for "cpu" as a device. + """ + def __init__(self, master_tensor: torch.Tensor) -> None: + assert _is_supported_device(master_tensor) + self.master = master_tensor + self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} + + +class ShardedGradScaler(GradScaler): + """ + ShardedGradScaler helps perform gradient scaling in a shard aware manner. It extends + functionality from GradScaler: + * Suports Pytorch DDP and FSDP implementations + * Support CPU offloaded tensors (as used in fully sharded data parallel[FSDP]) + * Supports the custom Mixed Precision loss dtype (fp16, bf16) that FSDP returns + * Sync inf/nan for scaled gradient tensors on any torch.device (where tensors are placed) across + nodes + + Example:: + + # Creates a ShardedGradScaler once at the beginning of training. + scaler = ShardedGradScaler() + + for epoch in epochs: + for input, target in data: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + + # Scales loss. Calls backward() on scaled loss to create scaled gradients. + scaler.scale(loss).backward() + + # scaler.step() first unscales gradients of the optimizer's params. + # If gradients don't contain infs/NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) + + # Updates the scale for next iteration. + scaler.update() + + See :class:`GradScaler` for explanation of scaling/unscaling and more use cases. + + Args: + init_scale (float, optional, default=2.**16): Initial scale factor. + growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during + :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. + backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during + :meth:`update` if inf/NaN gradients occur in an iteration. + growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients + that must occur for the scale to be multiplied by ``growth_factor``. + enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply + invokes the underlying ``optimizer.step()``, and other methods become no-ops. + Default: ``True`` + process_group (ProcessGroup, optional, default=torch.distributed.group.WORLD): + process group for sharding + """ + def __init__( + self, + init_scale: float = 2.0 ** 16, + backoff_factor: float = 0.5, + growth_factor: float = 2.0, + growth_interval: int = 2000, + enabled: bool = True, + process_group: Optional[ProcessGroup] = dist.group.WORLD, + ): + super().__init__( + init_scale=init_scale, + backoff_factor=backoff_factor, + growth_factor=growth_factor, + growth_interval=growth_interval, + enabled=enabled, + ) + if self._enabled: + self.process_group = process_group + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def scale(self, outputs: Union[torch.Tensor, List[torch.Tensor]]) -> Union[torch.Tensor, List[torch.Tensor]]: + if not self._enabled: + return outputs + + if self._dist_overflow_count is None: + self._lazy_init_dist_flag_and_dist_overflow_count() + assert self._dist_overflow_count is not None + + if self._dynamic and not self._clear_overflow_flag: + if not torch_npu.npu.utils.is_support_inf_nan(): + GradScaler.clear_npu_overflow_flag() + self._clear_overflow_flag = True + + if isinstance(outputs, torch.Tensor): + assert _is_supported_device(outputs) + if self._scale is None: + self._lazy_init_scale_growth_tracker(outputs.device) + assert self._scale is not None + scaled_output = outputs * self._scale.to(device=outputs.device, non_blocking=True) + # Here we ensure the return dtype is the same as the outputs dtype. + # For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision + # format (fp16, bf16) and so the scaled loss should be of the same dtype. + return scaled_output.type(outputs.dtype) + + stash: List[_GeneralMultiDeviceReplicator] = [] + + def apply_scale(val: Union[torch.Tensor, abc.Iterable]) -> Union[torch.Tensor, abc.Iterable]: + if isinstance(val, torch.Tensor): + assert _is_supported_device(val) + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + assert self._scale is not None + stash.append(_GeneralMultiDeviceReplicator(self._scale)) + scaled_val = val * stash[0].get(val.device) + # Here we ensure the return dtype is the same as the outputs dtype. + # For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision + # format (fp16, bf16) and so the scaled loss should be of the same dtype. + return scaled_val.type(val.dtype) + elif isinstance(val, abc.Iterable): + iterator = map(apply_scale, val) + if isinstance(val, (list, tuple)): + return type(val)(iterator) + else: + return iterator + else: + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) # type: ignore[return-value] + + def _foreach_non_finite_check_and_unscale_cpu_( + self, grads: List, found_inf: torch.Tensor, inv_scale: torch.Tensor + ) -> None: + if len(grads) == 0: + return + assert inv_scale.numel() == 1, "inv_scale must be a 1-element tensor." + assert found_inf.numel() == 1, "found_inf must be a 1-element tensor." + + expected_device = grads[0].device + for grad in grads: + for tensor in grad: + if tensor.device != expected_device: + logging.error("tensor device is %s and expected device is %s" % (tensor.device, expected_device)) + raise ValueError("Gradients must be on the same device.") + + # check for non_overlapping_and_dense doesn't exist in the python world + # we assume tensor is not MTA(multi tensor apply) safe. iterate through each item regardless of dtype + if torch.isinf(tensor).any().item() is True or torch.isnan(tensor).any().item() is True: + found_inf.data = torch.tensor([1.0]) + break + else: + tensor.data *= inv_scale.item() + + def _unscale_grads_( + self, optimizer: SGD, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool = True + ) -> Dict[torch.device, torch.Tensor]: + per_device_inv_scale = _GeneralMultiDeviceReplicator(inv_scale) + per_device_found_inf = _GeneralMultiDeviceReplicator(found_inf) + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be thousands of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] + with torch.no_grad(): + if hasattr(optimizer, 'is_npu_fused_optimizer') and optimizer.is_npu_fused_optimizer: + if not optimizer.is_params_grads_combined: + optimizer._maybe_init_combined_params_and_grads() + + device = found_inf.device + for grads_combined_one_dtype in optimizer.grads_all_group_combined: + if grads_combined_one_dtype is None: + continue + if self._dynamic: + torch._amp_foreach_non_finite_check_and_unscale_( + [grads_combined_one_dtype], + per_device_found_inf.get(device), + per_device_inv_scale.get(device)) + if per_device_found_inf.get(device)[0].item() > 0: + self._has_overflow = True + else: + grads_combined_one_dtype.mul_( + per_device_inv_scale.get(device)) + else: + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue + if (not allow_fp16) and param.grad.dtype == torch.float16: + raise ValueError("Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is torch.float16: + # coalesce is not suported in torch.float16 + param_grad_fp32 = param.grad.type(torch.float32).coalesce() + param.grad = param_grad_fp32.type(torch.float16) + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale) + + for device, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + if grads[0].device.type == "cpu": + self._foreach_non_finite_check_and_unscale_cpu_( + grads, + per_device_found_inf.get(device), + per_device_inv_scale.get(device), + ) + else: + torch._amp_foreach_non_finite_check_and_unscale_( + grads, + per_device_found_inf.get(device), + per_device_inv_scale.get(device), + ) + self._sync_dist_overflow_count() + if self._has_overflow: + per_device_found_inf.get(found_inf.device).add_(1) + else: + per_device_found_inf.get(found_inf.device) + + return per_device_found_inf._per_device_tensors + + def unscale_(self, optimizer: SGD) -> None: + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.UNSCALED: + raise RuntimeError("unscale_() has already been called on this optimizer since the last update().") + elif optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None + inv_scale = self._scale.double().reciprocal().float() + found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device) + + optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, True) + optimizer_state["stage"] = OptState.UNSCALED + + # Synchronize the detected inf across the ranks + optimizer_state = self._per_optimizer_states[id(optimizer)] + + def step(self, optimizer: SGD, *args, **kwargs) -> Optional[float]: + return super().step(optimizer, *args, **kwargs) + + def _amp_update_scale_cpu_(self, found_inf) -> None: + """ + If found_inf is 1.0 (True), then scale is multiplied by backoff_factor and growth_tracker is set to zero. + Otherwise, scale is multiplied by the growth factor when the growth interval is reached. + """ + if found_inf.item() >= 1.0: + self._scale *= self._backoff_factor # type: ignore[arg-type] + self._growth_tracker = 0 + else: + successful = self._growth_tracker + 1 # type: ignore[operator] + if successful == self._growth_interval: # type: ignore[arg-type] + self._scale *= self._growth_factor # type: ignore[arg-type] + self._growth_tracker = 0 + else: + self._growth_tracker = successful + + def update(self, new_scale: Optional[Union[float, FloatTensor]] = None) -> None: + """ + Updates the scale factor. + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + Args: + new_scale (float or :class:`torch.npu.FloatTensor`, optional, default=None): New scale factor. + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + """ + + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") # type: ignore[var-annotated] + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale = torch.full((1,), new_scale, dtype=torch.float32) + self._scale = self._scale.pin_memory().to(_scale.device, non_blocking=True) + else: + reason = "new_scale should be a float or a 1-element torch.npu.FloatTensor with requires_grad=False." + assert isinstance(new_scale, torch.npu.FloatTensor), reason # type: ignore[attr-defined] + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale.copy_(new_scale) # type: ignore[union-attr] + elif self._dynamic: + self._npu_update_scale() + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [ + found_inf.to(device=_scale.device, non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values() + ] + + assert len(found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + self._amp_update_scale_cpu_(found_inf_combined) + + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + self._has_overflow = False + self._clear_overflow_flag = False + + +def apply_fsdp_shard_grad_scaler(): + torch.distributed.fsdp.shard_grad_scaler.ShardedGradScaler = ShardedGradScaler + torch.distributed.fsdp.shard_grad_scaler._GeneralMultiDeviceReplicator = _GeneralMultiDeviceReplicator + torch.distributed.fsdp.shard_grad_scaler._is_supported_device = _is_supported_device + torch.distributed.fsdp.shard_grad_scaler._refresh_per_optimizer_state = _refresh_per_optimizer_state diff --git a/torch_npu/distributed/fsdp/utils.py b/torch_npu/distributed/fsdp/utils.py new file mode 100644 index 0000000000..d5bf87736c --- /dev/null +++ b/torch_npu/distributed/fsdp/utils.py @@ -0,0 +1,148 @@ +import dataclasses +import traceback +from collections import OrderedDict +from typing import Any, Callable, Dict, List, Set, Tuple, Union + +import torch +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parallel.scatter_gather import ( # type: ignore[attr-defined] + _is_namedtuple, +) +from torch.nn.utils.rnn import PackedSequence + +"""Useful functions to deal with tensor types with other python container types.""" + +__all__ = ["p_assert"] + +def _contains_batchnorm(module): + return any( + isinstance(mod, _BatchNorm) for mod in module.modules() + ) + + +def _override_batchnorm_mixed_precision(module): + for mod in module.modules(): + if isinstance(mod, _BatchNorm): + mod._wrap_overrides = {"mixed_precision": None} # type: ignore[assignment] + + +def _apply_to_tensors( + fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence] +) -> Any: + """Recursively apply to all tensor in different kinds of container types.""" + + def apply(x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]) -> Any: + if torch.is_tensor(x): + return fn(x) + elif hasattr(x, "__dataclass_fields__"): + dc = dataclasses.replace(x) + for f in dataclasses.fields(dc): + name = f.name + setattr(dc, name, apply(getattr(dc, name))) + return dc + elif isinstance(x, OrderedDict): + od = x.__class__() + for key, value in x.items(): + od[key] = apply(value) + return od + elif isinstance(x, PackedSequence): + apply(x.data) + return x + elif isinstance(x, dict): + return {key: apply(value) for key, value in x.items()} + elif _is_namedtuple(x): + res = (apply(el) for el in x) + return type(x)(*res) + elif isinstance(x, (list, tuple, set)): + return type(x)(apply(el) for el in x) + else: + return x + + return apply(container) + + +def _apply_to_modules( + root_module: torch.nn.Module, + module_fn: Callable, + return_fn: Callable, + *args, + **kwargs, +): + """ + Performs a pre-order traversal of the modules in the hierarchy rooted at + ``root_module``, applying ``module_fn`` at each module and finally + returning a value using ``return_fn``. The traversal constructs the full + module prefix name (e.g. "module.submodule." just like in model state dict) + and makes that available to ``module_fn``. + """ + def f(module: torch.nn.Module, prefix: str, *args, **kwargs): + # Call the module function before recursing over children (pre-order) + module_fn(module, prefix, *args, **kwargs) + for submodule_name, submodule in module.named_children(): + if submodule is not None: + new_prefix = prefix + submodule_name + "." + f(submodule, new_prefix, *args, **kwargs) + + f(root_module, "", *args, **kwargs) + return return_fn(*args, **kwargs) + + +@torch.no_grad() +def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool: + """ + Allocate storage for ``tensor`` with the given size. + + Returns: + bool: ``True`` if this method allocated storage and ``False`` if the + storage was already allocated. + """ + already_allocated = tensor.storage().size() == size.numel() + if not already_allocated: + tensor_storage_size = tensor.storage().size() + p_assert( + tensor_storage_size == 0, + f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}", + ) + tensor.storage().resize_(size.numel()) + return not already_allocated + + +@torch.no_grad() +def _free_storage(tensor: torch.Tensor) -> bool: + """ + Frees the underlying storage of ``tensor``. + + Returns: + bool: ``True`` if the method freed the storage and ``False`` if the + storage was already freed. + """ + already_freed = tensor.storage().size() == 0 + if not already_freed: + p_assert( + tensor.storage_offset() == 0, + "Freeing a tensor's storage is unsafe when it is not the sole occupant", + ) + tensor.storage().resize_(0) + return not already_freed + + +def p_assert(cond: Any, s: Any, raise_assertion_error: bool = True) -> None: + """This is used as an alternate to ``assert`` when in the backward context + to print the error message ``s`` since otherwise, it is swallowed.""" + if not cond: + print(s) + traceback.print_stack() + if raise_assertion_error: + raise AssertionError + + +def apply_fsdp_utils(): + torch.distributed.fsdp.utils.p_assert = p_assert + torch.distributed.fsdp.utils._free_storage = _free_storage + torch.distributed.fsdp.utils._alloc_storage = _alloc_storage + torch.distributed.fsdp.utils._apply_to_modules = _apply_to_modules + torch.distributed.fsdp.utils._apply_to_tensors = _apply_to_tensors + torch.distributed.fsdp.utils._override_batchnorm_mixed_precision = _override_batchnorm_mixed_precision + torch.distributed.fsdp.utils._contains_batchnorm = _contains_batchnorm + torch.distributed.fsdp.utils.__all__ = __all__ + diff --git a/torch_npu/distributed/fsdp/wrap.py b/torch_npu/distributed/fsdp/wrap.py new file mode 100644 index 0000000000..525b774714 --- /dev/null +++ b/torch_npu/distributed/fsdp/wrap.py @@ -0,0 +1,482 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Dict, + Generator, + Optional, + Set, + Tuple, + Type, + cast, +) +import torch +import torch.nn as nn +from torch.nn.modules.batchnorm import _BatchNorm + + +__all__ = [ + "always_wrap_policy", + "lambda_auto_wrap_policy", + "transformer_auto_wrap_policy", + "size_based_auto_wrap_policy", + "enable_wrap", + "wrap", + "ParamExecOrderWrapPolicy", +] + + +def always_wrap_policy(*args, **kwargs) -> bool: + """ + A simple wrapper policy that always returns ``True``, + i.e. when passed as the `auto_wrap_policy` into FSDP, + this will result in all submodules being wrapped as + distinct FSDP instances. + """ + return True + +def lambda_auto_wrap_policy( + module: nn.Module, + recurse: bool, + unwrapped_params: int, + lambda_fn: Callable +) -> bool: + """ + A convenient auto wrap policy to wrap submodules based on an arbitrary user + function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as + a `wrapper_cls` unit. + + Return if a module should be wrapped during auto wrapping. + + The first three parameters are required by :func:`_recursive_wrap`. + + Args: + module (nn.Module): + The module to be considered in this decision. + recurse (bool): + Indicate if this is called to make a decision on whether we + should recurse down a subgraph of the module structure. + If False, it means this function is called to make a decision + on whether we should wrap the said module. + unwrapped_params (int): + The number of parameters yet to be wrapped in this module. + + lambda_fn (Callable[nn.Module] -> bool): + If this returns ``True``, this module will be wrapped by + wrapper_cls individually. + """ + if recurse: + # always recurse + return True + else: + # if not recursing, decide whether we should wrap for the leaf node or reminder + return lambda_fn(module) + +def transformer_auto_wrap_policy( + module: nn.Module, + recurse: bool, + unwrapped_params: int, + transformer_layer_cls: Set[Type[nn.Module]], +) -> bool: + """ + A convenient auto wrap policy for transformer models. If the submodule + is an instance of transformer_layer_cls, the submodule will be wrapped + as a FSDP unit. Otherwise, all the other remainder submodules are wrapped + by the outermost FSDP unit. Right now, FSDP requires submodules that share + weights to be wrapped in the same FSDP unit, this auto wrap policy can + conviniently wrap the shared embeddings into the same FSDP unit for transformer + models. In the near future, FSDP will support submodules that share weights + to be wrapped in the separated FSDP units. + + Return if a module should be wrapped during FSDP auto wrapping. + + The first three parameters are required by :func:`_recursive_wrap`. + + + Args: + module (nn.Module): + The module to be considered in this decision. + recurse (bool): + Indicate if this is called to make a decision on whether we + should recurse down a subgraph of the module structure. + If False, it means this function is called to make a decision + on whether we should wrap the said module. + unwrapped_params (int): + The number of parameters yet to be wrapped in this module. + + transformer_layer_cls (int): + Submodules with one of the `transformer_layer_cls` names + will be wrapped as separated FSDP units + """ + if recurse: + # always recurse + return True + else: + # if not recursing, decide whether we should wrap for the leaf node or reminder + return isinstance(module, tuple(transformer_layer_cls)) + +def _wrap_batchnorm_individually( + module: nn.Module, + recurse: bool, + *args, + **kwargs, +) -> bool: + """ + A policy that wraps ``BatchNorm`` instances in their own FSDP unit. + """ + if recurse: + # always recurse + return True + else: + # if not recursing, decide whether we should wrap based on whether it is a + # BN layer or not. + return isinstance(module, _BatchNorm) + +def _or_policy( + module: nn.Module, + recurse: bool, + unwrapped_params: int, + policies, +) -> bool: + """ + A policy that wraps ``module`` if any policy in the passed in iterable of + ``policies`` returns ``True``. + """ + return any( + policy(module, recurse, unwrapped_params) for policy in policies + ) + + +def size_based_auto_wrap_policy( + module: nn.Module, + recurse: bool, + unwrapped_params: int, + # These are customizable for this policy function. + min_num_params: int = int(1e8), + force_leaf_modules: Optional[Set[Type[nn.Module]]] = None, + exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None, +) -> bool: + """A size based auto_wrap_policy function for FSDP API. + + Return if a module should be wrapped during FSDP auto wrapping. + + The first three parameters are used by :func:`_recursive_wrap`. If + you write a custom version of this policy function, your version + needs to at least accept the first three parameters and free + to do whatever you want in the function. + + Args: + module (nn.Module): + The module to be considered in this decision. + recurse (bool): + Indicate if this is called to make a decision on whether we + should recurse down a subgraph of the module structure. + If False, it means this function is called to make a decision + on whether we should wrap the said module. + unwrapped_params (int): + The number of parameters yet to be wrapped in this module. + + min_num_params (int): + Customizable policy input. It controls the size threshold + on how big should a module be to be considered wrapped. + force_leaf_modules (Set[Type[nn.Module]]): set of module types to + keep as leaves, i.e., their children will never be wrapped. + exclude_wrap_modules (Set[Type[nn.Module]]): + Customizable set of module types to be excluded in wrapping. + """ + force_leaf_modules = ( + size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined] + if force_leaf_modules is None + else force_leaf_modules + ) + exclude_wrap_modules = ( + size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore[attr-defined] + if exclude_wrap_modules is None + else exclude_wrap_modules + ) + + is_large = unwrapped_params >= min_num_params + if recurse: + # We should recurse if the module is big enough but not in force_leaf_modules list. + return is_large and not isinstance(module, tuple(force_leaf_modules)) + else: + # If we are not recursing, determine if we should wrap. + return is_large and not isinstance(module, tuple(exclude_wrap_modules)) + + +# Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported. +size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore[attr-defined] +size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore[attr-defined] + + +@contextlib.contextmanager +def enable_wrap( + *, wrapper_cls: Any, **wrapper_kwargs: Any +) -> Generator[None, None, None]: + """ + Context manager to wrap modules using a wrapper. + + Useful for when you'd like to apply the same configuration arguments to all + child modules that you wrap. A particularly important use case is wrapping + large layers so that they get sharded (in-place) during initialization, to + avoid running out of system memory. Large layers can indicate that they + should be sharded via the ``wrap`` annotation and this context manager can + provide the exact configuration for these nested instances. + + Usage:: + + with enable_wrap(wrapper_cls, **params): + # Wraps layer in FSDP by default if within context + self.l1 = wrap(torch.nn.Linear(5, 5)) + + Args: + wrapper_cls: + Class that `wrap` annotation will `wrap` modules with, such as + `FullyShardedDataParallel`. + **wrapper_kwargs: + Configuration settings that will be passed to all ``wrap`` + instances inside the context + """ + kwargs = { + **{"wrapper_cls": wrapper_cls}, + **wrapper_kwargs, + } + with _ConfigAutoWrap(**kwargs): + yield + + +def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module: + """ + Annotate that a module should be wrapped. Annotated modules will only be + wrapped if inside of an :func:`enable_wrap` context manager. This allows + a module to be initialized both with and without a wrapper without code + change. + + The class that this function wraps the passed in ``nn.Module`` with is the + passed in ``wrapper_cls`` argument into ``enable_wrap``. Both + ``enable_wrap`` and ``wrap`` can take in kwargs specifying how to construct + the ``wrapper_cls`` instance. In the case of duplicate kwargs in + ``enable_wrap`` and ``wrap``, the argument passed into ``wrap`` will be + respected. + + Usage:: + + with enable_wrap(wrapper_cls=FSDP, **fsdp_config): + # Wraps layer in FSDP by default if within context + self.l1 = wrap(torch.nn.Linear(5, 5)) + + Args: + module (nn.Module): module to wrap (if in :func:`enable_wrap` context) + **wrap_overrides: configuration overrides that will take priority over + the values provided by the :func:`enable_wrap` context + """ + if _ConfigAutoWrap.in_autowrap_context: + assert _ConfigAutoWrap.wrapper_cls is not None + + wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides} + return _wrap( + module, + _ConfigAutoWrap.wrapper_cls, + **wrap_overrides, + ) + return module + + +@dataclass +class ParamExecOrderWrapPolicy: + """ + This is the class used for the wrapping policy that wraps parameters and performs + the communication scheduling based on the parameter execution order in the forward pass + (also called non-recursive wrapping policy). + + The policy contains multiple wraps. Each wrap contains original parameters that will be executed together, + and the wrap transfers these parameters into one ``FlattenParameter``. In both forward and the backward passes, + the sharded parameters in each wrap will be gathered just before these parameters are used in the passes. + These parameters will then be reshaded once they have been used. + + TODO (linjianma): For now, the parameters contained in each wrap of ``ParamExecOrderWrapPolicy`` + are the parameters in each wrap of the ``init_policy`` (a recursive wrapping policy). + Later we will wrap parameters based on bucket size. + + Args: + init_policy (Callable): + The initial recursive wrapping policy used to guide the wrapping of + this policy. If tracing_config is none, in the first forward and + backward iteration, ``init_policy`` is used to record parameter + execution order. Otherwise, init_policy is only used in FSDP + constructor for module level wrapping. + + The default ``always_wrap_policy`` might not be the best choice for every model. For example, for + transformer based models, setting ``transformer_auto_wrap_policy`` as the ``init_policy`` will guarantee + wrapping each transformer layer into one FSDP unit, and can be easily combined with checkpointing + within each transformer layer. + + tracing_config (Optional[TracingConfig]): + The configuration used to perform symbolic tracing at FSDP + constructor to get the module and parameter execution order. The + type of ``tracing_config`` needs to be either ``None`` or + ``TracingConfig``. If set as ``None``, then symbolic tracing is not + enabled, and one forward as well as backward iteration are needed to + get the parameter execution order. + + ..warning :: Note that not all modules can be successfully traced when + ``tracing_config`` is not None and symbolic tracing is enabled. The two + cases below may be unable to trace: 1. when there is a data-dependent + branch, 2. when the forward pass contains operators that don't support + ``torch.fx.Proxy`` as the input type (e.g. ``arange``, ``zeros``, ``ones``, + ``full``, ``full_like``, ``eye``, ``empty``, ``tensor``). For those cases, + users can set ``tracing_config = None`` to disable symbolic tracing. + """ + init_policy: Callable = always_wrap_policy + tracing_config: Any = None + + +def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module: + assert wrapper_cls is not None + if hasattr(module, '_wrap_overrides'): + # If module has a _wrap_overrides attribute, we force overriding the + # FSDP config with these attributes for this module. Currently this + # is only used to disable mixed precision for BatchNorm when + # auto_wrapping. + overrides = {**kwargs, **module._wrap_overrides} # type: ignore[arg-type] + return wrapper_cls(module, **overrides) + + return wrapper_cls(module, **kwargs) + + +def _recursive_wrap( + module: nn.Module, + auto_wrap_policy: Callable, + wrapper_cls: Callable, + ignored_modules: Set[nn.Module], + ignored_params: Set[nn.Parameter], + only_wrap_children: bool = False, + **kwargs: Any +) -> Tuple[nn.Module, int]: + """ + Automatically wrap child modules of *module* that meet the given + criteria with :func:`auto_wrap`. Does not rely on _ConfigAutoWrap. + Args: + module (nn.Module): + module to recursively wrap + auto_wrap_policy (Callable): + A callable specifying a policy to recursively wrap layers with FSDP. + ignored_modules (Set[torch.nn.Module]): Modules to ignore when + wrapping. + ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when + wrapping; these should be the parameters contained in the modules + in ``ignored_modules``. + Returns: + (nn.Module, int): + Wrapped module and the number parameters wrapped recursively. + """ + assert auto_wrap_policy is not None, "Must specify auto_wrap_policy." + assert wrapper_cls is not None, "Must specify wrapper_cls" + # Make sure no child is already wrapped. + for _, child in module.named_modules(): + if child in ignored_modules: + continue + try: + assert not isinstance(child, cast(type, wrapper_cls)) + except TypeError: + # wrapper_cls is a function as opposed to a class type, just bypass above check. + pass + + # We count all params, assuming none of them are already wrapped. + num_params = sum( + p.numel() for p in module.parameters() if p not in ignored_params + ) + + assert auto_wrap_policy is not None + if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params): + total_wrapped_params = 0 + # Iterate through the children, recursively wrap if necessary + for name, child in module.named_children(): + if child in ignored_modules: + continue + wrapped_child, num_wrapped_params = _recursive_wrap( + module=child, + auto_wrap_policy=auto_wrap_policy, + wrapper_cls=wrapper_cls, + ignored_modules=ignored_modules, + ignored_params=ignored_params, + **kwargs, + ) + setattr(module, name, wrapped_child) + # Keep track of how many parameters have been wrapped + total_wrapped_params += num_wrapped_params + # decide if we need to wrap the current module, + # since the left over parameters exceed the number of params to wrap + remainder = num_params - total_wrapped_params + if not only_wrap_children and auto_wrap_policy( + module=module, recurse=False, unwrapped_params=remainder + ): + # Leaf node or final wrapping of the remainder both happen here. + return _wrap(module, wrapper_cls, **kwargs), num_params + else: + return module, total_wrapped_params + return module, 0 + + +class _ConfigAutoWrap: + """ + Helper class to wrap modules based on default config args via a context manager. + See :func:`enable_wrap` for more information. + """ + + in_autowrap_context: bool = False # Context flag + wrapper_cls: Optional[Callable] = None # The wrapper class + kwargs: Dict[str, Any] = {} # Wrapper's args + + def __init__(self, **kwargs: Dict[str, Any]): + self.kwargs = kwargs + + @staticmethod + def enable_autowrap_context(kwargs: Any) -> None: + if _ConfigAutoWrap.in_autowrap_context: + raise NotImplementedError( + "You are already within an autowrap context and we currently do not supported nested autowrap." + ) + _ConfigAutoWrap.in_autowrap_context = True + # Get and save the wrapper cls for the context. + assert ( + "wrapper_cls" in kwargs.keys() + ), "Expected to pass in wrapper_cls arg into _ConfigAutoWrap." + _ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"]) + del kwargs["wrapper_cls"] + # Save the rest. + _ConfigAutoWrap.kwargs = kwargs + + @staticmethod + def disable_autowrap_context() -> None: + _ConfigAutoWrap.in_autowrap_context = False + _ConfigAutoWrap.wrapper_cls = None + _ConfigAutoWrap.kwargs = {} + + def __enter__(self) -> None: + self.enable_autowrap_context(self.kwargs) + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.disable_autowrap_context() + + +def apply_fsdp_wrap(): + torch.distributed.fsdp.wrap._ConfigAutoWrap = _ConfigAutoWrap + torch.distributed.fsdp.wrap._recursive_wrap = _recursive_wrap + torch.distributed.fsdp.wrap._wrap = _wrap + torch.distributed.fsdp.wrap.wrap = wrap + torch.distributed.fsdp.wrap.enable_wrap = enable_wrap + torch.distributed.fsdp.wrap.size_based_auto_wrap_policy = size_based_auto_wrap_policy + torch.distributed.fsdp.wrap._or_policy = _or_policy + torch.distributed.fsdp.wrap._wrap_batchnorm_individually = _wrap_batchnorm_individually + torch.distributed.fsdp.wrap.transformer_auto_wrap_policy = transformer_auto_wrap_policy + torch.distributed.fsdp.wrap.lambda_auto_wrap_policy = lambda_auto_wrap_policy + torch.distributed.fsdp.wrap.always_wrap_policy = always_wrap_policy + torch.distributed.fsdp.wrap.__all__ = __all__ + diff --git a/torch_npu/distributed/utils.py b/torch_npu/distributed/utils.py new file mode 100644 index 0000000000..a860e9d696 --- /dev/null +++ b/torch_npu/distributed/utils.py @@ -0,0 +1,192 @@ +import torch +import torch_npu +import torch.distributed as dist +from torch.nn.parallel._functions import _get_stream +from torch.nn.parallel.scatter_gather import ( # type: ignore[attr-defined] + _is_namedtuple +) +from typing import Any, Dict, List, Tuple + +__all__ = [] # type: ignore[var-annotated] + +def _pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Tuple[str, ...]]: + """ + Turn argument list into separate key list and value list (unpack_kwargs does the opposite) + Inspiration: https://github.com/facebookresearch/fairscale/blob/eeb6684/fairscale/internal/containers.py#L70 + Usage:: + + kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) + assert kwarg_keys == ("a", "b") + assert flat_args == (1, 2, 3, 4) + args, kwargs = unpack_kwargs(kwarg_keys, flat_args) + assert args == (1, 2) + assert kwargs == {"a": 3, "b": 4} + Returns: + Tuple[Tuple[Any, ...], Tuple[str, ...]]: The first tuple element gives + gives both positional args and kwarg values, where the positional args + proceed kwarg values and kwarg values are ordered consistently with the + kwarg keys. The second tuple element gives the kwarg keys. + The second tuple element's length is at most the first tuple element's length. + """ + kwarg_keys: List[str] = [] + flat_args: List[Any] = list(args) + for k, v in kwargs.items(): + kwarg_keys.append(k) + flat_args.append(v) + + return tuple(flat_args), tuple(kwarg_keys) + + +def _unpack_kwargs(flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + """See _pack_kwargs.""" + assert len(kwarg_keys) <= len(flat_args), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" + if len(kwarg_keys) == 0: + return flat_args, {} + args = flat_args[: -len(kwarg_keys)] + kwargs = {k: v for k, v in zip(kwarg_keys, flat_args[-len(kwarg_keys) :])} + return args, kwargs + +def _recursive_to(inputs, target_gpu, use_side_stream_for_tensor_copies): + r""" + Recursively moves input to the target_gpu. + """ + + def to_map(obj): + if isinstance(obj, torch.Tensor): + if obj.device == torch.device("npu", target_gpu): + return (obj,) + if not use_side_stream_for_tensor_copies: + return (obj.to(target_gpu),) + else: + # Perform CPU -> GPU copies in a background stream. This code is + # motivated from similar logic in torch/nn/parallel/_functions.py + stream = _get_stream(target_gpu) + with torch_npu.npu.stream(stream): + output = obj.to(target_gpu) + # synchronize with the copy stream + with torch_npu.npu.device(target_gpu): + current_stream = torch_npu.npu.current_stream() + # Sync the current stream with the copy stream + current_stream.wait_stream(stream) + # Ensure tensor memory is not reused until work on + # main stream is complete + output.record_stream(current_stream) # type: ignore[arg-type] + return (output,) + if _is_namedtuple(obj): + return [type(obj)(*args) for args in zip(*map(to_map, obj))] + if isinstance(obj, tuple) and len(obj) > 0: + return list(zip(*map(to_map, obj))) + if isinstance(obj, list) and len(obj) > 0: + return [list(i) for i in zip(*map(to_map, obj))] + if isinstance(obj, dict) and len(obj) > 0: + return [type(obj)(i) for i in zip(*map(to_map, obj.items()))] + return [obj] + + # Avoid reference cycle + try: + res = to_map(inputs) + finally: + to_map = None # type: ignore[assignment] + return res + + +def _to_kwargs(inputs, kwargs, device_id, use_side_stream_for_tensor_copies): + inputs = ( + _recursive_to(inputs, device_id, use_side_stream_for_tensor_copies) + if inputs + else [] + ) + kwargs = ( + _recursive_to(kwargs, device_id, use_side_stream_for_tensor_copies) + if kwargs + else [] + ) + if len(inputs) < len(kwargs): + inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) + elif len(kwargs) < len(inputs): + kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) + inputs = tuple(inputs) + kwargs = tuple(kwargs) + return inputs, kwargs + +def _verify_param_shape_across_processes(process_group, tensors, logger=None): + return dist._verify_params_across_processes(process_group, tensors, logger) + +def _sync_module_states( + module, + process_group, + broadcast_bucket_size, + src, + params_and_buffers_to_ignore, +): + """ + Syncs ``module``'s parameters and buffers state so that all ranks contain + the same module state across all ranks. Note that this API assumes that all + parameter shapes are consistent before running the synchronization. This can + be checked with ``_verify_param_shape_across_processes``. + """ + module_states = [] + for name, param in module.named_parameters(): + if name not in params_and_buffers_to_ignore: + module_states.append(param.detach()) + + for name, buffer in module.named_buffers(): + if name not in params_and_buffers_to_ignore: + module_states.append(buffer.detach()) + + _sync_params_and_buffers( + process_group, + module_states, + broadcast_bucket_size, + src + ) + +def _sync_params_and_buffers( + process_group: dist.ProcessGroup, + module_states: List[torch.Tensor], + broadcast_bucket_size: int, + src: int, +): + """ + Synchronizes ``module_states`` (list of tensors) across all processes by + broadcasting them from rank 0. + """ + if len(module_states) > 0: + dist._broadcast_coalesced( + process_group, module_states, broadcast_bucket_size, src + ) + +def _replace_by_prefix( + state_dict: Dict[str, Any], + old_prefix: str, + new_prefix: str, +) -> None: + """ + Replace all keys that match a given old_prefix with a new_prefix (in-place). + + Usage:: + + state_dict = {"layer.xyz": torch.tensor(1)} + replace_by_prefix_(state_dict, "layer.", "module.layer.") + assert state_dict == {"module.layer.xyz": torch.tensor(1)} + """ + if old_prefix == new_prefix: + raise ValueError("old_prefix and new_prefix must be distinct") + for key in list(state_dict.keys()): + if not key.startswith(old_prefix): + continue + new_key = new_prefix + key[len(old_prefix) :] + state_dict[new_key] = state_dict[key] + del state_dict[key] + + +def apply_utils_func(): + torch.distributed.utils._replace_by_prefix = _replace_by_prefix + torch.distributed.utils._sync_params_and_buffers = _sync_params_and_buffers + torch.distributed.utils._sync_module_states = _sync_module_states + torch.distributed.utils._verify_param_shape_across_processes = _verify_param_shape_across_processes + torch.distributed.utils._to_kwargs = _to_kwargs + torch.distributed.utils._recursive_to = _recursive_to + torch.distributed.utils._unpack_kwargs = _unpack_kwargs + torch.distributed.utils._pack_kwargs = _pack_kwargs + torch.distributed.utils.__all__ = __all__ diff --git a/torch_npu/nn/modules/module.py b/torch_npu/nn/modules/module.py new file mode 100644 index 0000000000..d807fb91be --- /dev/null +++ b/torch_npu/nn/modules/module.py @@ -0,0 +1,162 @@ +from collections import OrderedDict, namedtuple + +import torch +from torch.nn.parameter import Parameter +import torch.utils.hooks as hooks + +from torch import Tensor +from typing import Tuple, Any, Callable, Set, Optional, Mapping, Dict, List +from torch.utils.hooks import RemovableHandle +from torch.nn.modules.module import _IncompatibleKeys +from torch.nn.modules.module import Module + +def __init__(self) -> None: + """ + Initializes internal Module state, shared by both nn.Module and ScriptModule. + """ + torch._C._log_api_usage_once("python.nn_module") + + self.training = True + self._parameters: Dict[str, Optional[Parameter]] = OrderedDict() + self._buffers: Dict[str, Optional[Tensor]] = OrderedDict() + self._non_persistent_buffers_set: Set[str] = set() + self._backward_hooks: Dict[int, Callable] = OrderedDict() + self._is_full_backward_hook = None + self._forward_hooks: Dict[int, Callable] = OrderedDict() + self._forward_pre_hooks: Dict[int, Callable] = OrderedDict() + self._state_dict_hooks: Dict[int, Callable] = OrderedDict() + self._load_state_dict_pre_hooks: Dict[int, Callable] = OrderedDict() + self._load_state_dict_post_hooks: Dict[int, Callable] = OrderedDict() + self._modules: Dict[str, Optional['Module']] = OrderedDict() + + +def __setstate__(self, state): + self.__dict__.update(state) + # Support loading old checkpoints that don't have the following attrs: + if '_forward_pre_hooks' not in self.__dict__: + self._forward_pre_hooks = OrderedDict() + if '_state_dict_hooks' not in self.__dict__: + self._state_dict_hooks = OrderedDict() + if '_load_state_dict_pre_hooks' not in self.__dict__: + self._load_state_dict_pre_hooks = OrderedDict() + if '_load_state_dict_post_hooks' not in self.__dict__: + self._load_state_dict_post_hooks = OrderedDict() + if '_non_persistent_buffers_set' not in self.__dict__: + self._non_persistent_buffers_set = set() + if '_is_full_backward_hook' not in self.__dict__: + self._is_full_backward_hook = None + + +def register_load_state_dict_post_hook(self, hook): + r"""Registers a post hook to be run after module's ``load_state_dict`` + is called. + + It should have the following signature:: + hook(module, incompatible_keys) -> None + + The ``module`` argument is the current module that this hook is registered + on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting + of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` + is a ``list`` of ``str`` containing the missing keys and + ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. + + The given incompatible_keys can be modified inplace if needed. + + Note that the checks performed when calling :func:`load_state_dict` with + ``strict=True`` are affected by modifications the hook makes to + ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either + set of keys will result in an error being thrown when ``strict=True``, and + clearning out both missing and unexpected keys will avoid an error. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._load_state_dict_post_hooks) + self._load_state_dict_post_hooks[handle.id] = hook + return handle + + +def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + r"""Copies parameters and buffers from :attr:`state_dict` into + this module and its descendants. If :attr:`strict` is ``True``, then + the keys of :attr:`state_dict` must exactly match the keys returned + by this module's :meth:`~torch.nn.Module.state_dict` function. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + strict (bool, optional): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + Note: + If a parameter or buffer is registered as ``None`` and its corresponding key + exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a + ``RuntimeError``. + """ + if not isinstance(state_dict, Mapping): + raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict))) + + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + error_msgs: List[str] = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = OrderedDict(state_dict) + if metadata is not None: + # mypy isn't aware that "_metadata" exists in state_dict + state_dict._metadata = metadata # type: ignore[attr-defined] + + def load(module, local_state_dict, prefix=''): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + for name, child in module._modules.items(): + if child is not None: + child_prefix = prefix + name + '.' + child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} + load(child, child_state_dict, child_prefix) + + # Note that the hook can modify missing_keys and unexpected_keys. + incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) + for hook in module._load_state_dict_post_hooks.values(): + out = hook(module, incompatible_keys) + assert out is None, ( + "Hooks registered with ``register_load_state_dict_post_hook`` are not" + "expected to return new values, if incompatible_keys need to be modified," + "it should be done inplace." + ) + + load(self, state_dict) + del load + + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in unexpected_keys))) + if len(missing_keys) > 0: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in missing_keys))) + + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(error_msgs))) + return _IncompatibleKeys(missing_keys, unexpected_keys) + + +def add_nn_module_api(): + Module.__setstate__ = __setstate__ + Module.__init__ = __init__ + Module._register_load_state_dict_post_hook = register_load_state_dict_post_hook + Module.load_state_dict = load_state_dict + -- Gitee From b9d9e157c3b6ab04474620824977c72e9a17847d Mon Sep 17 00:00:00 2001 From: wangqianren Date: Sun, 11 Jun 2023 12:26:46 +0800 Subject: [PATCH 3/3] fix mokey patch bugs of FSDP --- torch_npu/__init__.py | 29 +++++++++++++++++++ torch_npu/distributed/__init__.py | 26 ----------------- .../_checkpoint/checkpoint_wrapper.py | 4 +-- torch_npu/distributed/fsdp/__init__.py | 1 - torch_npu/distributed/fsdp/_optim_utils.py | 8 ++--- torch_npu/distributed/fsdp/_utils.py | 10 +++++-- .../fsdp/flatten_params_wrapper.py | 3 +- .../fsdp/fully_sharded_data_parallel.py | 6 ++-- .../distributed/fsdp/sharded_grad_scaler.py | 10 +++---- torch_npu/distributed/fsdp/utils.py | 11 +++++-- torch_npu/distributed/utils.py | 11 +++++-- torch_npu/nn/modules/module.py | 2 +- 12 files changed, 68 insertions(+), 53 deletions(-) diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index 68f83a8f95..d7580e59c9 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -68,6 +68,19 @@ from torch_npu.utils import apply_module_patch, add_tensor_methods, add_torch_fu from torch_npu.distributed.hccl_dtype_wraper import wrap_dtype_for_hccl from torch_npu.npu.amp.autocast_mode import apply_autocast_patch from torch_npu.nn.modules.module import add_nn_module_api +from torch_npu.distributed.fsdp import apply_fsdp_init +from torch_npu.distributed.fsdp._optim_utils import apply_fsdp_optim_utils +from torch_npu.distributed.fsdp._shard_utils import apply_fsdp_shard_utils +from torch_npu.distributed.fsdp.flat_param import apply_fsdp_flat_param_handle +from torch_npu.distributed.fsdp.flatten_params_wrapper import apply_fsdp_flatten_params_wrapper +from torch_npu.distributed.fsdp.fully_sharded_data_parallel import apply_fsdp +from torch_npu.distributed.fsdp.sharded_grad_scaler import apply_fsdp_sharded_grad_scaler +from torch_npu.distributed.fsdp.utils import apply_fsdp_utils +from torch_npu.distributed.fsdp.wrap import apply_fsdp_wrap +from torch_npu.distributed.algorithms._checkpoint.checkpoint_wrapper import apply_algorithms_checkpoint_wrapper +from torch_npu.distributed.algorithms._comm_hooks.default_hooks import apply_algorithms_comm_hooks_default_hooks +from torch_npu.distributed.algorithms._comm_hooks import apply_algorithms_comm_hooks_init +from torch_npu.distributed.utils import apply_utils_func from .version import __version__ as __version__ @@ -195,6 +208,9 @@ all_monkey_patches = [ ["nn", npu_modules], ["_C.Generator", torch_npu._C.Generator], ["device", torch_npu._C.device], + ["distributed.fsdp", torch_npu.distributed.fsdp], + ["distributed.algorithms", torch_npu.distributed.algorithms], + ["distributed.utils", torch_npu.distributed.utils], ] all_monkey_patches += serialization_patches @@ -244,6 +260,19 @@ def apply_class_patches(): add_checkpoint_methods() apply_autocast_patch() add_nn_module_api() + apply_fsdp_init() + apply_fsdp_optim_utils() + apply_fsdp_shard_utils() + apply_fsdp_flat_param_handle() + apply_fsdp_flatten_params_wrapper() + apply_fsdp() + apply_fsdp_sharded_grad_scaler() + apply_fsdp_utils() + apply_fsdp_wrap() + apply_algorithms_checkpoint_wrapper() + apply_algorithms_comm_hooks_default_hooks() + apply_algorithms_comm_hooks_init() + apply_utils_func() # Apply monkey-patches. diff --git a/torch_npu/distributed/__init__.py b/torch_npu/distributed/__init__.py index b2f49a93f7..83d79ec752 100644 --- a/torch_npu/distributed/__init__.py +++ b/torch_npu/distributed/__init__.py @@ -51,33 +51,7 @@ from .distributed_c10d import ( _rank_not_in_group, Logger, all_gather_object, broadcast_object_list, all_gather_togather, _reduce_scatter_base ) -from .fsdp import apply_fsdp_init -from .fsdp._optim_utils import apply_fsdp_optim_utils -from .fsdp._shard_utils import apply_fsdp_shard_utils -from .fsdp.flat_param import apply_fsdp_flat_param_handle -from .fsdp.flatten_params_wrapper import apply_fsdp_flatten_params_wrapper -from .fsdp.fully_sharded_data_parallel import apply_fsdp -from .fsdp.sharded_grad_scaler import apply_fsdp_shard_grad_scaler -from .fsdp.utils import apply_fsdp_utils -from .fsdp.wrap import apply_fsdp_wrap -from .algorithms._checkpoint.checkpoint_wrapper import apply_algorithms_checkpoint_wrapper -from .algorithms._comm_hooks.default_hooks import apply_algorithms_comm_hooks_default_hooks -from .algorithms._comm_hooks import apply_algorithms_comm_hooks_init -from .utils import apply_utils_func set_debug_level_from_env() -apply_fsdp_init() -apply_fsdp_optim_utils() -apply_fsdp_shard_utils() -apply_fsdp_flat_param_handle() -apply_fsdp_flatten_params_wrapper() -apply_fsdp() -apply_fsdp_shard_grad_scaler() -apply_fsdp_utils() -apply_fsdp_wrap() -apply_algorithms_checkpoint_wrapper() -apply_algorithms_comm_hooks_default_hooks() -apply_algorithms_comm_hooks_init() -apply_utils_func() diff --git a/torch_npu/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch_npu/distributed/algorithms/_checkpoint/checkpoint_wrapper.py index a53241f320..affb4ede7e 100644 --- a/torch_npu/distributed/algorithms/_checkpoint/checkpoint_wrapper.py +++ b/torch_npu/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -5,8 +5,8 @@ from typing import Any, Dict, Iterator, Tuple import torch import torch.nn as nn from torch.autograd.graph import save_on_cpu -from torch.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs -from torch.utils.checkpoint import checkpoint +from torch_npu.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs +from torch_npu.utils.checkpoint import checkpoint _CHECKPOINT_PREFIX = "_checkpoint_wrapped_module" diff --git a/torch_npu/distributed/fsdp/__init__.py b/torch_npu/distributed/fsdp/__init__.py index f7b0018cd8..eda1eab8c2 100644 --- a/torch_npu/distributed/fsdp/__init__.py +++ b/torch_npu/distributed/fsdp/__init__.py @@ -1,5 +1,4 @@ import torch -import torch_npu from .flat_param import FlatParameter from .fully_sharded_data_parallel import ( BackwardPrefetch, diff --git a/torch_npu/distributed/fsdp/_optim_utils.py b/torch_npu/distributed/fsdp/_optim_utils.py index 1451ca2456..8dbfb76098 100644 --- a/torch_npu/distributed/fsdp/_optim_utils.py +++ b/torch_npu/distributed/fsdp/_optim_utils.py @@ -19,12 +19,12 @@ import torch import torch_npu import torch.distributed as dist # Import the entire FSDP file to avoid circular imports -import .fully_sharded_data_parallel as FSDP +import torch_npu.distributed.fsdp.fully_sharded_data_parallel as FSDP import torch.nn as nn from torch.distributed._shard.sharded_tensor import ShardedTensor -from ._shard_utils import _gather_state_dict -from .flat_param import FlatParameter, FlatParamHandle -from ._fsdp_extensions import _ext_chunk_tensor +from torch_npu.distributed.fsdp._shard_utils import _gather_state_dict +from torch_npu.distributed.fsdp.flat_param import FlatParameter, FlatParamHandle +from torch_npu.distributed.fsdp._fsdp_extensions import _ext_chunk_tensor def sorted_items(dictionary: Dict[str, Any]) -> Iterator[Tuple[str, Any]]: diff --git a/torch_npu/distributed/fsdp/_utils.py b/torch_npu/distributed/fsdp/_utils.py index 80688e5dec..ff3b6c8c58 100644 --- a/torch_npu/distributed/fsdp/_utils.py +++ b/torch_npu/distributed/fsdp/_utils.py @@ -5,15 +5,19 @@ from typing import Any, Callable, Dict, List, Set, Tuple, Union import torch from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.parallel.scatter_gather import ( # type: ignore[attr-defined] - _is_namedtuple, -) from torch.nn.utils.rnn import PackedSequence FSDP_FLATTENED = "_fsdp_flattened" +def _is_namedtuple(obj): + # Check if type was created from collections.namedtuple or a typing.NamedTuple. + return ( + isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") + ) + + def _contains_batchnorm(module): return any( isinstance(mod, _BatchNorm) for mod in module.modules() diff --git a/torch_npu/distributed/fsdp/flatten_params_wrapper.py b/torch_npu/distributed/fsdp/flatten_params_wrapper.py index c4872ebd00..81e4fb9b57 100644 --- a/torch_npu/distributed/fsdp/flatten_params_wrapper.py +++ b/torch_npu/distributed/fsdp/flatten_params_wrapper.py @@ -10,9 +10,8 @@ import contextlib from typing import Any, Dict, Generator, List import torch -import torch_npu import torch.nn as nn -from torch.distributed.utils import _replace_by_prefix +from torch_npu.distributed.utils import _replace_by_prefix from .flat_param import FlatParamHandle, HandleConfig diff --git a/torch_npu/distributed/fsdp/fully_sharded_data_parallel.py b/torch_npu/distributed/fsdp/fully_sharded_data_parallel.py index c88f2d3747..c3971cf681 100644 --- a/torch_npu/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch_npu/distributed/fsdp/fully_sharded_data_parallel.py @@ -4328,13 +4328,13 @@ def _calc_grad_norm(parameters: List[torch.nn.Parameter], p: float) -> torch.Ten local_norm = torch.tensor(max(par.grad.detach().abs().max() for par in parameters)) else: # Compute the norm in full precision no matter what - local_norm = torch.linalg.vector_norm( + local_norm = torch.norm( torch.stack( [ - torch.linalg.vector_norm(par.grad.detach(), p, dtype=torch.float32) + torch.norm(par.grad.detach().view(-1), p, dtype=torch.float32) for par in parameters ] - ), + ).view(-1), p, ) local_norm.to(dtype=parameters[0].dtype) diff --git a/torch_npu/distributed/fsdp/sharded_grad_scaler.py b/torch_npu/distributed/fsdp/sharded_grad_scaler.py index 84350bce7a..733cce1fd8 100644 --- a/torch_npu/distributed/fsdp/sharded_grad_scaler.py +++ b/torch_npu/distributed/fsdp/sharded_grad_scaler.py @@ -348,8 +348,8 @@ class ShardedGradScaler(GradScaler): self._clear_overflow_flag = False -def apply_fsdp_shard_grad_scaler(): - torch.distributed.fsdp.shard_grad_scaler.ShardedGradScaler = ShardedGradScaler - torch.distributed.fsdp.shard_grad_scaler._GeneralMultiDeviceReplicator = _GeneralMultiDeviceReplicator - torch.distributed.fsdp.shard_grad_scaler._is_supported_device = _is_supported_device - torch.distributed.fsdp.shard_grad_scaler._refresh_per_optimizer_state = _refresh_per_optimizer_state +def apply_fsdp_sharded_grad_scaler(): + torch.distributed.fsdp.sharded_grad_scaler.ShardedGradScaler = ShardedGradScaler + torch.distributed.fsdp.sharded_grad_scaler._GeneralMultiDeviceReplicator = _GeneralMultiDeviceReplicator + torch.distributed.fsdp.sharded_grad_scaler._is_supported_device = _is_supported_device + torch.distributed.fsdp.sharded_grad_scaler._refresh_per_optimizer_state = _refresh_per_optimizer_state diff --git a/torch_npu/distributed/fsdp/utils.py b/torch_npu/distributed/fsdp/utils.py index d5bf87736c..5b06519ab8 100644 --- a/torch_npu/distributed/fsdp/utils.py +++ b/torch_npu/distributed/fsdp/utils.py @@ -5,15 +5,20 @@ from typing import Any, Callable, Dict, List, Set, Tuple, Union import torch from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.parallel.scatter_gather import ( # type: ignore[attr-defined] - _is_namedtuple, -) from torch.nn.utils.rnn import PackedSequence """Useful functions to deal with tensor types with other python container types.""" __all__ = ["p_assert"] + +def _is_namedtuple(obj): + # Check if type was created from collections.namedtuple or a typing.NamedTuple. + return ( + isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") + ) + + def _contains_batchnorm(module): return any( isinstance(mod, _BatchNorm) for mod in module.modules() diff --git a/torch_npu/distributed/utils.py b/torch_npu/distributed/utils.py index a860e9d696..582a42dd1c 100644 --- a/torch_npu/distributed/utils.py +++ b/torch_npu/distributed/utils.py @@ -2,13 +2,18 @@ import torch import torch_npu import torch.distributed as dist from torch.nn.parallel._functions import _get_stream -from torch.nn.parallel.scatter_gather import ( # type: ignore[attr-defined] - _is_namedtuple -) from typing import Any, Dict, List, Tuple __all__ = [] # type: ignore[var-annotated] + +def _is_namedtuple(obj): + # Check if type was created from collections.namedtuple or a typing.NamedTuple. + return ( + isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") + ) + + def _pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Tuple[str, ...]]: """ Turn argument list into separate key list and value list (unpack_kwargs does the opposite) diff --git a/torch_npu/nn/modules/module.py b/torch_npu/nn/modules/module.py index d807fb91be..36065c2f11 100644 --- a/torch_npu/nn/modules/module.py +++ b/torch_npu/nn/modules/module.py @@ -157,6 +157,6 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): def add_nn_module_api(): Module.__setstate__ = __setstate__ Module.__init__ = __init__ - Module._register_load_state_dict_post_hook = register_load_state_dict_post_hook + Module.register_load_state_dict_post_hook = register_load_state_dict_post_hook Module.load_state_dict = load_state_dict -- Gitee