diff --git a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py index 4e55d38c30cd31e0cc0f0463cf1ae046c4ce4280..9307f72a7079c6c2662c0b24b57e4130affe0bbf 100644 --- a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py +++ b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py @@ -24,6 +24,7 @@ class OptimizerMon(ABC): def __init__(self) -> None: self.fp16_to_fp32_param = {} + self.is_stage3 = False @classmethod def set_wrapped_optimizer(cls, wrapped_optimizer): @@ -77,14 +78,35 @@ class OptimizerMon(ABC): update_dict = defaultdict() ratio_dict = defaultdict() param2name = defaultdict() + fp32_partitioned_groups_flat_grad = defaultdict() mix_prec_opt = OptimizerMon.wrapped_optimizer + partition_id = dist.get_rank() + + def get_flatten_grad(self, optimizer, group_idx): + if self.is_stage3 or optimizer.cpu_offload: + return fp32_partitioned_groups_flat[group_idx].grad + elif fp32_partitioned_groups_flat[group_idx].grad is None: + if partition_id == dist.get_world_size() - 1: + fp32_partitioned_groups_flat_grad = optimizer.flatten_dense_tensors_aligned( + optimizer.averaged_gradients[group_idx], + int(optimizer.partition_size[group_idx]) + ).to(fp32_partitioned_groups_flat[group_idx].dtype) + else: + fp32_partitioned_groups_flat_grad = optimizer.flatten( + optimizer.averaged_gradients[group_idx] + ).to(fp32_partitioned_groups_flat[group_idx].dtype) + return fp32_partitioned_groups_flat_grad + else: + return fp32_partitioned_groups_flat[group_idx].grad + for group_idx in range(len(fp32_partitioned_groups_flat)): + fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, mix_prec_opt, group_idx) for name in params2name.values(): - start_idx, end_idx = name2indices[name] - if start_idx > end_idx: + start_idx, end_idx, group_idx, group_with_rank = name2indices[name] + if group_with_rank != partition_id and isinstance(group_with_rank, int): continue - fp32_param = fp32_partitioned_groups_flat[start_idx: end_idx] - fp32_param.grad = fp32_partitioned_groups_flat.grad[start_idx: end_idx] + fp32_param = fp32_partitioned_groups_flat[group_idx][start_idx: end_idx] + fp32_param.grad = fp32_partitioned_groups_flat_grad[group_idx][start_idx: end_idx] param2name[fp32_param] = name state_param = list(mix_prec_opt.state.values())[0] exp_avg = state_param.get("exp_avg", None) @@ -113,6 +135,7 @@ class OptimizerMon(ABC): ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat) monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name]) monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name]) + del fp32_partitioned_groups_flat_grad return MV_Grad_Result(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict, grad=param2name) @@ -149,60 +172,86 @@ class MegatronFP32OptimizerMon(OptimizerMon): return self._fetch_mv_in_adam(monitor, torch_opt, params2name) +class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon): + def fetch_mv(self, monitor, torch_opt, params2name): + return self._fetch_mv_in_adam(monitor, torch_opt, params2name) + + class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon): def get_param_index(self, params2name, name2index): mix_prec_opt = OptimizerMon.wrapped_optimizer - fp16_groups = mix_prec_opt.fp16_partitioned_groups[0] + fp16_groups = mix_prec_opt.fp16_partitioned_groups name2indices = defaultdict() index_length = defaultdict() index = 0 - for idx, param in enumerate(fp16_groups): - index_length[idx] = (index, index + len(param)) - index += len(param) + idx = 0 + for group_idx, fp16_group in enumerate(fp16_groups): + for param in fp16_group: + param_length = len(param.flatten()) + index_length[idx] = (index, index + param_length, group_idx) + index += param_length + idx += 1 for _, name in params2name.items(): idx = name2index[name] - start_idx, end_idx = index_length[idx] - name2indices[name] = (start_idx, end_idx) + start_idx, end_idx, group_idx = index_length[idx] + name2indices[name] = (start_idx, end_idx, group_idx, None) return name2indices def fetch_mv(self, monitor, torch_opt, params2name, name2indices): + self.is_stage3 = True mix_prec_opt = OptimizerMon.wrapped_optimizer - fp32_partitioned_groups_flat = mix_prec_opt.fp32_partitioned_groups_flat[0] + fp32_partitioned_groups_flat = mix_prec_opt.fp32_partitioned_groups_flat return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat) class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon): def get_param_index(self, params2name, name2index): mix_prec_opt = OptimizerMon.wrapped_optimizer - bf16_groups = mix_prec_opt.bit16_groups[0] - flatten_partitioned_fp32_groups = mix_prec_opt.single_partition_of_fp32_groups[0] - padding = mix_prec_opt.groups_padding[0] + padding = mix_prec_opt.groups_padding + world_size = dist.get_world_size() + fp32_length = [0] + for fp32_group_index, single_partition_of_fp32_group in enumerate(mix_prec_opt.single_partition_of_fp32_groups): + fp32_length.append(len(single_partition_of_fp32_group) * world_size + fp32_length[fp32_group_index]) + + def get_group_index(fp32_length, world_size, index): + for i in range(len(fp32_length) - 1): + if fp32_length[i] <= index < fp32_length[i + 1]: + interval_start = fp32_length[i] + interval_length = fp32_length[i + 1] - fp32_length[i] + sub_interval_length = interval_length // world_size + sub_index = (index - interval_start) // sub_interval_length + sub_interval_start = interval_start + sub_index * sub_interval_length + return sub_interval_start, min(sub_index, world_size - 1) + return fp32_length[-1], 0 + + bf16_groups = [] name2indices = defaultdict() index_length = defaultdict() index = 0 - group_idx = dist.get_rank(group=mix_prec_opt.real_dp_process_group[0]) - group_index = group_idx * len(flatten_partitioned_fp32_groups) - need_padding = True if group_idx == dist.get_world_size(group=mix_prec_opt.real_dp_process_group[0]) - 1 else False - def get_new_index(group_index, start_idx, end_idx, length): + idx = 0 + for group_idx, bf16_group in enumerate(mix_prec_opt.bit16_groups): + bf16_groups.extend(bf16_group) + for param in bf16_group: + param_length = len(param.flatten()) + group_index, group_with_rank = get_group_index(fp32_length, world_size, index) + index_length[idx] = (index, index + param_length, group_idx, group_index, group_with_rank) + index += param_length + idx += 1 + group_length = len(bf16_groups) / len(mix_prec_opt.bit16_groups) + for _, name in params2name.items(): + name_index = name2index[name] + start_idx, end_idx, group_idx, group_index, group_with_rank = index_length[name_index] + need_padding = True if group_with_rank == world_size - 1 else False new_start_idx = start_idx - group_index new_end_idx = end_idx - group_index - return max(new_start_idx, 0), min(new_end_idx, length) - for idx, param in enumerate(bf16_groups): - param_length = len(param.flatten()) - index_length[idx] = (index, index + param_length) - index += param_length - for _, name in params2name.items(): - idx = name2index[name] - start_idx, end_idx = index_length[idx] - new_start_idx, new_end_idx = get_new_index(group_index, start_idx, end_idx, len(flatten_partitioned_fp32_groups)) - if need_padding and idx == len(bf16_groups) - 1: - new_end_idx -= padding - name2indices[name] = (new_start_idx, new_end_idx) + if need_padding and group_length - 1 <= name_index <= len(bf16_groups) - 1 and name_index % (group_length - 1) == 0: + new_end_idx -= padding[int(name_index // (group_length - 1) - 1)] + name2indices[name] = (new_start_idx, new_end_idx, group_idx, group_with_rank) return name2indices def fetch_mv(self, monitor, torch_opt, params2name, name2indices): mix_prec_opt = OptimizerMon.wrapped_optimizer - fp32_partitioned_groups_flat = mix_prec_opt.single_partition_of_fp32_groups[0] + fp32_partitioned_groups_flat = mix_prec_opt.single_partition_of_fp32_groups return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat) @@ -216,6 +265,7 @@ class OptimizerMonFactory: "Megatron_Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon, "Megatron_DistributedOptimizer": MegatronDistributedOptimizerMon, "Megatron_FP32Optimizer": MegatronFP32OptimizerMon, + "DeepSpeedZeroOptimizer_Stage0": DeepSpeedZeroOptimizerStage0Mon, "DeepSpeedZeroOptimizer_Stage1_or_2": DeepSpeedZeroOptimizerStage1or2Mon, "DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon, "unknown": DummyOptimizerMon