diff --git a/test/nn/test_linear_functions.py b/test/nn/test_linear_functions.py index 8fd8be8a4ce49bdb43b4addaf3d3cbaa7d77712f..ec7795483a5ad30d4f8f7c575e818ff77647f887 100644 --- a/test/nn/test_linear_functions.py +++ b/test/nn/test_linear_functions.py @@ -4,36 +4,54 @@ import unittest import torch import torch.nn.functional as F import torch_npu +import numpy as np from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] class TestLinearFunctions(TestCase): - @unittest.skip("skip test_linear now") def test_linear(self): - input1 = torch.randn(2, 3, 4) - weight = torch.randn(3, 4) - npu_input = copy.deepcopy(input1).npu() - npu_weight = copy.deepcopy(weight).npu() + cpu_input, npu_input = create_common_tensor([np.float16, 2, [2, 3, 4]], 0, 1) + cpu_weight, npu_weight = create_common_tensor([np.float16, 2, [3, 4]], 0, 1) - cpu_output = F.linear(input1, weight) + cpu_output = F.linear(cpu_input, cpu_weight) + npu_output = F.linear(npu_input, npu_weight) + + self.assertRtolEqual(cpu_output.numpy(), npu_output.cpu().numpy()) + + @unittest.skipIf(DEVICE_NAME == 'Ascend910A', + "fp32 is not supported on 910A, skip this ut for this device type!") + def test_linear_32(self): + cpu_input, npu_input = create_common_tensor([np.float32, 2, [2, 3, 4]], 0, 1) + cpu_weight, npu_weight = create_common_tensor([np.float32, 2, [3, 4]], 0, 1) + + cpu_output = F.linear(cpu_input, cpu_weight) npu_output = F.linear(npu_input, npu_weight) self.assertRtolEqual(cpu_output.numpy(), npu_output.cpu().numpy()) - @unittest.skip("skip test_bilinear now") def test_bilinear(self): - input1 = torch.randn(10, 30) - input2 = torch.randn(10, 40) - weight = torch.randn(5, 30, 40) - bias = torch.randn(5) + cpu_input1, npu_input1 = create_common_tensor([np.float16, 2, [10, 30]], 0, 1) + cpu_input2, npu_input2 = create_common_tensor([np.float16, 2, [10, 40]], 0, 1) + cpu_weight, npu_weight = create_common_tensor([np.float16, 2, [5, 30, 40]], 0, 1) + cpu_bias, npu_bias = create_common_tensor([np.float16, 2, [5]], 0, 1) + + cpu_output = F.bilinear(cpu_input1, cpu_input2, cpu_weight, cpu_bias) + npu_output = F.bilinear(npu_input1, npu_input2, npu_weight, npu_bias) + + self.assertRtolEqual(cpu_output.numpy(), npu_output.cpu().numpy()) - npu_input1 = copy.deepcopy(input1).npu() - npu_input2 = copy.deepcopy(input2).npu() - npu_weight = copy.deepcopy(weight).npu() - npu_bias = copy.deepcopy(bias).npu() + @unittest.skipIf(DEVICE_NAME == 'Ascend910A', + "fp32 is not supported on 910A, skip this ut for this device type!") + def test_bilinear_32(self): + cpu_input1, npu_input1 = create_common_tensor([np.float32, 2, [10, 30]], 0, 1) + cpu_input2, npu_input2 = create_common_tensor([np.float32, 2, [10, 40]], 0, 1) + cpu_weight, npu_weight = create_common_tensor([np.float32, 2, [5, 30, 40]], 0, 1) + cpu_bias, npu_bias = create_common_tensor([np.float32, 2, [5]], 0, 1) - cpu_output = F.bilinear(input1, input2, weight, bias) + cpu_output = F.bilinear(cpu_input1, cpu_input2, cpu_weight, cpu_bias) npu_output = F.bilinear(npu_input1, npu_input2, npu_weight, npu_bias) self.assertRtolEqual(cpu_output.numpy(), npu_output.cpu().numpy())