diff --git a/docs/pytorch/features/high_availability.md b/docs/pytorch/features/high_availability.md index 6a521494abc1e772465b2c7a208822f15cc15b29..43a813d55dd0037f684cbbcd336a42b924346ee0 100644 --- a/docs/pytorch/features/high_availability.md +++ b/docs/pytorch/features/high_availability.md @@ -12,6 +12,11 @@ 昇腾芯片支持NPU卡内存发生UCE故障(内存不可修复)的实时检测,检测到UCE故障后,基于优化器状态副本机制并完成故障卡的在线修复并继续训练,减少训练损失。 +### 弹性训练功能 + +在训练过程中发生故障后,在训练集群中没有空闲资源可替换时,基于优化器状态副本机制缩掉部分节点继续训练;当训练集群中有空闲资源可使用时,再基于优化器状态副本机制扩容回原有规模继续训练。 +当前阶段仅支持Data Parallel级别的弹性训练,即按照Data Parallel粒度缩掉部分数据并行域进行扩容或缩容。 + ### 原理说明 megatron原生的分布式优化器数据流及工作原理如下图: @@ -33,20 +38,24 @@ megatron原生的分布式优化器数据流及工作原理如下图: ### 环境准备 -MindIO的功能以whl包的形式提供 +MindIO的功能以whl包的形式提供,其中弹性训练功能依赖mindio_ttp和taskd两个whl包 mindio_ttp下载地址:[MindIO TTP 下载软件包-昇腾社区](https://www.hiascend.com/document/detail/zh/mindx-dl/600/clusterscheduling/ref/mindioacp/mindioacp009.html) +taskd下载地址:[TaskD 下载软件包-昇腾社区](https://www.hiascend.com/document/detail/zh/mindcluster/71RC1/clustersched/dlug/dlug_installation_009.html) + ### 启动脚本中添加启动参数 `--enable-high-availability` # 使能开启高可用功能的总开关,并使能TTP临终遗言功能,保存checkpoint时要求全局至少存在一份完整的优化器数据; `--enable-hbmfault-repair` # 使能进行片上内存故障,Step级重计算功能的开关;本功能将在线进行worker级修复,修复时要求全局至少存在一个故障卡的副本卡。 -`--enable-worker-reboot` # 使能空中加油功能,配合支持相关功能的 MindX DL 组件共同使能后,在发生一般性故障时,进行进程级重启修复,继续训练。本功能会将故障卡所在节点进行重启,修复时要求未故障节点中至少存在一份完整的优化器数据。 +`--enable-worker-reboot` # 使能空中加油功能,配合支持相关功能的 Mind Cluster 组件共同使能后,在发生一般性故障时,进行进程级重启修复,继续训练。本功能会将故障卡所在节点进行重启,修复时要求未故障节点中至少存在一份完整的优化器数据。 `--distributed-optimizer-no-replica` # 不使用副本优化器而使用CKPT文件进行重计算和空中加油修复,需要在故障时存在CKPT文件。 +`--enable-elastic-training` # 使能弹性训练功能,配合支持相关功能的 Mind Cluster组件共同使能后,在发生一般性故障且无空闲芯片资源时,缩掉部分节点后继续训练,待有可用芯片资源时扩容回原有规模继续训练。本功能会将故障卡所在Data Parallel域对应节点剔除,修复时要求未故障节点中至少存在一份完整的优化器数据。 + ### 启动脚本中添加环境变量 为避免在结合mindx使用时需配置多个组件的开关,添加环境变量,环境变量优先级高于args,设置环境变量会被优先使用。 @@ -57,8 +66,21 @@ mindio_ttp下载地址:[MindIO TTP 下载软件包-昇腾社区](https://www.h `export HIGH_AVAILABILITY=recover` 启用 `--enable-high-availability` `--enable-worker-reboot` +`export HIGH_AVAILABILITY=elastic-training` 启用 `--enable-high-availability` `--enable-elastic-training` + ## 使用约束 由于原理限制,为了保证故障发生后,有完整的优化器状态数据,需要在ptd切分时保障Data Parallel Size大于1,在使用MoE特性时还要求稠密层与稀疏层的Data Parallel Size均大于1,在使用长序列并行特性时还要求dp_cp_size大于1。 +### 弹性训练功能使用约束 +除上述使用约束外,针对弹性训练功能还需遵守以下使用约束: + +1、当前仅支持开启enable-high-availability、use-distributed-optimizer + +2、当前仅支持不开启use-custom-fsdp、reuse-fp32-param的场景 + +3、当前仅支持Data Parallel、Tensor Parallel、Pipeline Parallel并行 + +4、当前缩容后不可再次缩容,扩容仅支持直接扩容回原有规模 + 详见:[MindIO TTP 约束限制-昇腾社区](https://www.hiascend.com/document/detail/zh/mindx-dl/600/clusterscheduling/ref/mindiottp/mindiotft005.html) \ No newline at end of file diff --git a/mindspeed_llm/core/distributed/param_and_grad_buffer.py b/mindspeed_llm/core/distributed/param_and_grad_buffer.py index cb548c77d1a4274f79b7406d71cbab2a6d5a159b..f4e5ca8d1f7818c9b1dbb60a8687f9ab4dbe2700 100644 --- a/mindspeed_llm/core/distributed/param_and_grad_buffer.py +++ b/mindspeed_llm/core/distributed/param_and_grad_buffer.py @@ -7,22 +7,49 @@ from megatron.training import get_args from megatron.core.distributed.param_and_grad_buffer import (shard_buffer, dist_all_gather_func) - def start_grad_sync_wrapper(fn): @wraps(fn) def wrapper(self, *args, **kwargs): self.ddp_config.use_distributed_optimizer, use_distributed_optimizer_tmp = False, self.ddp_config.use_distributed_optimizer + gradient_scaling_factors = [] + arguments = get_args() + for bucket in self.buckets: + gradient_scaling_factors.append(bucket.gradient_scaling_factor) try: if use_distributed_optimizer_tmp: self.data_parallel_group = self.intra_distributed_optimizer_instance_group + if arguments.enable_elastic_training: + # let gradient_scaling_factor be divided by num_micro_batches more, + # because it wasn't divided during the loss calculation in the forward_step function. + from taskd.python.adaptor.elastic_training import common + if common.zit_scale_in_running_state(): + for bucket in self.buckets: + bucket.gradient_scaling_factor = 1.0 / ( + arguments.global_batch_size / arguments.micro_batch_size) fn(self, *args, **kwargs) finally: if use_distributed_optimizer_tmp: self.data_parallel_group = None self.ddp_config.use_distributed_optimizer = use_distributed_optimizer_tmp + if arguments.enable_elastic_training: + recover_gradient_scaling_factors(self, gradient_scaling_factors) return wrapper +def recover_gradient_scaling_factors(self, gradient_scaling_factors): + """ + Restore the modified parameter 'gradient_scaling_factor'. + """ + from taskd.python.adaptor.elastic_training import common + if not common.zit_scale_in_running_state(): + return + index = 0 + for bucket in self.buckets: + if index < len(gradient_scaling_factors): + bucket.gradient_scaling_factor = gradient_scaling_factors[index] + index += 1 + + def start_param_sync(self, force_sync: bool = False): assert self.ddp_config.use_distributed_optimizer assert self.intra_distributed_optimizer_instance_group_for_tft @@ -36,6 +63,40 @@ def start_param_sync(self, force_sync: bool = False): assert self.param_gather_handle is None async_op = self.ddp_config.overlap_param_gather and not force_sync + deal_param_gather_handle_default(self, async_op) + arguments = get_args() + if arguments.enable_elastic_training: + deal_param_gather_handle_scale_in_running(self, async_op) + self.param_gather_dispatched = True + + +def deal_param_gather_handle_scale_in_running(self, async_op): + """ + In scale-in training state, the replica ranks of fault ranks need to do an addition gather operation. + """ + from taskd.python.adaptor.elastic_training import common + if not common.zit_scale_in_running_state(): + return + if not common.zit_fault_rank_in_dp_cp_replica_group() and common.zit_is_fault_replica_rank(): + instance_group = common.SCALE_IN_DP_CP_REPLICA_GROUP + instance_rank = torch.distributed.get_rank( + group=instance_group + ) + instance_size = torch.distributed.get_world_size( + group=instance_group) + for bucket in self.buckets: + local_data_view = shard_buffer( + bucket.param_data, instance_size + )[instance_rank] + dist_all_gather_func( + bucket.param_data, + local_data_view, + group=instance_group, + async_op=async_op, + ) + + +def deal_param_gather_handle_default(self, async_op): self.param_gather_handle = [] # Coalesce communication kernels across buckets in the bucket group. instance_group = self.intra_distributed_optimizer_instance_group_for_tft() @@ -43,7 +104,7 @@ def start_param_sync(self, force_sync: bool = False): group=instance_group ) instance_size = torch.distributed.get_world_size( - group=instance_group) + group=instance_group) for bucket in self.buckets: local_data_view = shard_buffer( bucket.param_data, instance_size @@ -58,7 +119,6 @@ def start_param_sync(self, force_sync: bool = False): if not async_op: self.param_gather_handle = None - self.param_gather_dispatched = True def param_and_grad_bucket_group_init_wrapper(fn): diff --git a/mindspeed_llm/core/optimizer/__init__.py b/mindspeed_llm/core/optimizer/__init__.py index d86d6e77261e077ea6db6872b1804c2dd9912cec..6b62e3848bc1386171ef10965f997f6ffa5188df 100644 --- a/mindspeed_llm/core/optimizer/__init__.py +++ b/mindspeed_llm/core/optimizer/__init__.py @@ -184,7 +184,20 @@ def get_megatron_optimizer_based_on_param_groups( ] from mindio_ttp.adaptor import TTPReplicaOptimizer, TTPFP16ReplicaOptimizer - if config.use_distributed_optimizer: + args = get_args() + if config.use_distributed_optimizer and args.enable_elastic_training: + from taskd.python.adaptor.elastic_training.optimizer import TTPElasticTrainingReplicaOptimizer + optimizer = TTPElasticTrainingReplicaOptimizer( + *optimizer_args, + model_chunks=model_chunks, + per_model_buffers=per_model_buffers, + data_parallel_group=data_parallel_group, + data_parallel_group_gloo=data_parallel_group_gloo, + data_parallel_group_idx=data_parallel_group_idx, + distributed_optimizer_instance_id=distributed_optimizer_instance_id, + ori_dp_group=ori_dp_group + ) + elif config.use_distributed_optimizer: optimizer = TTPReplicaOptimizer( *optimizer_args, model_chunks=model_chunks, diff --git a/mindspeed_llm/core/optimizer/clip_grads.py b/mindspeed_llm/core/optimizer/clip_grads.py index 4b9f2f14ddc30d859a417a63cb656fe7fc60f198..f36bb36de41d378f911049447cfaf749f4eda1b0 100644 --- a/mindspeed_llm/core/optimizer/clip_grads.py +++ b/mindspeed_llm/core/optimizer/clip_grads.py @@ -2,7 +2,10 @@ # Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. from functools import wraps + +import torch from megatron.training import get_args +from megatron.core import mpu def get_grad_norm_fp32_wrapper(fn): @@ -10,11 +13,62 @@ def get_grad_norm_fp32_wrapper(fn): def wrapper(*args, **kwargs): argument = get_args() if argument.use_distributed_optimizer: - from mindio_ttp.adaptor import ttp_get_replica_dp_num - norm_type = kwargs.get('norm_type', 2.0) - if len(args) > 1: - norm_type = float(args[1]) - return fn(*args, **kwargs) / (ttp_get_replica_dp_num() ** (1.0 / norm_type)) + return get_grad_norm_fp32(fn, *args, **kwargs) else: return fn(*args, **kwargs) - return wrapper \ No newline at end of file + return wrapper + + +def get_grad_norm_fp32(fn, *args, **kwargs): + try: + from taskd.python.adaptor.elastic_training import common + if not common.zit_scale_in_running_state(): + return get_grad_norm_fp32_default(fn, *args, **kwargs) + return get_grad_norm_fp32_scale_in_running(fn, *args, **kwargs) + except ImportError: + return get_grad_norm_fp32_default(fn, *args, **kwargs) + + +def get_grad_norm_fp32_default(fn, *args, **kwargs): + from mindio_ttp.adaptor import ttp_get_replica_dp_num + norm_type = kwargs.get('norm_type', 2.0) + if len(args) > 1: + norm_type = float(args[1]) + return fn(*args, **kwargs) / (ttp_get_replica_dp_num() ** (1.0 / norm_type)) + + +def get_grad_norm_fp32_scale_in_running(fn, *args, **kwargs): + """ + In the context of scale-in training scenarios, change the way of get_grad_norm_fp32 result. + First, do all-reduce in the model parallel group. + Then do all-reduce in the data parallel and context parallel replica group. + """ + norm_type = kwargs.get('norm_type', 2.0) + if len(args) > 1: + norm_type = float(args[1]) + grad_stats_parallel_group_arg_index = 2 + new_args = args + # change teh all reduce group to the model parallel group + if len(args) > grad_stats_parallel_group_arg_index and args[grad_stats_parallel_group_arg_index] is None: + args_list = list(args) + args_list[grad_stats_parallel_group_arg_index] = mpu.get_model_parallel_group() + new_args = tuple(args_list) + elif len(args) <= grad_stats_parallel_group_arg_index and kwargs.get('grad_stats_parallel_group', None) is None: + kwargs['grad_stats_parallel_group'] = mpu.get_model_parallel_group() + # Get the result of summation within the model parallel group first. + # Then perform an all-reduce operation within the data_parallel_and_context_parallel_replica group to obtain + # the world group sum of the native function. + total_norm = fn(*new_args, **kwargs) ** norm_type + total_norm_tensor = torch.tensor([float(total_norm)], dtype=torch.float, device='cuda') + replica_total_norm_tensor = total_norm_tensor.clone() + from mindio_ttp.adaptor import ttp_get_dp_cp_replica_group + from taskd.python.adaptor.elastic_training import common + group = ttp_get_dp_cp_replica_group() + if common.zit_fault_rank_in_dp_cp_replica_group(): + group = common.zit_get_scale_in_dp_cp_replica_group() + torch.distributed.all_reduce(total_norm_tensor, op=torch.distributed.ReduceOp.SUM, group=group) + if not common.zit_fault_rank_in_dp_cp_replica_group() and common.zit_is_fault_replica_rank(): + total_norm_tensor = replica_total_norm_tensor + torch.distributed.all_reduce(total_norm_tensor, op=torch.distributed.ReduceOp.SUM, + group=common.zit_get_scale_in_dp_cp_replica_group()) + return total_norm_tensor.item() ** (1.0 / norm_type) \ No newline at end of file diff --git a/mindspeed_llm/core/optimizer/distrib_optimizer.py b/mindspeed_llm/core/optimizer/distrib_optimizer.py index bf7fbfd8d5ad64538613632c113c73ab52385f29..4c86662f1212132c5af7765b78f927246a4168c9 100644 --- a/mindspeed_llm/core/optimizer/distrib_optimizer.py +++ b/mindspeed_llm/core/optimizer/distrib_optimizer.py @@ -211,3 +211,25 @@ def get_parameter_state_dp_zero_with_high_availability_wrapper(func): state['shard_main_param_res'] = buffer_res_full_shard return state return wrapper + + +def get_parameter_state_dp_zero_wrapper(fn): + """ + In the context of scale-in training scenarios, have the replica rank with the fault perform + an addition gather operation. + """ + @wraps(fn) + def wrapper(self): + from taskd.python.adaptor.elastic_training import common + if not common.zit_scale_in_running_state(): + return fn(self) + state = None + if not common.zit_fault_rank_in_dp_cp_replica_group(): + state = fn(self) + if common.zit_fault_rank_in_dp_cp_replica_group() or common.zit_is_fault_replica_rank(): + dp_group_gloo = self.data_parallel_group_gloo + self.data_parallel_group_gloo = common.zit_get_scale_in_dp_cp_replica_group_gloo() + state = fn(self) + self.data_parallel_group_gloo = dp_group_gloo + return state + return wrapper \ No newline at end of file diff --git a/mindspeed_llm/core/optimizer_param_scheduler.py b/mindspeed_llm/core/optimizer_param_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..52dc59ea41179652ed202da957a6c1f722a31957 --- /dev/null +++ b/mindspeed_llm/core/optimizer_param_scheduler.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. + +from functools import wraps + + +def optimizer_param_scheduler_step_wrapper(fn): + """ + In the context of scale-in training scenarios, change the parameter 'increment' + to get_args().global_batch_size. Because every data parallel's num_micro_bathes + may be different. + """ + @wraps(fn) + def wrapper(self, increment: int): + from taskd.python.adaptor.elastic_training import common + if common.zit_scale_in_running_state(): + from megatron.training import get_args + increment = get_args().global_batch_size + return fn(self, increment) + return wrapper \ No newline at end of file diff --git a/mindspeed_llm/core/pipeline_parallel/schedules.py b/mindspeed_llm/core/pipeline_parallel/schedules.py index f55ed69c165a79db5f074cbebaf8709eeda8dded..28ef8039ca4033d29516a14dae43a7a1d4b9e765 100644 --- a/mindspeed_llm/core/pipeline_parallel/schedules.py +++ b/mindspeed_llm/core/pipeline_parallel/schedules.py @@ -63,4 +63,70 @@ def forward_backward_pipelining_with_interleaving_wrapper(fn): if args_.virtual_pipeline_model_parallel_size is not None and args_.stage == "orm": kwargs['micro_batch_size'] = args_.micro_batch_size * 2 return fn(*args, **kwargs) + return wrapper + + +def forward_step_wrapper(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + """ + In the context of a scaling-in operation, modify the input parameter num_microbatches to 1. + The purpose of this modification is to ensure that during the loss calculation within this function, + averaging across the num_microbatches dimension is not performed. Instead, averaging will be uniformly + applied across the data_parallel_size*num_microbatches dimensions at the final stage. + """ + from taskd.python.adaptor.elastic_training import common + if not common.zit_scale_in_running_state(): + return fn(*args, **kwargs) + new_args = args + num_microbatches_index = 3 + if len(args) >= num_microbatches_index + 1: + args_list = list(args) + args_list[num_microbatches_index] = 1 + new_args = tuple(args_list) + else: + kwargs['num_microbatches'] = 1 + return fn(*new_args, **kwargs) + return wrapper + + +def elastic_training_get_forward_backward_func_wrapper(fn): + """ + In the context of scale-in training scenarios, perform an all-reduce operation on the sum + of the 'lm loss' values for all micro batches within the data parallel and context parallel + replica group. Because it wasn't done in the 'loss_func' function. + """ + @wraps(fn) + def wrapper(): + from taskd.python.adaptor.elastic_training import common + if not common.zit_scale_in_running_state(): + return fn() + forward_backward_func = fn() + + def scale_in_forward_backward_func(*args, **kwargs): + losses_reduced = forward_backward_func(*args, **kwargs) + from megatron.core import mpu + if not mpu.is_pipeline_last_stage(ignore_virtual=True): + return losses_reduced + new_losses_reduced = [] + loss_reduced = {} + for key in losses_reduced[0].keys(): + numerator = 0 + denominator = 0 + for x in losses_reduced: + val = x[key] + if isinstance(val, tuple) or isinstance(val, list): + numerator += val[0] + denominator += val[1] + else: + numerator += val + denominator += 1 + value_tensor = torch.tensor([numerator, denominator], device="cuda") + torch.distributed.all_reduce(value_tensor, group=mpu.get_data_parallel_group()) + loss_reduced[key] = (value_tensor[0].item(), value_tensor[1].item()) + new_losses_reduced.append(loss_reduced) + return new_losses_reduced + + return scale_in_forward_backward_func + return wrapper \ No newline at end of file diff --git a/mindspeed_llm/core/timers.py b/mindspeed_llm/core/timers.py new file mode 100644 index 0000000000000000000000000000000000000000..7be12597ec18cefb29f40e75a66acc33fef36b01 --- /dev/null +++ b/mindspeed_llm/core/timers.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. + +from functools import wraps + +import torch + + +def patch_world_size_func_wrapper(fn): + """ + In the context of scale-in training scenarios, change the world group to scale-in world group when + get world size and rank. + """ + @wraps(fn) + def wrapper(*args, **kwargs): + from taskd.python.adaptor.elastic_training import common + if not common.zit_scale_in_running_state(): + return fn(*args, **kwargs) + get_world_size_func = torch.distributed.get_world_size + get_rank_func = torch.distributed.get_rank + + def patch_get_world_size(group=None): + if group is not None or not common.zit_scale_in_running_state(): + return get_world_size_func(group) + return get_world_size_func(group=common.zit_get_scale_in_world_group()) + + def patch_get_rank(group=None): + if group is not None or not common.zit_scale_in_running_state(): + return get_rank_func(group) + return get_rank_func(group=common.zit_get_scale_in_world_group()) + + torch.distributed.get_world_size = patch_get_world_size + torch.distributed.get_rank = patch_get_rank + result = fn(*args, **kwargs) + torch.distributed.get_world_size = get_world_size_func + torch.distributed.get_rank = get_rank_func + return result + + return wrapper + + +def log_wrapper(fn): + """ + In the context of scale-in training scenarios, change the parameter 'rank' + to the last rank of scale-in world group when the rank passed in is 'None'. + """ + @wraps(fn) + def wrapper(*args, **kwargs): + from taskd.python.adaptor.elastic_training import common + if not common.zit_scale_in_running_state(): + return fn(*args, **kwargs) + rank_index = 2 + scale_in_word_ranks = torch.distributed.get_process_group_ranks( + group=common.zit_get_scale_in_world_group()) + need_change_rank, change_str = is_need_change_rank(*args, **kwargs) + if need_change_rank and change_str == 'args': + args_list = list(args) + args_list[rank_index] = scale_in_word_ranks[-1] + new_args = tuple(args_list) + return fn(*new_args, **kwargs) + elif need_change_rank and change_str == 'kwargs': + kwargs['rank'] = scale_in_word_ranks[-1] + return fn(*args, **kwargs) + return fn(*args, **kwargs) + return wrapper + + +def is_need_change_rank(*args, **kwargs): + """ + Check whether the parameter 'rank' passed in is 'None'. + """ + rank_index = 2 + if len(args) <= rank_index and kwargs.get('group', None) is None: + return True, 'kwargs' + if len(args) > rank_index and args[rank_index] is None: + return True, 'args' + return False, "" \ No newline at end of file diff --git a/mindspeed_llm/features_manager/high_availability/communication_patch.py b/mindspeed_llm/features_manager/high_availability/communication_patch.py index 05e8677ef7c217de7b73ba55ff41072350b161f9..73134fbdad19e71dd3057df9fd2566b50fff02a3 100644 --- a/mindspeed_llm/features_manager/high_availability/communication_patch.py +++ b/mindspeed_llm/features_manager/high_availability/communication_patch.py @@ -7,10 +7,29 @@ def communication_wrapper(fn): @wraps(fn) def wrapper(*args, **kwargs): from megatron.training import get_args - if get_args().enable_high_availability: + arguments = get_args() + if arguments.enable_high_availability: from mindio_ttp.adaptor import tft_is_arf_reboot_node if tft_is_arf_reboot_node(): return None + if arguments.enable_elastic_training: + group_index = 2 + return torch_wrapper(fn, group_index, *args, **kwargs) + return fn(*args, **kwargs) + return wrapper + + +def barrier_wrapper(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + from megatron.training import get_args + arguments = get_args() + if arguments.enable_high_availability: + from mindio_ttp.adaptor import tft_is_arf_reboot_node + if tft_is_arf_reboot_node(): + return None + if arguments.enable_elastic_training: + return torch_wrapper(fn, 0, *args, **kwargs) return fn(*args, **kwargs) return wrapper @@ -28,3 +47,85 @@ def new_group_wrapper(fn): res = fn(*args, **kwargs) return res return wrapper + + +def is_need_change_group(group_index, *args, **kwargs): + """ + Check whether the 'group' parameter passed in is 'None' to determine if the value of 'group' + parameter needs to be changed in the scenario of scale-in training, and whether to modify 'args' + or 'kwargs'. + """ + if group_index < 0: + return False, "" + if len(args) <= group_index and kwargs.get('group', None) is None: + return True, 'kwargs' + if len(args) > group_index and args[group_index] is None: + return True, 'args' + return False, "" + + +def group_index_two_torch_wrapper(fn): + """ + In the context of scale-in training scenarios, if the 'group' parameter passed in is 'None', + change it to the scale-in world group. + """ + @wraps(fn) + def wrapper(*args, **kwargs): + from megatron.training import get_args + if not get_args().enable_elastic_training: + return fn(*args, **kwargs) + group_index = 2 + return torch_wrapper(fn, group_index, *args, **kwargs) + return wrapper + + +def group_index_three_torch_wrapper(fn): + """ + In the context of scale-in training scenarios, if the 'group' parameter passed in is 'None', + change it to the scale-in world group. + """ + @wraps(fn) + def wrapper(*args, **kwargs): + from megatron.training import get_args + if not get_args().enable_elastic_training: + return fn(*args, **kwargs) + group_index = 3 + return torch_wrapper(fn, group_index, *args, **kwargs) + return wrapper + + +def all_to_all_single_wrapper(fn): + """ + In the context of scale-in training scenarios, if the 'group' parameter passed in is 'None', + change it to the scale-in world group. + """ + @wraps(fn) + def wrapper(*args, **kwargs): + from megatron.training import get_args + if not get_args().enable_elastic_training: + return fn(*args, **kwargs) + group_index = 4 + return torch_wrapper(fn, group_index, *args, **kwargs) + return wrapper + + +def torch_wrapper(fn, group_index, *args, **kwargs): + """ + In the context of scale-in training scenarios, if the 'group' parameter passed in is 'None', + change it to the scale-in world group. + """ + from mindio_ttp.adaptor.tft_arf_group_repair import tft_is_arf_reboot_node + from taskd.python.adaptor.elastic_training import common + if tft_is_arf_reboot_node(): + return None + if common.zit_scale_in_running_state(): + need_change_group, change_str = is_need_change_group(group_index, *args, **kwargs) + if need_change_group and change_str == 'args': + args_list = list(args) + args_list[group_index] = common.zit_get_scale_in_world_group() + new_args = tuple(args_list) + return fn(*new_args, **kwargs) + if need_change_group and change_str == 'kwargs': + kwargs['group'] = common.zit_get_scale_in_world_group() + return fn(*args, **kwargs) + return fn(*args, **kwargs) diff --git a/mindspeed_llm/features_manager/high_availability/high_availability.py b/mindspeed_llm/features_manager/high_availability/high_availability.py index c1a73692ac2da98651fa40bb5d6b0a4547662a48..a6bd80fc19e279c3393b195aed738dbfc97a6503 100644 --- a/mindspeed_llm/features_manager/high_availability/high_availability.py +++ b/mindspeed_llm/features_manager/high_availability/high_availability.py @@ -17,6 +17,8 @@ class HighAvailabilityFeature(MindSpeedFeature): help='high availability feature, enable worker reboot') group.add_argument('--distributed-optimizer-no-replica', action='store_true', help='high availability feature, repair from ckpt and disable replica optimizer') + group.add_argument('--enable-elastic-training', action='store_true', + help='high availability feature, enable elastic training') def pre_validate_args(self, args): from .high_availability_helper import get_env_args @@ -36,13 +38,50 @@ class HighAvailabilityFeature(MindSpeedFeature): raise AssertionError('switch of the high availability feature is unsupported') if args.swap_optimizer and args.enable_high_availability: raise AssertionError('switch of the high availability feature is unsupported') + if args.enable_elastic_training: + try: + import taskd.python.adaptor.elastic_training + except ModuleNotFoundError as e: + raise AssertionError( + f"enable elastic training requires the taskd.python.adaptor.elastic_training package" + f" but it is not installed.") from e + if args.enable_elastic_training and not args.enable_high_availability: + raise AssertionError( + 'switch of the enable elastic training is unsupported, please enable high availability feature first.') + if args.enable_elastic_training and not args.use_distributed_optimizer: + raise AssertionError( + 'switch of the enable elastic training is unsupported, please enable use-distributed-optimizer first.') + if args.enable_elastic_training and args.use_custom_fsdp: + raise AssertionError( + 'switch of the enable elastic training is unsupported when reuse-fp32-param is enabled.') + if args.enable_elastic_training and args.reuse_fp32_param: + raise AssertionError( + 'switch of the enable elastic training is unsupported when reuse-fp32-param is enabled.') + if args.enable_elastic_training and (args.expert_model_parallel_size > 1 or args.context_parallel_size > 1): + raise AssertionError( + 'switch of the enable elastic training is unsupported when expert-model-parallel-size, context ' + 'parallel size is set.') def pre_register_patches(self, patch_manager, args): - from .communication_patch import communication_wrapper + from .communication_patch import communication_wrapper, barrier_wrapper from .high_availability_helper import skip_reuse_register_patches - for communication in ['barrier', 'all_reduce', '_all_gather_base', 'broadcast', 'all_gather_into_tensor']: + patch_manager.register_patch('torch.distributed.barrier', + barrier_wrapper) + for communication in ['all_reduce', '_all_gather_base', 'broadcast', 'all_gather_into_tensor']: patch_manager.register_patch('torch.distributed.distributed_c10d.' + communication, communication_wrapper) + from .communication_patch import (group_index_two_torch_wrapper, + group_index_three_torch_wrapper, all_to_all_single_wrapper) + patch_manager.register_patch('torch.distributed.all_to_all_single', + all_to_all_single_wrapper) + for communication in ['all_gather', 'all_to_all', 'all_reduce_coalesced', 'all_gather_object', + 'broadcast_object_list', 'all_gather_coalesced']: + patch_manager.register_patch('torch.distributed.' + communication, + group_index_two_torch_wrapper) + for communication in ['gather', 'scatter', 'reduce', 'reduce_scatter', 'gather_object', + 'scatter_object_list', 'reduce_scatter_tensor', '_reduce_scatter_base']: + patch_manager.register_patch('torch.distributed.' + communication, + group_index_three_torch_wrapper) from mindspeed.features_manager import ReuseFP32Param ReuseFP32Param.register_patches = skip_reuse_register_patches(ReuseFP32Param.register_patches, args) @@ -87,10 +126,48 @@ class HighAvailabilityFeature(MindSpeedFeature): distributed_optimizer_init_for_reuse_fp32_wrapper) patch_manager.register_patch('mindio_ttp.adaptor.TTPReplicaOptimizer.get_parameter_state_dp_zero_for_ttp', get_parameter_state_dp_zero_with_high_availability_wrapper) - if args.enable_worker_reboot: + if args.enable_worker_reboot or args.enable_elastic_training: from .initialize_patch import build_train_valid_test_data_iterators_wrapper from mindspeed_llm.features_manager.high_availability.communication_patch import new_group_wrapper patch_manager.register_patch('megatron.training.training.build_train_valid_test_data_iterators', build_train_valid_test_data_iterators_wrapper) patch_manager.register_patch('torch.distributed.distributed_c10d.new_group', new_group_wrapper) + if args.enable_elastic_training: + from mindspeed_llm.core.pipeline_parallel.schedules import forward_step_wrapper + from mindspeed_llm.core.optimizer.distrib_optimizer import get_parameter_state_dp_zero_wrapper + from mindspeed_llm.core.timers import patch_world_size_func_wrapper, log_wrapper + from mindspeed_llm.training.utils import is_last_rank_wrapper, print_rank_last_wrapper + from mindspeed_llm.core.optimizer_param_scheduler import optimizer_param_scheduler_step_wrapper + from mindspeed_llm.core.pipeline_parallel.schedules import ( + elastic_training_get_forward_backward_func_wrapper) + from mindspeed_llm.training.training import num_floating_point_operations_wrapper + from mindspeed_llm.training.one_logger_utils import track_app_tag_wrapper + patch_manager.register_patch('megatron.core.pipeline_parallel.schedules.forward_step', + forward_step_wrapper) + patch_manager.register_patch( + 'megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.get_parameter_state_dp_zero', + get_parameter_state_dp_zero_wrapper) + patch_manager.register_patch('megatron.core.timers.Timers._get_elapsed_time_all_ranks', + patch_world_size_func_wrapper) + patch_manager.register_patch('megatron.core.timers.Timers._get_all_ranks_time_string', + patch_world_size_func_wrapper) + patch_manager.register_patch('megatron.core.timers.Timers.log', + log_wrapper) + patch_manager.register_patch('megatron.training.utils.is_last_rank', + is_last_rank_wrapper) + patch_manager.register_patch('megatron.core.optimizer_param_scheduler.OptimizerParamScheduler.step', + optimizer_param_scheduler_step_wrapper) + patch_manager.register_patch('megatron.core.pipeline_parallel.schedules.get_forward_backward_func', + elastic_training_get_forward_backward_func_wrapper) + patch_manager.register_patch('megatron.training.one_logger_utils.track_app_tag', + track_app_tag_wrapper) + patch_manager.register_patch('megatron.training.training.num_floating_point_operations', + num_floating_point_operations_wrapper) + patch_manager.register_patch('megatron.training.utils.print_rank_last', + print_rank_last_wrapper) + + + + + diff --git a/mindspeed_llm/features_manager/high_availability/high_availability_helper.py b/mindspeed_llm/features_manager/high_availability/high_availability_helper.py index 45799039513df19625528806d4ba4e3919fb2367..5a8d4e2b991588f5a9e7adc34e8001550ae8edfe 100644 --- a/mindspeed_llm/features_manager/high_availability/high_availability_helper.py +++ b/mindspeed_llm/features_manager/high_availability/high_availability_helper.py @@ -8,7 +8,7 @@ def get_env_args(args): if not env: return args for strategy in env.split(','): - if strategy.lower() in ('dump', 'recover', 'retry'): + if strategy.lower() in ('dump', 'recover', 'retry', 'elastic-training'): if not getattr(args, 'enable_high_availability', False): warnings.warn( "HIGH_AVAILABILITY environment variables enabled and args.enable_high_availability inactive" @@ -18,6 +18,8 @@ def get_env_args(args): args.enable_worker_reboot = True if strategy.lower() == 'retry': args.enable_hbmfault_repair = True + if strategy.lower() == 'elastic-training': + args.enable_elastic_training = True return args diff --git a/mindspeed_llm/training/one_logger_utils.py b/mindspeed_llm/training/one_logger_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dc64eb841604294da7c7661695bc392ea2c55ea3 --- /dev/null +++ b/mindspeed_llm/training/one_logger_utils.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. + +from functools import wraps + + +def track_app_tag_wrapper(fn): + """ + In the context of scale-in training scenarios, change the parameter 'batch_size' + to get_args().global_batch_size. Because every data parallel's num_micro_bathes + may be different. + """ + @wraps(fn) + def wrapper(batch_size, world_size, seq_length): + from taskd.python.adaptor.elastic_training import common + if common.zit_scale_in_running_state(): + from megatron.training import get_args + batch_size = get_args().global_batch_size + return fn(batch_size, world_size, seq_length) + return wrapper \ No newline at end of file diff --git a/mindspeed_llm/training/training.py b/mindspeed_llm/training/training.py index f2f55a599bc8e54630a68bc7903cf5c541ca8d70..f20979f59bdb832f2b998d49b21e6fe15b869199 100644 --- a/mindspeed_llm/training/training.py +++ b/mindspeed_llm/training/training.py @@ -442,6 +442,9 @@ def pretrain(train_valid_test_dataset_provider, if args.enable_high_availability: from mindio_ttp.adaptor import tft_register_processor, tft_train tft_register_processor(train_valid_test_dataset_provider, model_provider, model_type) + if args.enable_elastic_training: + from taskd.python.adaptor.elastic_training import register_callbacks + register_callbacks() iteration, num_floating_point_operations_so_far = tft_train(train_args, test_data_iterator_list) else: iteration, num_floating_point_operations_so_far = train(*train_args) @@ -787,4 +790,18 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, def should_disable_forward_pre_hook(args): """Block forward pre-hook for certain configurations.""" - return not args.use_custom_fsdp and args.use_distributed_optimizer and args.overlap_param_gather \ No newline at end of file + return not args.use_custom_fsdp and args.use_distributed_optimizer and args.overlap_param_gather + + +def num_floating_point_operations_wrapper(fn): + """ + In the context of scale-in training scenarios, change the parameter 'batch_size' + to 'get_args().global_batch_size'. + """ + @wraps(fn) + def wrapper(args, batch_size): + from taskd.python.adaptor.elastic_training import common + if common.zit_scale_in_running_state(): + batch_size = get_args().global_batch_size + return fn(args, batch_size) + return wrapper \ No newline at end of file diff --git a/mindspeed_llm/training/utils.py b/mindspeed_llm/training/utils.py index d751b3abf3e92e9eae3366a54acfff714945f6c5..d3cbe1db9ec055099e62ac5c5145a6c359dbf2df 100644 --- a/mindspeed_llm/training/utils.py +++ b/mindspeed_llm/training/utils.py @@ -774,3 +774,40 @@ def _get_batch_on_this_cp_rank_in_ulysses_cp(batch): batch[key] = val return batch + + +def is_last_rank_wrapper(fn): + @wraps(fn) + def wrapper(): + """ + In the context of scale-in training scenarios, use the scale-in world group to determine + if it is the last rank. + """ + from taskd.python.adaptor.elastic_training import common + if not common.zit_scale_in_running_state(): + return fn() + else: + return torch.distributed.get_rank() == torch.distributed.get_process_group_ranks( + group=common.zit_get_scale_in_world_group())[-1] + return wrapper + + +def print_rank_last_wrapper(fn): + @wraps(fn) + def wrapper(message): + """ + In the context of scale-in training scenarios, use the get_args().global_batch_size to + replace the batch_size. + """ + from taskd.python.adaptor.elastic_training import common + if common.zit_scale_in_running_state(): + args = get_args() + from megatron.core.num_microbatches_calculator import get_num_microbatches + batch_size = args.micro_batch_size * args.data_parallel_size * \ + get_num_microbatches() + src_str = f' global batch size: {batch_size:5d} |' + batch_size = get_args().global_batch_size + dest_str = f' global batch size: {batch_size:5d} |' + message = message.replace(src_str, dest_str) + return fn(message) + return wrapper \ No newline at end of file diff --git a/pretrain_gpt.py b/pretrain_gpt.py index d687a2607914d274c43f137ed7f5d8457e473a43..a54ca3436e27b94290905b6caa9a97a9d61478b8 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -196,7 +196,12 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): ) # Reduce loss for logging. reporting_loss = loss.clone().detach() - torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + try: + from taskd.python.adaptor.elastic_training import common + if not args.enable_elastic_training or not common.zit_scale_in_running_state(): + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + except ImportError: + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) # loss[0] is a view of loss, so it has ._base not None, which triggers assert error # in core/pipeline_parallel/schedule.py::deallocate_output_tensor, calling .clone()