diff --git a/pretrain_gpt.py b/pretrain_gpt.py index fb6195b633e49b53c6c45286815e1351d35f1c6c..4974dd9d9693f7ca10e40c82ed7e082bda715fea 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -17,6 +17,7 @@ from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegat from megatron.core.datasets.gpt_dataset import GPTDatasetConfig from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset from megatron.core.datasets.utils import get_blend_from_list +from megatron.core.rerun_state_machine import get_rerun_state_machine import megatron.legacy.model from megatron.core.models.gpt import GPTModel from mindspeed_llm.training import pretrain @@ -128,46 +129,76 @@ def get_batch(data_iterator): return batch.values() +# define spiky loss as a loss that's 10x the max loss observed +SPIKY_LOSS_FACTOR = 10 + + def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): """Loss function. Args: loss_mask (torch.Tensor): Used to mask out some portions of the loss output_tensor (torch.Tensor): The tensor with the losses - """ + + Returns: + the loss scalar for this micro-batch + the number of non-padded tokens in this microbatch + a dict containing reporting metrics on the loss and number of tokens across + the data parallel ranks + """ args = get_args() losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() + total_tokens = loss_mask.sum() + loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)]) + if args.context_parallel_size > 1: - loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)]) torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) - loss = loss[0] / loss[1] - else: - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # Check individual rank losses are not NaN prior to DP all-reduce. + rerun_state_machine = get_rerun_state_machine() if args.check_for_nan_in_loss_and_grad: - global_rank = torch.distributed.get_rank() - if loss.isnan(): - raise ValueError(f'Rank {global_rank}: found NaN in local forward loss calculation. ' - f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}') - - if args.async_log_allreduce: - # Reduce loss for logging, which is different from megatron pretrain_gpt.py. - reporting_loss = loss.clone().detach() - - allreduce_handle = torch.distributed.all_reduce( - reporting_loss, group=mpu.get_data_parallel_group(), async_op=True + rerun_state_machine.validate_result( + result=loss[0], + rejection_func=torch.isnan, + message="found NaN in local forward loss calculation", + tolerance=0.0, # forward pass calculations are determinisic + fatal=True, ) - - return loss * args.context_parallel_size, ({"lm loss": (reporting_loss)}, allreduce_handle) - else: - - # Reduce loss for logging. - averaged_loss = average_losses_across_data_parallel_group([loss]) - - return loss * args.context_parallel_size, {'lm loss': averaged_loss[0]} + rerun_state_machine.validate_result( + result=loss[0], + rejection_func=torch.isinf, + message="found Inf in local forward loss calculation", + tolerance=0.0, # forward pass calculations are determinisic + fatal=True, + ) + # Check for spiky loss + if args.check_for_spiky_loss: + rerun_state_machine.validate_result( + result=loss[0], + rejection_func=partial( + rerun_state_machine.is_unexpectedly_large, + threshold=SPIKY_LOSS_FACTOR, + context="loss", + ), + message="Spiky loss", + tolerance=0.0, # forward pass calculations are determinisic + fatal=False, + ) + # Reduce loss for logging. + reporting_loss = loss.clone().detach() + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + + # loss[0] is a view of loss, so it has ._base not None, which triggers assert error + # in core/pipeline_parallel/schedule.py::deallocate_output_tensor, calling .clone() + # on loss[0] fixes this + local_num_tokens = loss[1].clone().detach().to(torch.int) + return ( + loss[0].clone(), + local_num_tokens, + {'lm loss': (reporting_loss[0], reporting_loss[1])}, + ) def forward_step(data_iterator, model: GPTModel): diff --git a/tests/st/baseline_results/llama2_tp2_cp4_general_double_ring.json b/tests/st/baseline_results/llama2_tp2_cp4_general_double_ring.json index 44d19e62714cc55c9a4664004528b09da4c3ee38..ee213e8940cd6ca6dbfac8a7b5b0a6e4542e9cf2 100644 --- a/tests/st/baseline_results/llama2_tp2_cp4_general_double_ring.json +++ b/tests/st/baseline_results/llama2_tp2_cp4_general_double_ring.json @@ -16,6 +16,22 @@ 1.550481, 1.514153 ], + "grad norm": [ + 27.05, + 26.351, + 26.57, + 26.895, + 23.497, + 22.75, + 22.254, + 23.012, + 17.275, + 16.61, + 17.031, + 17.574, + 16.649, + 16.239 + ], "throughput": [ 2.1, 29.7, diff --git a/tests/test_tools/acquire_json.py b/tests/test_tools/acquire_json.py index f0f2053562d6f3ee4a06bbeca7944a4c7ab8b7ae..527b085ea519d5d4fb3fb7f5f9d55255c7ec1f3f 100644 --- a/tests/test_tools/acquire_json.py +++ b/tests/test_tools/acquire_json.py @@ -23,7 +23,7 @@ def transfer_logs_as_json(log_file, output_json_file): """ log_pattern = re.compile( - r"elapsed time per iteration \(ms\):\s+([0-9.]+)\s+\| throughput per GPU \(TFLOP/s/GPU\):\s+([0-9.]+)\s+\| .*?lm loss:\s+([0-9.]+E[+-][0-9]+) | .* actor/pg_loss : ([0-9.]+)" + r"elapsed time per iteration \(ms\):\s+([0-9.]+)\s+\| throughput per GPU \(TFLOP/s/GPU\):\s+([0-9.]+)\s+\| .*?lm loss:\s+([0-9.]+E[+-][0-9]+)\s+\| .*?grad norm:\s+([0-9.]+) | .* actor/pg_loss : ([0-9.]+)" ) if "trl_ppo" in log_file: @@ -37,6 +37,7 @@ def transfer_logs_as_json(log_file, output_json_file): data = { "lm loss": [], + "grad norm": [], "time info": [], "throughput": [], "memo info": [], @@ -50,10 +51,11 @@ def transfer_logs_as_json(log_file, output_json_file): if log_matches: if log_matches[0][1][2] != "": data["lm loss"] = [float(match[2]) for match in log_matches] + data["grad norm"] = [float(match[3]) for match in log_matches] data["time info"] = [float(match[0]) for match in log_matches] data["throughput"] = [float(match[1]) for match in log_matches] else: - data["lm loss"] = [float(match[3]) for match in log_matches] + data["lm loss"] = [float(match[4]) for match in log_matches] if memory_matches: memo_info = [ diff --git a/tests/test_tools/test_ci_st.py b/tests/test_tools/test_ci_st.py index 66de3a0bfe6b16fa0467a45c8d0612f7d9dd2cd2..7ab969e6c538db985022165a75cec0cbe17aa2ee 100644 --- a/tests/test_tools/test_ci_st.py +++ b/tests/test_tools/test_ci_st.py @@ -7,6 +7,7 @@ MEMO_INFO = "memo info" THROUGHPUT = "throughput" LOSS = "lm loss" TIME_INFO = "time info" +GRAD_NORM_INFO = "grad norm" WARM_UP = 5 @@ -17,6 +18,7 @@ class TestMargin: memory = 0.1 throughput = 0.05 time = 0.05 + grad_norm = 0.1 @classmethod def refresh_margin_from_json(cls, json_obj): @@ -24,11 +26,13 @@ class TestMargin: cls.memory = json_obj.get(MEMO_INFO + cls._MARGIN_NAME, cls.memory) cls.throughput = json_obj.get(THROUGHPUT + cls._MARGIN_NAME, cls.throughput) cls.time = json_obj.get(TIME_INFO + cls._MARGIN_NAME, cls.time) + cls.grad_norm = json_obj.get(GRAD_NORM_INFO + cls._MARGIN_NAME, cls.grad_norm) class TestCIST: - margin_loss = 0.02 + margin_loss = 0.02 + margin_grad_norm = 0.1 margin_throughput_percent = 0.05 margin_memory_percent = 0.1 margin_time_percent = 0.05 @@ -65,11 +69,18 @@ class TestCIST: comparison_time = { TIME_INFO: self._compare_time, } + + comparison_grad_norm = { + GRAD_NORM_INFO: self._compare_grad_norm, + } if "time info" in self.expected: comparison_selection = {**comparison_base, **comparison_time} else: comparison_selection = {**comparison_base, **comparison_throughput} + + if "grad norm" in self.expected: + comparison_selection = {**comparison_selection, **comparison_grad_norm} if test_obj in comparison_selection: expected_list = self.expected[test_obj] @@ -99,6 +110,16 @@ class TestCIST: print(f"Checking step {step + 1} for lm loss") assert actual_val == pytest.approx(expected=expected_val, rel=TestMargin.loss),\ f"The loss at step {step} should be approximate to {expected_val} but it is {actual_val}." + + + def _compare_grad_norm(self, expected_list, actual_list): + # Because "deterministic computation" affects the throughput, so we just test + # grad norm in case of approximation. + for step, (expected_val, actual_val) in enumerate(zip(expected_list, actual_list)): + print(f"Checking step {step + 1} for grad norm") + assert actual_val == pytest.approx(expected=expected_val, rel=TestMargin.grad_norm),\ + f"The grad norm at step {step} should be approximate to {expected_val} but it is {actual_val}." + def _compare_throughput(self, expected_list, actual_list): # First few iterations might take a little longer. So we take the last 70 percent of the timings @@ -111,7 +132,8 @@ class TestCIST: assert actual_avg_throughput >= expected_avg_throughput or \ abs(actual_avg_throughput - expected_avg_throughput) / expected_avg_throughput <= TestMargin.throughput, \ f"The actual avg throughput {actual_avg_throughput} degradate expected avg throughput {expected_avg_throughput}" - + + def _compare_time(self, expected_list, actual_list): try: expected_avg_time = sum(expected_list[WARM_UP:]) / (len(expected_list) - WARM_UP) @@ -134,22 +156,33 @@ class TestCIST: abs(actual_val["max allocated memory"] - expected_val["max allocated memory"]) / expected_val["max allocated memory"] <= TestMargin.memory, \ f'The actual max memory {actual_val["max allocated memory"]} seems to be abnormal compare to expected {expected_val["max allocated memory"]}.' + def test_lm_loss_approx(self, baseline_json, generate_log, generate_json): # expected training loss curve at different global steps. self._get_baseline(baseline_json) self._get_actual(generate_log, generate_json) self._test_helper("lm loss") - + + + def test_grad_norm_approx(self, baseline_json, generate_log, generate_json): + # expected training loss curve at different global steps. + self._get_baseline(baseline_json) + self._get_actual(generate_log, generate_json) + self._test_helper("grad norm") + + def test_througpout(self, baseline_json, generate_log, generate_json): self._get_baseline(baseline_json) self._get_actual(generate_log, generate_json) self._test_helper("throughput") - + + def test_time(self, baseline_json, generate_log, generate_json): self._get_baseline(baseline_json) self._get_actual(generate_log, generate_json) self._test_helper("time info") + def test_allocated_memory(self, baseline_json, generate_log, generate_json): self._get_baseline(baseline_json) self._get_actual(generate_log, generate_json)