diff --git a/test/test_custom_ops/test_amp_foreach_non_finite_check_and_unscale.py b/test/test_custom_ops/test_amp_foreach_non_finite_check_and_unscale.py new file mode 100644 index 0000000000000000000000000000000000000000..820bb119a9e4d5e418151048a78cfb5ffd0aac2d --- /dev/null +++ b/test/test_custom_ops/test_amp_foreach_non_finite_check_and_unscale.py @@ -0,0 +1,87 @@ +# Copyright (c) 2022, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.decorator import Dtypes, instantiate_tests + +@instantiate_tests +class TestAmpForeachNonFiniteCheckAndUnscale(TestCase): + + @Dtypes(torch.float32, torch.float16) + def test_grad_scaling_unscale(self, dtype, device="npu"): + + def _clear_float_status(): + float_status = torch.zeros(8).npu() + result = torch.npu_clear_float_status(float_status) + + def _get_float_status(): + float_status = torch.zeros(8).npu() + result = torch.npu_get_float_status(float_status) + + print('------------------start------------------') + inv_scale = torch.full((1,), 0.25, dtype=torch.float, device=device) + found_inf = torch.full((1,), 0.0, dtype=torch.float, device=device) + print('inv_scale: ', inv_scale) + + size = 6 + g = torch.full((size, size), 4.0, dtype=dtype, device=device) + + cases = ( + ([g.clone(), g.clone()], False), + ([g.clone(), g.clone().t()], True), + ([g.clone(), g.clone()[:, :5]], False), + ([g.clone()[:, :5], g.clone()[:, :5]], False), + ([g.clone(), g.clone()], True), + ([g.clone(), g.clone().t()], False), + ) + + for grads, has_inf in cases: + print('----------------------------') + found_inf.zero_() + _clear_float_status() + + if has_inf: + ginf = g.clone() + ginf[2,2].mul_(torch.finfo(dtype).max) + torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, inv_scale) + _get_float_status() + self.assertEqual(found_inf, 1.0) + else: + torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, inv_scale) + _get_float_status() + self.assertEqual(found_inf, 0.0) + for grad in grads: + self.assertTrue(torch.allclose(grad, torch.ones_like(grad), atol=1e-7)) + + _clear_float_status() + + # Passing lists with mismatched devices or dtypes to a raw + # _amp_foreach_non_finite_check_and_unscale_ call should raise errors. + with self.assertRaisesRegex(RuntimeError, r"must have the same dtype"): + if dtype==torch.float16: + torch._amp_foreach_non_finite_check_and_unscale_([g.clone(), g.to(dtype=torch.float32)], + found_inf, + inv_scale) + else: + torch._amp_foreach_non_finite_check_and_unscale_([g.clone(), g.to(dtype=torch.float16)], + found_inf, + inv_scale) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 2be97a676055c5bc8886689146e5caf71e42c536..e23750f94930b2b9b2705aafc75abcb8ecc4c7d9 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1877,6 +1877,7 @@ custom: variants: method, function - func: fast_gelu_backward(Tensor grad, Tensor self) -> Tensor variants: function, method + - func: _amp_foreach_non_finite_check_(Tensor scaled_grad, Tensor found_inf) -> bool - func: npu_bert_apply_adam(Scalar lr, Scalar beta1, Scalar beta2, Scalar epsilon, Tensor grad, Scalar max_grad_norm, Scalar global_grad_norm, Scalar weight_decay, Scalar? step_size=None, int adam_mode=0) -> (Tensor var, Tensor m, Tensor v) - func: npu_bert_apply_adam.out(Scalar lr, Scalar beta1, Scalar beta2, Scalar epsilon, Tensor grad, Scalar max_grad_norm, Scalar global_grad_norm, Scalar weight_decay, Scalar? step_size=None, int adam_mode=0, *, Tensor(a!) var, Tensor(b!) m, Tensor(c!) v) -> (Tensor(a!), Tensor(b!), Tensor(c!)) - func: npu_conv_transpose2d_backward(Tensor input, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) diff --git a/torch_npu/csrc/aten/ops/AmpForeachNonFiniteCheckAndUnscaleKernelNpu.cpp b/torch_npu/csrc/aten/ops/AmpForeachNonFiniteCheckAndUnscaleKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ba5488e43633d29051cea4cc1301379d502cf770 --- /dev/null +++ b/torch_npu/csrc/aten/ops/AmpForeachNonFiniteCheckAndUnscaleKernelNpu.cpp @@ -0,0 +1,75 @@ +// Copyright (c) 2022 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +const int FLOAT_STATUS_OP_DIMS_SIZE = 8; + +bool NPUNativeFunctions::_amp_foreach_non_finite_check_(const at::Tensor& scaled_grad, + const at::Tensor& found_inf) { + TORCH_WARN_ONCE("Non finite check on NPU device!"); + TORCH_CHECK(at_npu::key::isDeviceTensor(found_inf), "found_inf must be NPU-Tensor"); + TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor"); + TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor"); + + auto options = at::TensorOptions(at_npu::key::NativeDeviceType).dtype(at::kFloat); + at::Tensor float_status = at::zeros({FLOAT_STATUS_OP_DIMS_SIZE}, options); + at::Tensor result = NPUNativeFunctions::npu_get_float_status(float_status); + + if (float_status[0].item().to() != 0) { + found_inf.add_(1); + return true; + } else { + return false; + } +} + +void NPUNativeFunctions::_amp_foreach_non_finite_check_and_unscale_(at::TensorList scaled_grads, + at::Tensor& found_inf, + const at::Tensor& inv_scale) { + TORCH_WARN_ONCE("Non finite check and unscale on NPU device!"); + TORCH_CHECK(at_npu::key::isDeviceTensor(inv_scale), "inv_scale must be NPU-Tensor"); + TORCH_CHECK(inv_scale.numel() == 1, "inv_scale must be a 1-element tensor"); + TORCH_CHECK(inv_scale.scalar_type() == at::ScalarType::Float, "inv_scale must be a float tensor"); + + if (scaled_grads.size() == 0) { + return; + } + + if (NPUNativeFunctions::_amp_foreach_non_finite_check_(scaled_grads[0], found_inf) == 0) { + auto expected_device = scaled_grads[0].device(); + auto expected_dtype = scaled_grads[0].dtype(); + for (auto t : scaled_grads) { + TORCH_CHECK(at_npu::key::isDeviceTensor(t), "one of scaled_grads was not a NPU tensor."); + TORCH_CHECK(t.device() == expected_device, "scaled_grads must be on the same device."); + TORCH_CHECK(t.dtype() == expected_dtype, "scaled_grads must have the same dtype."); + TORCH_CHECK(t.layout() == at::kStrided, "one of scaled_grads was not a strided tensor."); + + t.mul_(inv_scale); + } + } + + auto options = at::TensorOptions(at_npu::key::NativeDeviceType).dtype(at::kFloat); + at::Tensor float_status = at::zeros({FLOAT_STATUS_OP_DIMS_SIZE}, options); + at::Tensor result = NPUNativeFunctions::npu_clear_float_status(float_status); +} +} +} \ No newline at end of file