diff --git a/torch_npu/npu/amp/grad_scaler.py b/torch_npu/npu/amp/grad_scaler.py index 4aaf69f841954b4393514c3cd7de5eb92e274117..e2aa937dcbf89eab31da15798a2bace2f2cab921 100644 --- a/torch_npu/npu/amp/grad_scaler.py +++ b/torch_npu/npu/amp/grad_scaler.py @@ -209,6 +209,7 @@ class GradScaler(Cuda_GradScaler): def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): per_device_found_inf = _NpuMultiDeviceReplicator(found_inf) + per_device_inv_scale = _NpuMultiDeviceReplicator(inv_scale) # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. # There could be hundreds of grads, so we'd like to iterate through them just once. @@ -216,35 +217,41 @@ class GradScaler(Cuda_GradScaler): # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict # Google says mypy struggles with defaultdicts type annotations. - def unscale(group, inv_scale, allow_fp16): - for param in group["params"]: - if param.grad is None: - continue - if (not allow_fp16) and param.grad.dtype == torch.float16: - raise ValueError("Attempting to unscale FP16 gradients.") - if param.grad.is_sparse: - # is_coalesced() == False means the sparse grad has values with duplicate indices. - # coalesce() deduplicates indices and adds all values that have the same index. - # For scaled fp16 values, there's a good chance coalescing will cause overflow, - # so we should check the coalesced _values(). - if param.grad.dtype is torch.float16: - param.grad = param.grad.coalesce() - to_unscale = param.grad._values() - else: - to_unscale = param.grad - to_unscale.mul_(inv_scale) - + + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) with torch.no_grad(): if self._dynamic: - self._has_overflow = get_npu_overflow_flag() + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue + if (not allow_fp16) and param.grad.dtype == torch.float16: + raise ValueError("Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype == torch.float16: + param.grad = param.grad.coalesce() + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + # TODO: is there a way to split by device and dtype without appending in the inner loop? + per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale) + + for device, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._amp_foreach_non_finite_check_and_unscale_(grads, + per_device_found_inf.get(device), + per_device_inv_scale.get(device)) + if per_device_found_inf.get(device)[0].item() > 0: + self._has_overflow = True + self._sync_dist_overflow_count() - per_device_found_inf_tensor = per_device_found_inf.get(found_inf.device) if self._has_overflow: - per_device_found_inf_tensor.add_(1) - return per_device_found_inf._per_device_tensors - - for group in optimizer.param_groups: - unscale(group, inv_scale, allow_fp16) + per_device_found_inf.get(found_inf.device).add_(1) return per_device_found_inf._per_device_tensors