diff --git a/mindspeed_llm/core/__init__.py b/mindspeed_llm/core/__init__.py index 1831c630adcd8f807a0d4aeda5ad0bccc88a2476..45d75a751614e6ac15010be7b091eedc08abd7a8 100644 --- a/mindspeed_llm/core/__init__.py +++ b/mindspeed_llm/core/__init__.py @@ -34,7 +34,7 @@ from .models.gpt.gpt_layer_specs import get_gpt_layer_local_spec_wrapper, build_ from .distributed.param_and_grad_buffer import start_grad_sync_wrapper from .distributed.distributed_data_parallel import distributed_data_parallel_init_wrapper from .optimizer import get_megatron_optimizer_wrapper -from .optimizer.clip_grads import clip_grad_norm_fp32_wrapper +from .optimizer.clip_grads import get_grad_norm_fp32_wrapper from .optimizer.distrib_optimizer import distributed_optimizer_init_wrapper, distributed_optimizer_init_for_reuse_fp32_wrapper from .optimizer.distrib_optimizer import get_parameter_state_dp_zero_with_high_availability_wrapper from .transformer.custom_layers.transformer_engine import PTNorm diff --git a/mindspeed_llm/core/distributed/distributed_data_parallel.py b/mindspeed_llm/core/distributed/distributed_data_parallel.py index ebd334045a8a5fc46bc9187869a2ffbd9c8d239b..aceeab4fbb0e7e8807089df1917bc4611572f2e7 100644 --- a/mindspeed_llm/core/distributed/distributed_data_parallel.py +++ b/mindspeed_llm/core/distributed/distributed_data_parallel.py @@ -1,7 +1,7 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. -from typing import Optional +import logging from functools import wraps import torch from megatron.training import get_args @@ -9,42 +9,49 @@ from megatron.core import parallel_state from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.distributed import ParamAndGradBuffer - +from megatron.core.datasets.utils import log_single_rank +from megatron.core.distributed import DistributedDataParallelConfig +logger = logging.getLogger(__name__) def distributed_data_parallel_init( self, config: TransformerConfig, + ddp_config: DistributedDataParallelConfig, module: torch.nn.Module, - data_parallel_group: torch.distributed.ProcessGroup, - accumulate_allreduce_grads_in_fp32: bool, - overlap_grad_reduce: bool, - use_distributed_optimizer: bool, - expert_data_parallel_group: Optional[torch.distributed.ProcessGroup] = None, disable_bucketing: bool = False, - check_for_nan_in_grad: bool = False, - bucket_size: int = 40000000 + ): MegatronModule.__init__(self, config=config) self.module = module + # If bucket_size is not provided as an input, use sane default. + # If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL + # ring-reduce implementations are large enough to remain bandwidth-bound rather than + # latency-bound. + if ddp_config.bucket_size is None: + ddp_config.bucket_size = max( + 40000000, 1000000 * parallel_state.get_data_parallel_world_size() + ) # Set bucket_size to infinity if overlap_grad_reduce is False. - self.overlap_grad_reduce = overlap_grad_reduce - self.use_distributed_optimizer = use_distributed_optimizer - - # Turn off bucketing if overlap_grad_reduce is False, if we are on a pipeline stage - # that is not the first (since data-parallel communication on these stages is not on - # the critical path), or if disable_bucketing is True (e.g., we might not want to - # break up model parameters into buckets for model chunks after the first - # in the interleaved schedule). - if not self.overlap_grad_reduce: - bucket_size = None + if not ddp_config.overlap_grad_reduce: + ddp_config.bucket_size = None + + self.ddp_config = ddp_config + log_single_rank( + logger, + logging.INFO, + f'Setting up DistributedDataParallel with config {self.ddp_config}', + ) + + # Turn off bucketing if we are on a pipeline stage that is not the first (since + # data-parallel communication on these stages is not on the critical path), or if + # disable_bucketing is True (e.g., we might not want to break up model parameters + # into buckets for model chunks after the first in the interleaved schedule). + self.bucket_size = self.ddp_config.bucket_size if parallel_state.get_pipeline_model_parallel_rank() > 0: - bucket_size = None + self.bucket_size = None if disable_bucketing: - bucket_size = None - - self.check_for_nan_in_grad = check_for_nan_in_grad - self.bucket_size = bucket_size + self.bucket_size = None self.module = module self.param_to_buffer = {} @@ -66,7 +73,7 @@ def distributed_data_parallel_init( expert_parallel_params.append(param) def allocate_buffers_for_parameters( - input_params, data_parallel_group, gradient_scaling_factor=1.0, + input_params, data_parallel_group, gradient_scaling_factor, ): param_and_grad_dtype_to_params = {} @@ -76,27 +83,38 @@ def distributed_data_parallel_init( continue param_dtype = param.dtype - grad_dtype = torch.float if accumulate_allreduce_grads_in_fp32 else param.dtype + grad_dtype = torch.float if self.ddp_config.grad_reduce_in_fp32 else param.dtype params = param_and_grad_dtype_to_params.get((param_dtype, grad_dtype), []) params.append(param) param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = params + if not config.calculate_per_token_loss: + target_gradient_scaling_factor = 1.0 / parallel_state.get_data_parallel_world_size( + with_context_parallel=True + ) + if self.ddp_config.average_in_collective: + # Collective is averaging gradients in collective with data_parallel_group. + assert ( + gradient_scaling_factor + / torch.distributed.get_world_size(group=data_parallel_group) + == target_gradient_scaling_factor + ) + else: + assert gradient_scaling_factor == target_gradient_scaling_factor # Allocate the grad buffers and map the grads. buffers = [] for (param_dtype, grad_dtype), params in param_and_grad_dtype_to_params.items(): buffers.append( ParamAndGradBuffer( + self.ddp_config, param_dtype, grad_dtype, params, data_parallel_group, - bucket_size, + self.bucket_size, param_to_name, - self.overlap_grad_reduce, - self.use_distributed_optimizer, gradient_scaling_factor, - self.check_for_nan_in_grad, ) ) for param in params: @@ -104,27 +122,41 @@ def distributed_data_parallel_init( return buffers - data_parallel_world_size = torch.distributed.get_world_size(data_parallel_group) + if config.calculate_per_token_loss: + gradient_scaling_factor = 1.0 + expert_gradient_scaling_factor = 1.0 + else: + if self.ddp_config.average_in_collective: + gradient_scaling_factor = 1.0 + expert_gradient_scaling_factor = ( + 1.0 / parallel_state.get_expert_model_parallel_world_size() + ) + else: + data_parallel_world_size = parallel_state.get_data_parallel_world_size( + with_context_parallel=True + ) + gradient_scaling_factor = 1.0 / data_parallel_world_size + expert_gradient_scaling_factor = 1.0 / data_parallel_world_size # Allocate the param+grad buffers for dense params' grads. self.buffers = allocate_buffers_for_parameters( dense_params, - data_parallel_group, - gradient_scaling_factor=1.0 / data_parallel_world_size, + parallel_state.get_data_parallel_group(with_context_parallel=True), + gradient_scaling_factor=gradient_scaling_factor, ) # Allocate separate param+grad buffers for expert parallel params' grads. self.expert_parallel_buffers = allocate_buffers_for_parameters( expert_parallel_params, - expert_data_parallel_group, - gradient_scaling_factor=1.0 / data_parallel_world_size, + parallel_state.get_data_modulo_expert_parallel_group(with_context_parallel=True), + gradient_scaling_factor=expert_gradient_scaling_factor, ) # Delete references to weight_tensor if they exist since we don't want two parameter copies # if we re-mapped parameters (which happens when we use the distributed optimizer). # This is a temporary workaround around a TE bug that is fixed with # https://github.com/NVIDIA/TransformerEngine/pull/719. - if self.use_distributed_optimizer: + if self.ddp_config.use_distributed_optimizer: @torch.no_grad() def unmap_weight_tensor(m): diff --git a/mindspeed_llm/core/distributed/param_and_grad_buffer.py b/mindspeed_llm/core/distributed/param_and_grad_buffer.py index 00e8d13591f6782b294d97ada4124ec78d8d45d3..083f582a15095be5d41ab41adac53aa56dbc176d 100644 --- a/mindspeed_llm/core/distributed/param_and_grad_buffer.py +++ b/mindspeed_llm/core/distributed/param_and_grad_buffer.py @@ -1,50 +1,15 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. -import os from functools import wraps -import torch -from megatron.training import get_args - - -def start_grad_sync(self): - """ - Initiates grad sync (all-reduce or reduce-scatter) communication operation - for this bucket. - - When overlap_grad_reduce is set to True, dispatches an asynchronous - communication call. When overlap_grad_reduce is set to False, makes - synchronous call. - """ - assert ( - self.communication_handle is None and not self.communication_issued - ), 'Should not have multiple communication calls in flight at once' - - # Make sure norm of grads in bucket are not NaN - # prior to data-parallel all-reduce / reduce-scatter. - if self.check_for_nan_in_grad: - global_rank = torch.distributed.get_rank() - norm = self.grad_data.norm(p=2) - assert not norm.isnan(), ( - f'Rank {global_rank}: found NaN in local grad norm in ' - f'backward pass before data-parallel communication collective. ' - f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}' - ) - - self.grad_data *= self.gradient_scaling_factor - # Use async_op only when overlap_grad_reduce is True. - self.communication_handle = torch.distributed.all_reduce( - self.grad_data, group=self.data_parallel_group, async_op=self.overlap_grad_reduce - ) - self.communication_issued = True def start_grad_sync_wrapper(fn): @wraps(fn) - def wrapper(self): - args = get_args() - if args.enable_high_availability: - start_grad_sync(self) - else: - fn(self) - return wrapper \ No newline at end of file + def wrapper(self, *args, **kwargs): + use_distributed_optimizer_tmp = self.ddp_config.use_distributed_optimizer + self.ddp_config.use_distributed_optimizer = False + fn(self, *args, **kwargs) + self.ddp_config.use_distributed_optimizer = use_distributed_optimizer_tmp + + return wrapper diff --git a/mindspeed_llm/core/optimizer/__init__.py b/mindspeed_llm/core/optimizer/__init__.py index 1ae16701dd7f40e15432bf1a3302614e4044dcf3..d1fedaded2b3b594457d65141bf6789f3ed0aae5 100644 --- a/mindspeed_llm/core/optimizer/__init__.py +++ b/mindspeed_llm/core/optimizer/__init__.py @@ -31,6 +31,7 @@ def get_megatron_optimizer_based_on_param_groups( config: OptimizerConfig, param_groups: List, per_model_buffers: Optional[Dict[int, List[ParamAndGradBuffer]]] = None, + model_parallel_group: Optional[torch.distributed.ProcessGroup] = None, data_parallel_group: Optional[torch.distributed.ProcessGroup] = None, data_parallel_group_gloo: Optional[torch.distributed.ProcessGroup] = None, data_parallel_group_idx: Optional[int] = None, @@ -130,12 +131,15 @@ def get_megatron_optimizer_based_on_param_groups( ) else: optimizer = TTPFP16ReplicaOptimizer(*optimizer_args, ori_dp_group=ori_dp_group) + setattr(optimizer, 'model_parallel_group', model_parallel_group) + else: + # FP32 optimizer. + from mindio_ttp.adaptor import TTPFP32ReplicaOptimizer + optimizer = TTPFP32ReplicaOptimizer(optimizer, config, init_state_fn, ori_dp_group=ori_dp_group) + setattr(optimizer, 'model_parallel_group', model_parallel_group) - return optimizer + return optimizer - # FP32. - from mindio_ttp.adaptor import TTPFP32ReplicaOptimizer - return TTPFP32ReplicaOptimizer(optimizer, config, init_state_fn, ori_dp_group=ori_dp_group) def get_megatron_optimizer( @@ -163,9 +167,7 @@ def get_megatron_optimizer( Instance of MegatronOptimizer. """ - if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: - logger.info(f'Setting up optimizer with {config}') - + logger.info(f'Setting up optimizer with {config}') # Collect param groups. param_groups = _get_param_groups( model_chunks, @@ -206,6 +208,7 @@ def get_megatron_optimizer( config, param_groups=dense_param_groups, per_model_buffers=per_model_buffers, + model_parallel_group=mpu.get_model_parallel_group(), data_parallel_group=ttp_get_dp_cp_replica_group(), data_parallel_group_gloo=ttp_get_dp_cp_replica_group_gloo(), ori_dp_group=mpu.get_data_parallel_group(with_context_parallel=True), @@ -220,6 +223,7 @@ def get_megatron_optimizer( config, param_groups=moe_param_groups, per_model_buffers=per_model_ep_buffers, + model_parallel_group=mpu.get_model_parallel_group(with_expert_parallel=True), data_parallel_group=ttp_get_dp_ep_replica_group(), data_parallel_group_gloo=ttp_get_dp_ep_replica_group_gloo(), ori_dp_group=mpu.get_data_modulo_expert_parallel_group(), diff --git a/mindspeed_llm/core/optimizer/clip_grads.py b/mindspeed_llm/core/optimizer/clip_grads.py index a47c1b11b680ecb8bcebbcffaa5672848d5df4dd..f641e1bd577ddb7bef3d0bcd3968b6baa91c21a9 100644 --- a/mindspeed_llm/core/optimizer/clip_grads.py +++ b/mindspeed_llm/core/optimizer/clip_grads.py @@ -6,56 +6,72 @@ import sys from functools import wraps from typing import List, Optional, Union -import amp_C import torch -from apex.multi_tensor_apply import multi_tensor_applier from torch import inf from megatron.training import get_args +try: + from transformer_engine.pytorch.optimizers import ( + multi_tensor_applier, + multi_tensor_l2norm, + multi_tensor_scale, + ) + + l2_norm_impl = multi_tensor_l2norm + multi_tensor_scale_impl = multi_tensor_scale +except ImportError: + try: + import amp_C + from apex.multi_tensor_apply import multi_tensor_applier + + l2_norm_impl = amp_C.multi_tensor_l2norm + multi_tensor_scale_impl = amp_C.multi_tensor_scale + except ImportError: + import warnings + + warnings.warn( + f'Transformer Engine and Apex are not installed. ' + 'Falling back to local implementations of multi_tensor_applier, ' + 'multi_tensor_l2norm, and multi_tensor_scale' + ) + + from megatron.core.utils import ( + local_multi_tensor_applier, + local_multi_tensor_l2_norm, + local_multi_tensor_scale, + ) + + multi_tensor_applier = local_multi_tensor_applier + l2_norm_impl = local_multi_tensor_l2_norm + multi_tensor_scale_impl = local_multi_tensor_scale -def clip_grad_norm_fp32( - parameters: Union[List[torch.Tensor], torch.Tensor], + +def get_grad_norm_fp32( grads_for_norm: Union[List[torch.Tensor], torch.Tensor], - max_norm: Union[int, float], norm_type: Union[int, float] = 2, model_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> float: - """Clips gradient norm of an iterable of parameters whose gradients - are in fp32. + """Calculate the norm of gradients in fp32. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and - added functionality to handle model parallel parameters. Note that - the gradients are modified in place. + added functionality to handle model parallel parameters. - Args: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have gradients normalized. - grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single + Arguments: + grads_for_norm (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will be used for calculating the grad norm. - max_norm (float or int): max norm of the gradients. norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. - model_parallel_group (torch.distributed.ProcessGroup, optional): model-parallel - group over which grad norm needs to be aggregated. + model_parallel_group (group): given the nature of the distributed + optimizer, this is passed as an argument. Returns: Total norm of the parameters (viewed as a single vector). """ - if isinstance(parameters, torch.Tensor): - parameters = [parameters] if isinstance(grads_for_norm, torch.Tensor): grads_for_norm = [grads_for_norm] - # Grads. - grads = [] - for param in parameters: - if param.grad is not None: - assert param.grad.type() == 'torch.cuda.FloatTensor' - grads.append(param.grad.detach()) - # Norm parameters. - max_norm = float(max_norm) norm_type = float(norm_type) total_norm = 0.0 @@ -77,7 +93,7 @@ def clip_grad_norm_fp32( # and performs the operation on that list all in one kernel. if grads_for_norm: grad_norm, _ = multi_tensor_applier( - amp_C.multi_tensor_l2norm, + l2_norm_impl, dummy_overflow_buf, [grads_for_norm], False, # no per-parameter norm @@ -93,35 +109,26 @@ def clip_grad_norm_fp32( grad_norm = torch.norm(grad, norm_type) total_norm += grad_norm ** norm_type + args = get_args() + if args.use_distributed_optimizer: + from mindio_ttp.adaptor import ttp_get_replica_dp_num + total_norm = total_norm / ttp_get_replica_dp_num() + # Sum across all model-parallel GPUs. torch.distributed.all_reduce( total_norm, op=torch.distributed.ReduceOp.SUM, group=model_parallel_group ) - - if get_args().use_distributed_optimizer: - from mindio_ttp.adaptor import ttp_get_replica_dp_num - total_norm = total_norm / ttp_get_replica_dp_num() - total_norm = total_norm.item() ** (1.0 / norm_type) - - # Scale. - clip_coeff = max_norm / (total_norm + 1.0e-6) - if clip_coeff < 1.0: - dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda') - multi_tensor_applier( - amp_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff - ) - return total_norm -def clip_grad_norm_fp32_wrapper(fn): +def get_grad_norm_fp32_wrapper(fn): @wraps(fn) def wrapper(*args, **kwargs): argument = get_args() if argument.enable_high_availability: - return clip_grad_norm_fp32(*args, **kwargs) + return get_grad_norm_fp32(*args, **kwargs) else: return fn(*args, **kwargs) return wrapper \ No newline at end of file diff --git a/mindspeed_llm/tasks/megatron_adaptor.py b/mindspeed_llm/tasks/megatron_adaptor.py index 2fcabb54d5f5b9bc532d7906d94cdbe312d7985d..0f672e79c20e78c0f3c98e245ff4b901171f181a 100644 --- a/mindspeed_llm/tasks/megatron_adaptor.py +++ b/mindspeed_llm/tasks/megatron_adaptor.py @@ -724,7 +724,8 @@ class LegacyAdaptation(MegatronAdaptationABC): def patch_high_availability_feature(self): args = MegatronAdaptation.get_args() from ..training import setup_model_and_optimizer_wrapper - from ..core import (get_megatron_optimizer_wrapper, clip_grad_norm_fp32_wrapper, + from ..training.initialize import initialize_distributed_wrapper + from ..core import (get_megatron_optimizer_wrapper, get_grad_norm_fp32_wrapper, distributed_optimizer_init_wrapper, start_grad_sync_wrapper, distributed_data_parallel_init_wrapper, distributed_optimizer_init_for_reuse_fp32_wrapper, @@ -736,10 +737,12 @@ class LegacyAdaptation(MegatronAdaptationABC): distributed_data_parallel_init_wrapper) MegatronAdaptation.register('megatron.core.distributed.param_and_grad_buffer.Bucket.start_grad_sync', start_grad_sync_wrapper) - MegatronAdaptation.register('megatron.training.training.get_megatron_optimizer', + MegatronAdaptation.register('megatron.core.optimizer.get_megatron_optimizer', get_megatron_optimizer_wrapper) - MegatronAdaptation.register('megatron.core.optimizer.optimizer.clip_grad_norm_fp32', - clip_grad_norm_fp32_wrapper) + MegatronAdaptation.register('megatron.core.optimizer.clip_grads.get_grad_norm_fp32', + get_grad_norm_fp32_wrapper) + MegatronAdaptation.register('megatron.training.initialize._initialize_distributed', + initialize_distributed_wrapper) MegatronAdaptation.register('megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.__init__', distributed_optimizer_init_wrapper) MegatronAdaptation.register('megatron.training.training.setup_model_and_optimizer', @@ -758,6 +761,14 @@ class LegacyAdaptation(MegatronAdaptationABC): distributed_optimizer_init_for_reuse_fp32_wrapper) MegatronAdaptation.register('mindio_ttp.adaptor.TTPReplicaOptimizer.get_parameter_state_dp_zero_for_ttp', get_parameter_state_dp_zero_with_high_availability_wrapper) + if args.enable_worker_reboot: + from ..training.training import build_train_valid_test_data_iterators_wrapper + from ..training.initialize import communication_wrapper, new_group_wrapper + MegatronAdaptation.register('megatron.training.training.build_train_valid_test_data_iterators', + build_train_valid_test_data_iterators_wrapper) + for communication in ['barrier', 'all_reduce', '_all_gather_base', 'broadcast', 'all_gather_into_tensor']: + MegatronAdaptation.register('torch.distributed.distributed_c10d.' + communication, communication_wrapper) + MegatronAdaptation.register('torch.distributed.distributed_c10d.new_group', new_group_wrapper) def patch_model(self): from mindspeed.core.fusions.fused_layer_norm import (FusedLayerNormAffineFunction, FastLayerNormFN) diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index 161ffc841243fd88eed26e3e4ff3b04cbe9b23f2..b5e307d572014a1943fad0a4d3704517f7c0214f 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -984,6 +984,14 @@ def _add_high_availability_args(parser): action='store_true', help='high availability feature, enable hbmfault repair') + group.add_argument('--enable-worker-reboot', + action='store_true', + help='high availability feature, enable ARF') + + group.add_argument('--distributed-optimizer-no-replica', + action='store_true', + help='high availability feature, enable ARF') + return parser @@ -1085,6 +1093,12 @@ def _validate_high_availability(args): if args.enable_hbmfault_repair and not args.enable_high_availability: raise AssertionError( 'switch of the enable hbmfault repair is unsupported, please enable high availability feature first.') + if args.enable_worker_reboot and not args.enable_high_availability: + raise AssertionError( + 'switch of the enable worker reboot(ARF) is unsupported, please enable high availability feature first.') + if args.distributed_optimizer_no_replica and not args.enable_high_availability: + raise AssertionError( + 'switch of the no replica is unsupported, please enable high availability feature first.') if args.enable_high_availability and args.use_dist_ckpt: raise AssertionError('switch of the high availability feature is unsupported') diff --git a/mindspeed_llm/training/initialize.py b/mindspeed_llm/training/initialize.py index 497c0b8f3025b0ce3d834851390c388e209c2bee..3c30850cea2c39a378b06129a8cd3fa4afae06c5 100644 --- a/mindspeed_llm/training/initialize.py +++ b/mindspeed_llm/training/initialize.py @@ -160,4 +160,54 @@ def coc_registration_wrapper(fn): initialize_coc_from_cfg(args) return res - return wrapper \ No newline at end of file + return wrapper + + +def initialize_distributed_wrapper(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + device_count = torch.cuda.device_count() + device = get_args().rank % device_count + torch.cuda.set_device(device) + from mindio_ttp.adaptor import tft_init_controller_processor, ttp_initialize_replica_dp_group + tft_init_controller_processor(enable_tls=False, tls_option_top_path='') + fn(*args, **kwargs) + world_size: int = torch.distributed.get_world_size() + args = megatron.training.get_args() + ttp_initialize_replica_dp_group( + args.pipeline_model_parallel_size, + args.tensor_model_parallel_size, + args.context_parallel_size, + args.expert_model_parallel_size, + world_size + ) + + return wrapper + + +def communication_wrapper(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + from mindio_ttp.adaptor import tft_is_arf_reboot_node + if tft_is_arf_reboot_node(): + return None + return fn(*args, **kwargs) + + return wrapper + + +def new_group_wrapper(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + backend = kwargs.get('backend', None) + from mindio_ttp.adaptor import tft_is_arf_reboot_node + if tft_is_arf_reboot_node() and isinstance(backend, str) and 'gloo' in backend: + return None + + if (backend is None) or torch.distributed.distributed_c10d._is_barrier_after_init(): + kwargs['use_local_synchronization'] = True + + res = fn(*args, **kwargs) + return res + + return wrapper diff --git a/mindspeed_llm/training/training.py b/mindspeed_llm/training/training.py index d042c8f9c73ab9a84983fbe5c9b4e4c0c0182aa1..2a6fb7fef068678bf6baa9cb9c244b743cc634e8 100644 --- a/mindspeed_llm/training/training.py +++ b/mindspeed_llm/training/training.py @@ -340,9 +340,6 @@ def pretrain(train_valid_test_dataset_provider, args = get_args() timers = get_timers() - if args.enable_high_availability: - raise AssertionError("High availability feature do not support core_r0.8.0") - if args.log_progress: append_to_progress_log("Starting job") @@ -391,8 +388,7 @@ def pretrain(train_valid_test_dataset_provider, iteration = 0 if args.do_train and args.train_iters > 0: if args.enable_high_availability: - from mindio_ttp.adaptor import tft_init_controller_processor, tft_register_processor, tft_train - tft_init_controller_processor(enable_tls=False, tls_option_top_path='') + from mindio_ttp.adaptor import tft_register_processor, tft_train tft_register_processor(train_valid_test_dataset_provider, model_provider, model_type) iteration, num_floating_point_operations_so_far = tft_train(train_args, test_data_iterator_list) else: @@ -697,4 +693,16 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, if exit: sys.exit() - return iteration, num_floating_point_operations_so_far \ No newline at end of file + return iteration, num_floating_point_operations_so_far + + +def build_train_valid_test_data_iterators_wrapper(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + res = fn(*args, **kwargs) + from mindio_ttp.adaptor import tft_is_arf_reboot_node + if tft_is_arf_reboot_node(): + get_args().do_train = True + return res + + return wrapper \ No newline at end of file