From 2c6d3c85217f99746cf6344347c42df450bd6e47 Mon Sep 17 00:00:00 2001 From: zhoufan37 Date: Fri, 11 Feb 2022 10:36:14 +0800 Subject: [PATCH] ctc_loss_backward Operator --- test/test_network_ops/test_clamp.py | 160 +++++++++++++++++ test/test_network_ops/test_clamp_max.py | 160 +++++++++++++++++ test/test_network_ops/test_clamp_min.py | 159 +++++++++++++++++ test/test_network_ops/test_ctc_loss.py | 94 ++++++++++ .../test_ctc_loss_backward.py | 102 +++++++++++ test/test_network_ops/test_where.py | 123 +++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 3 - torch_npu/csrc/aten/ops/ClampKernelNpu.cpp | 165 ++++++++++++++++++ .../aten/ops/CtcLossBackwardKernelNpu.cpp | 87 +++++++++ torch_npu/csrc/aten/ops/CtcLossKernelNpu.cpp | 140 +++++++++++++++ torch_npu/csrc/aten/ops/WhereKernelNpu.cpp | 123 +++++++++++++ 11 files changed, 1313 insertions(+), 3 deletions(-) create mode 100644 test/test_network_ops/test_clamp.py create mode 100644 test/test_network_ops/test_clamp_max.py create mode 100644 test/test_network_ops/test_clamp_min.py create mode 100644 test/test_network_ops/test_ctc_loss.py create mode 100644 test/test_network_ops/test_ctc_loss_backward.py create mode 100644 test/test_network_ops/test_where.py create mode 100644 torch_npu/csrc/aten/ops/ClampKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/CtcLossBackwardKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/CtcLossKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/WhereKernelNpu.cpp diff --git a/test/test_network_ops/test_clamp.py b/test/test_network_ops/test_clamp.py new file mode 100644 index 00000000000..89f28ee6505 --- /dev/null +++ b/test/test_network_ops/test_clamp.py @@ -0,0 +1,160 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestClamp(TestCase): + def generate_data(self, data): + input1 = np.random.uniform(data[0], data[1], data[2]).astype(data[3]) + + #modify from numpy.ndarray to torch.tensor + input1 = torch.from_numpy(input1) + + return input1 + + def npu_op_exec(self, input1, min_val, max_val): + input1 = input1.to("npu") + output = torch.clamp(input1, min_val, max_val) + output = output.to("cpu") + output = output.numpy() + + return output + + def cpu_op_exec(self, input1, min_val, max_val): + output = torch.clamp(input1, min_val,max_val) + output = output.numpy() + + return output + + def cpu_op_exec_float16(self, input1, min_val, max_val): + input1 = input1.to(torch.float32) + output = torch.clamp(input1, min_val, max_val).to(torch.float16) + output = output.numpy() + + return output + + def npu_inp_op_exec(self, input1, min_val, max_val): + input1 = input1.to("npu") + output = torch.clamp_(input1, min_val, max_val) + output = input1.to("cpu") + output = output.numpy() + + return output + + def cpu_inp_op_exec(self, input1, min_val, max_val): + output = torch.clamp_(input1, min_val, max_val) + output = output.numpy() + + return output + + def cpu_inp_op_exec_float16(self, input1, min_val, max_val): + input1 = input1.to(torch.float32) + output = torch.clamp_(input1, min_val, max_val).to(torch.float16) + output = output.numpy() + + return output + + def npu_op_exec_out(self, input1, min_val, max_val, input2): + input1 = input1.to("npu") + output = input2.to("npu") + torch.clamp(input1, min_val, max_val, out=output) + output = output.to("cpu") + output = output.numpy() + + return output + + def npu_inp_uncon_op_exec(self, input1, min_val, max_val): + input1 = input1.to("npu") + input1 = input1.as_strided([2, 2], [1, 2], 2) + output = torch.clamp_(input1, min_val, max_val) + output = input1.to("cpu") + output = output.numpy() + + return output + + def cpu_inp_uncon_op_exec(self, input1, min_val, max_val): + input1 = input1.as_strided([2, 2], [1, 2], 2) + output = torch.clamp(input1, min_val, max_val) + output = output.numpy() + + return output + + def cpu_inp_uncon_op_exec_float16(self, input1, min_val, max_val): + input1 = input1.to(torch.float32).as_strided([2, 2], [1, 2], 2) + output = torch.clamp(input1, min_val, max_val).to(torch.float16) + output = output.numpy() + + return output + + def test_clamp_common(self, device): + shape_format = [ + [1, 100, (4, 3), np.float32], + [1, 100, (4, 3), np.int32], + ] + for item in shape_format: + input1 = self.generate_data(item) + + cpu_output = self.cpu_op_exec(input1, 40, 60) + npu_output = self.npu_op_exec(input1, 40, 60) + + cpu_inp_output = self.cpu_inp_op_exec(input1, 40, 60) + npu_inp_output = self.npu_inp_op_exec(input1, 40, 60) + + input2 = self.generate_data(item) + npu_out_output = self.npu_op_exec_out(input1, 40, 60, input2) + + cpu_inp_uncon_output = self.cpu_inp_uncon_op_exec(input1, 40, 60) + npu_inp_uncon_output = self.npu_inp_uncon_op_exec(input1, 40, 60) + + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_inp_output, npu_inp_output) + self.assertRtolEqual(cpu_output, npu_out_output) + self.assertRtolEqual(cpu_inp_uncon_output, npu_inp_uncon_output) + + def test_clamp_float16(self, device): + shape_format = [ + [1, 100, (4, 3), np.float16], + ] + for item in shape_format: + input1 = self.generate_data(item) + + cpu_output = self.cpu_op_exec_float16(input1, 40, 60) + npu_output = self.npu_op_exec(input1, 40, 60) + + cpu_inp_output = self.cpu_inp_op_exec_float16(input1, 40, 60) + npu_inp_output = self.npu_inp_op_exec(input1, 40, 60) + + input2 = self.generate_data(item) + npu_out_output = self.npu_op_exec_out(input1, 40, 60, input2) + + cpu_inp_uncon_output = self.cpu_inp_uncon_op_exec_float16(input1, 40, 60) + npu_inp_uncon_output = self.npu_inp_uncon_op_exec(input1, 40, 60) + + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_inp_output, npu_inp_output) + self.assertRtolEqual(cpu_output, npu_out_output) + self.assertRtolEqual(cpu_inp_uncon_output, npu_inp_uncon_output) + + +instantiate_device_type_tests(TestClamp, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_clamp_max.py b/test/test_network_ops/test_clamp_max.py new file mode 100644 index 00000000000..88213e8b6fe --- /dev/null +++ b/test/test_network_ops/test_clamp_max.py @@ -0,0 +1,160 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestClampMax(TestCase): + def generate_data(self, data): + input1 = np.random.uniform(data[0], data[1], data[2]).astype(data[3]) + + #modify from numpy.ndarray to torch.tensor + input1 = torch.from_numpy(input1) + + return input1 + + def npu_op_exec(self, input1, max_val): + input1 = input1.to("npu") + output = torch.clamp_max(input1, max_val) + output = output.to("cpu") + output = output.numpy() + + return output + + def cpu_op_exec(self, input1, max_val): + output = torch.clamp_max(input1, max_val) + output = output.numpy() + + return output + + def cpu_op_exec_float16(self, input1, max_val): + input1 = input1.to(torch.float32) + output = torch.clamp_max(input1, max_val).to(torch.float16) + output = output.numpy() + + return output + + def npu_inp_op_exec(self, input1, max_val): + input1 = input1.to("npu") + output = torch.clamp_max_(input1, max_val) + output = input1.to("cpu") + output = output.numpy() + + return output + + def cpu_inp_op_exec(self, input1, max_val): + output = torch.clamp_max_(input1, max_val) + output = output.numpy() + + return output + + def cpu_inp_op_exec_float16(self, input1, max_val): + input1 = input1.to(torch.float32) + output = torch.clamp_max_(input1, max_val).to(torch.float16) + output = output.numpy() + + return output + + def npu_op_exec_out(self, input1, max_val, input2): + input1 = input1.to("npu") + output = input2.to("npu") + torch.clamp_max(input1, max_val, out=output) + output = output.to("cpu") + output = output.numpy() + + return output + + def npu_inp_uncon_op_exec(self, input1, max_val): + input1 = input1.to("npu") + input1 = input1.as_strided([2, 2], [1, 2], 2) + output = torch.clamp_max_(input1, max_val) + output = input1.to("cpu") + output = output.numpy() + + return output + + def cpu_inp_uncon_op_exec(self, input1, max_val): + input1 = input1.as_strided([2, 2], [1, 2], 2) + output = torch.clamp_max(input1, max_val) + output = output.numpy() + + return output + + def cpu_inp_uncon_op_exec_float16(self, input1, max_val): + input1 = input1.to(torch.float32).as_strided([2, 2], [1, 2], 2) + output = torch.clamp_max(input1, max_val).to(torch.float16) + output = output.numpy() + + return output + + def test_clamp_max_common(self, device): + shape_format = [ + [1, 100, (4, 3), np.float32], + [1, 100, (4, 3), np.int32], + ] + for item in shape_format: + input1 = self.generate_data(item) + + cpu_output = self.cpu_op_exec(input1, 50) + npu_output = self.npu_op_exec(input1, 50) + + cpu_inp_output = self.cpu_inp_op_exec(input1, 50) + npu_inp_output = self.npu_inp_op_exec(input1, 50) + + input2 = self.generate_data(item) + npu_out_output = self.npu_op_exec_out(input1, 50, input2) + + cpu_inp_uncon_output = self.cpu_inp_uncon_op_exec(input1, 50) + npu_inp_uncon_output = self.npu_inp_uncon_op_exec(input1, 50) + + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_inp_output, npu_inp_output) + self.assertRtolEqual(cpu_output, npu_out_output) + self.assertRtolEqual(cpu_inp_uncon_output, npu_inp_uncon_output) + + def test_clamp_max_float16(self, device): + shape_format = [ + [1, 100, (4, 3), np.float16], + ] + for item in shape_format: + input1 = self.generate_data(item) + + cpu_output = self.cpu_op_exec_float16(input1, 50) + npu_output = self.npu_op_exec(input1, 50) + + cpu_inp_output = self.cpu_inp_op_exec_float16(input1, 50) + npu_inp_output = self.npu_inp_op_exec(input1, 50) + + input2 = self.generate_data(item) + npu_out_output = self.npu_op_exec_out(input1, 50, input2) + + cpu_inp_uncon_output = self.cpu_inp_uncon_op_exec_float16(input1, 50) + npu_inp_uncon_output = self.npu_inp_uncon_op_exec(input1, 50) + + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_inp_output, npu_inp_output) + self.assertRtolEqual(cpu_output, npu_out_output) + self.assertRtolEqual(cpu_inp_uncon_output, npu_inp_uncon_output) + + +instantiate_device_type_tests(TestClampMax, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_clamp_min.py b/test/test_network_ops/test_clamp_min.py new file mode 100644 index 00000000000..10baef73b89 --- /dev/null +++ b/test/test_network_ops/test_clamp_min.py @@ -0,0 +1,159 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestClampMin(TestCase): + def generate_data(self, data): + input1 = np.random.uniform(data[0], data[1], data[2]).astype(data[3]) + + #modify from numpy.ndarray to torch.tensor + input1 = torch.from_numpy(input1) + + return input1 + + def npu_op_exec(self, input1, min_val): + input1 = input1.to("npu") + output = torch.clamp_min(input1, min_val) + output = output.to("cpu") + output = output.numpy() + + return output + + def cpu_op_exec(self, input1, min_val): + output = torch.clamp_min(input1, min_val) + output = output.numpy() + + return output + + def cpu_op_exec_float16(self, input1, min_val): + input1 = input1.to(torch.float32) + output = torch.clamp_min(input1, min_val).to(torch.float16) + output = output.numpy() + + return output + + def npu_inp_op_exec(self, input1, min_val): + input1 = input1.to("npu") + output = torch.clamp_min_(input1, min_val) + output = input1.to("cpu") + output = output.numpy() + + return output + + def cpu_inp_op_exec(self, input1, min_val): + output = torch.clamp_min_(input1, min_val) + output = output.numpy() + + return output + + def cpu_inp_op_exec_float16(self, input1, min_val): + input1 = input1.to(torch.float32) + output = torch.clamp_min_(input1, min_val).to(torch.float16) + output = output.numpy() + + return output + + def npu_op_exec_out(self, input1, min_val, input2): + input1 = input1.to("npu") + output = input2.to("npu") + torch.clamp_min(input1, min_val, out=output) + output = output.to("cpu") + output = output.numpy() + + return output + + def npu_inp_uncon_op_exec(self, input1, min_val): + input1 = input1.to("npu") + input1 = input1.as_strided([2, 2], [1, 2], 2) + output = torch.clamp_min_(input1, min_val) + output = input1.to("cpu") + output = output.numpy() + + return output + + def cpu_inp_uncon_op_exec(self, input1, min_val): + input1 = input1.as_strided([2, 2], [1, 2], 2) + output = torch.clamp_min(input1, min_val) + output = output.numpy() + + return output + + def cpu_inp_uncon_op_exec_float16(self, input1, min_val): + input1 = input1.to(torch.float32).as_strided([2, 2], [1, 2], 2) + output = torch.clamp_min(input1, min_val).to(torch.float16) + output = output.numpy() + + return output + + def test_clamp_min_common(self, device): + shape_format2 = [ + [1, 100, (4, 3), np.float32], + [1, 100, (4, 3), np.int32], + ] + for item in shape_format2: + input3 = self.generate_data(item) + + cpu_output = self.cpu_op_exec(input3, 50) + npu_output = self.npu_op_exec(input3, 50) + + cpu_inp_output = self.cpu_inp_op_exec(input3, 50) + npu_inp_output = self.npu_inp_op_exec(input3, 50) + + input4 = self.generate_data(item) + npu_out_output = self.npu_op_exec_out(input3, 50, input4) + + cpu_inp_uncon_output = self.cpu_inp_uncon_op_exec(input3, 50) + npu_inp_uncon_output = self.npu_inp_uncon_op_exec(input3, 50) + + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_inp_output, npu_inp_output) + self.assertRtolEqual(cpu_output, npu_out_output) + self.assertRtolEqual(cpu_inp_uncon_output, npu_inp_uncon_output) + + def test_clamp_min_float16(self, device): + shape_format3 = [ + [1, 100, (4, 3), np.float16], + ] + for item in shape_format2: + input3 = self.generate_data(item) + + cpu_output = self.cpu_op_exec_float16(input3, 50) + npu_output = self.npu_op_exec(input3, 50) + + cpu_inp_output = self.cpu_inp_op_exec_float16(input3, 50) + npu_inp_output = self.npu_inp_op_exec(input3, 50) + + input4 = self.generate_data(item) + npu_out_output = self.npu_op_exec_out(input3, 50, input4) + + cpu_inp_uncon_output = self.cpu_inp_uncon_op_exec_float16(input3, 50) + npu_inp_uncon_output = self.npu_inp_uncon_op_exec(input3, 50) + + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_inp_output, npu_inp_output) + self.assertRtolEqual(cpu_output, npu_out_output) + self.assertRtolEqual(cpu_inp_uncon_output, npu_inp_uncon_output) + +instantiate_device_type_tests(TestClampMin, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_ctc_loss.py b/test/test_network_ops/test_ctc_loss.py new file mode 100644 index 00000000000..ef8db2dbd68 --- /dev/null +++ b/test/test_network_ops/test_ctc_loss.py @@ -0,0 +1,94 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestCtcLoss(TestCase): + def generate_data(self, item): + T = item[0][0] + C = item[0][1] + N = item[0][2] + S = item[0][3] + S_min = item[0][4] + dtype = item[1] + reduction_str = item[2] + blk = item[3] + + log_probs = np.random.uniform(-10, 10, (T, N, C)).astype(dtype) + targets = torch.randint(1, C, (N, S), dtype = torch.long) + input_lengths = torch.full((N,), T, dtype=torch.long) + target_lengths = torch.randint(S_min, S, (N,), dtype=torch.long) + + # modify from numpy.ndarray to torch.tensor + log_probs = torch.from_numpy(log_probs) + + ctc_loss = torch.nn.CTCLoss(blank= blk, zero_infinity=True, reduction=reduction_str) + + list1 = [ctc_loss, log_probs, targets, input_lengths, target_lengths] + + return list1 + + def cpu_op_exec(self, ctc_loss, log_probs, targets, input_lengths, target_lengths): + if log_probs.dtype == torch.float16: + log_probs = log_probs.to(torch.float32) + + neg_log_likelihood = ctc_loss(log_probs.log_softmax(1), targets, input_lengths, target_lengths) + + neg_log_likelihood = neg_log_likelihood.numpy() + + return neg_log_likelihood + + def npu_op_exec(self, ctc_loss, log_probs, targets, input_lengths, target_lengths): + log_probs = log_probs.npu() + targets = targets.npu() + input_lengths = input_lengths.npu() + target_lengths = target_lengths.npu() + + neg_log_likelihood = ctc_loss(log_probs.log_softmax(1), targets, input_lengths, target_lengths) + + if neg_log_likelihood.dtype == torch.float16: + neg_log_likelihood = neg_log_likelihood.to(torch.float32) + + neg_log_likelihood = neg_log_likelihood.cpu().numpy() + + return neg_log_likelihood + + def test_ctc_loss(self, device): + sizes_list = [[50, 20, 16, 30, 10], [26, 37, 256, 18, 10]] + para_reduction = ["sum", "mean", "none"] + dtype = [np.float32, np.float16] + blank = [0, 9] + shape_format = [ + [i, j, k, l] for i in sizes_list for j in dtype for k in para_reduction for l in blank + ] + + for item in shape_format: + getlist1 = self.generate_data(item) + + neg_log_likelihood_cpu = self.cpu_op_exec(getlist1[0], getlist1[1], getlist1[2], getlist1[3], getlist1[4]) + neg_log_likelihood_npu = self.npu_op_exec(getlist1[0], getlist1[1], getlist1[2], getlist1[3], getlist1[4]) + + self.assertRtolEqual(neg_log_likelihood_cpu, neg_log_likelihood_npu, 1e-3) + +instantiate_device_type_tests(TestCtcLoss, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_ctc_loss_backward.py b/test/test_network_ops/test_ctc_loss_backward.py new file mode 100644 index 00000000000..97df8d1b88c --- /dev/null +++ b/test/test_network_ops/test_ctc_loss_backward.py @@ -0,0 +1,102 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestCtcLossBackward(TestCase): + def generate_data(self, item): + T = item[0][0] + C = item[0][1] + N = item[0][2] + S = item[0][3] + S_min = item[0][4] + dtype = item[1] + reduction_str = item[2] + + log_probs = np.random.uniform(-10, 10, (T, N, C)).astype(dtype) + targets = torch.randint(1, C, (N, S), dtype = torch.long) + input_lengths = torch.full((N,), T, dtype=torch.long) + target_lengths = torch.randint(S_min, S, (N,), dtype=torch.long) + + # modify from numpy.ndarray to torch.tensor + log_probs = torch.from_numpy(log_probs) + + ctc_loss = torch.nn.CTCLoss(zero_infinity=True, reduction=reduction_str) + + list1 = [ctc_loss, log_probs, targets, input_lengths, target_lengths] + + return list1 + + def cpu_op_exec(self, ctc_loss, log_probs, targets, input_lengths, target_lengths): + if log_probs.dtype == torch.float16: + log_probs = log_probs.to(torch.float32) + + log_probs.requires_grad_(True) + log_probs.retain_grad() + + neg_log_likelihood = ctc_loss(log_probs.log_softmax(1), targets, input_lengths, target_lengths) + neg_log_likelihood.backward() + grad = log_probs.grad + + grad = grad.numpy() + + return grad + + def npu_op_exec(self, ctc_loss, log_probs, targets, input_lengths, target_lengths): + log_probs = copy.deepcopy(log_probs).npu() + targets = targets.npu() + log_probs.requires_grad_(True) + log_probs.retain_grad() + + neg_log_likelihood = ctc_loss(log_probs.log_softmax(1), targets, input_lengths.npu(), target_lengths.npu()) + neg_log_likelihood.backward() + grad = log_probs.grad + + if grad.dtype == torch.float16: + grad = grad.to(torch.float32) + + grad = grad.cpu().numpy() + + return grad + + def test_ctc_loss_backward(self, device): + sizes_list = [[50, 20, 16, 30, 10], [26, 37, 2560, 18, 10]] + para_reduction = ["sum", "mean"] + dtype = [np.float32] # Insufficient accuracy when use fp16 data + shape_format = [ + [i, j, k] for i in sizes_list for j in dtype for k in para_reduction + ] + + for item in shape_format: + getlist1 = self.generate_data(item) + + grad_cpu = self.cpu_op_exec(getlist1[0], getlist1[1], getlist1[2], getlist1[3], getlist1[4]) + grad_npu = self.npu_op_exec(getlist1[0], getlist1[1], getlist1[2], getlist1[3], getlist1[4]) + + self.assertRtolEqual(grad_cpu, grad_npu, 1e-3) + + + +instantiate_device_type_tests(TestCtcLossBackward, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_where.py b/test/test_network_ops/test_where.py new file mode 100644 index 00000000000..cc58c37f7f2 --- /dev/null +++ b/test/test_network_ops/test_where.py @@ -0,0 +1,123 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestWhere(TestCase): + def cpu_op_exec(self, input1): + output = torch.where(input1) + output = list(output) + output[0] = output[0].numpy().astype(np.int32) + return output + + def npu_op_exec(self, input1): + output = torch.where(input1) + output = list(output) + output[0] = output[0].to("cpu").numpy().astype(np.int32) + return output + + def cpu_op_exec_condition(self, input1, ones): + output = torch.where(input1 > 0, input1, ones) + output = output.numpy() + return output + + def npu_op_exec_condition(self, input1, ones): + output = torch.where(input1 > 0, input1, ones) + output = output.to("cpu").numpy() + return output + + def cpu_op_exec_s(self, input1, ones): + output = torch._s_where(input1 > 0, input1, ones) + output = output.numpy() + return output + + def npu_op_exec_s(self, input1, ones): + output = torch._s_where(input1 > 0, input1, ones) + output = output.to("cpu").numpy() + return output + + def where_result(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, -100, 100) + cpu_ones = torch.ones_like(cpu_input1) + npu_ones = cpu_ones.to("npu") + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_ones = cpu_ones.to(torch.float32) + + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(npu_input1) + + cpu_output_cond = self.cpu_op_exec_condition(cpu_input1, cpu_ones) + npu_output_cond = self.npu_op_exec_condition(npu_input1, npu_ones) + cpu_output_cond = cpu_output_cond.astype(npu_output_cond.dtype) + + cpu_output_s = self.cpu_op_exec_s(cpu_input1, cpu_ones) + npu_output_s = self.npu_op_exec_s(npu_input1, npu_ones) + cpu_output_s = cpu_output_s.astype(npu_output_s.dtype) + + cpu_output[0] = cpu_output[0].astype(npu_output[0].dtype) + self.assertRtolEqual(cpu_output[0], npu_output[0]) + self.assertRtolEqual(cpu_output_cond, npu_output_cond) + self.assertRtolEqual(cpu_output_s, npu_output_s) + + def test_where_shape_format_fp32_1d(self, device): + format_list = [0, 3] + shape_format = [[np.float32, i, [18]] for i in format_list] + self.where_result(shape_format) + + def test_where_shape_format_fp32_2d(self, device): + format_list = [0] + shape_format = [[np.float32, i, [5, 256]] for i in format_list] + self.where_result(shape_format) + + def test_where_shape_format_fp32_3d(self, device): + format_list = [0] + shape_format = [[np.float32, i, [32, 3, 3]] for i in format_list] + self.where_result(shape_format) + + def test_where_shape_format_fp32_4d(self, device): + format_list = [0, 3] + shape_format = [[np.float32, i, [64, 112, 7, 7]] for i in format_list] + self.where_result(shape_format) + + def test_where_shape_format_fp16_1d(self, device): + format_list = [0, 3] + shape_format = [[np.float16, i, [18]] for i in format_list] + self.where_result(shape_format) + + def test_where_shape_format_fp16_2d(self, device): + format_list = [0, 3, 4, 29] + shape_format = [[np.float16, i, [5, 256]] for i in format_list] + self.where_result(shape_format) + + def test_where_shape_format_fp16_3d(self, device): + format_list = [0, 3, 4, 29] + shape_format = [[np.float16, i, [32, 3, 3]] for i in format_list] + self.where_result(shape_format) + + def test_where_shape_format_fp16_4d(self, device): + format_list = [0, 3, 4, 29] + shape_format = [[np.float16, i, [64, 112, 7, 7]] for i in format_list] + self.where_result(shape_format) + +instantiate_device_type_tests(TestWhere, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 1f08a26654f..7fe3eed7d5b 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -898,9 +898,6 @@ supported: - var_mean.names_dim - view_as - where.self - - where.ScalarSelf - - where.ScalarOther - - where.Scalar - where - _s_where - norm_except_dim diff --git a/torch_npu/csrc/aten/ops/ClampKernelNpu.cpp b/torch_npu/csrc/aten/ops/ClampKernelNpu.cpp new file mode 100644 index 00000000000..81a839bd134 --- /dev/null +++ b/torch_npu/csrc/aten/ops/ClampKernelNpu.cpp @@ -0,0 +1,165 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& clamp_min_out_npu_nocheck( + at::Tensor& result, + const at::Tensor& self, + at::Scalar min) { + // Set max according to self.dtype() + at::Scalar max; + if (self.dtype() == at::kInt) { + max = INT_MAX; + } else if (self.dtype() == at::kFloat) { + max = FLT_MAX; + } else { + max = NPU_HALF_MAX; + } + OpCommand cmd; + cmd.Name("ClipByValue") + .Input(self) + .Input(min, self.scalar_type()) + .Input(max, self.scalar_type()) + .Output(result) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::clamp_min_out( + const at::Tensor& self, + at::Scalar min, + at::Tensor& result) { + OpPreparation::CheckOut( + {self}, + result, + self); + OpPipeWithDefinedOut pipe; + return pipe.CheckMemory({self}, {result}) + .Func([&self, &min](at::Tensor& result){clamp_min_out_npu_nocheck(result, self, min);}) + .Call(result); + return result; +} + +at::Tensor& NPUNativeFunctions::clamp_max_out(const at::Tensor& self, at::Scalar max, at::Tensor& result) { + // Set min according to self.dtype() + at::Scalar min; + if (self.dtype() == at::kInt) { + min = INT_MIN; + } else if (self.dtype() == at::kFloat) { + min = -FLT_MAX; + } else { + min = NPU_HALF_MIN; + } + OpCommand cmd; + cmd.Name("ClipByValue") + .Input(self) + .Input(min, self.scalar_type()) + .Input(max, self.scalar_type()) + .Output(result) + .Run(); + return result; +} + +at::Tensor& clamp_out_npu_nocheck( + at::Tensor& result, + const at::Tensor& self, + c10::optional min, + c10::optional max) { + if (!min.has_value()) { + at::Scalar maxScalar = max.value(); + NPUNativeFunctions::clamp_max_out(self, maxScalar, result); + } else if (!max.has_value()) { + at::Scalar minScalar = min.value(); + NPUNativeFunctions::clamp_min_out(self, minScalar, result); + } else { + OpCommand cmd; + cmd.Name("ClipByValue") + .Input(self) + .Input(min.value(), self.scalar_type()) + .Input(max.value(), self.scalar_type()) + .Output(result) + .Run(); + } + return result; +} + +at::Tensor& NPUNativeFunctions::clamp_out( + const at::Tensor& self, + c10::optional min, + c10::optional max, + at::Tensor& result) { + OpPreparation::CheckOut( + {self}, + result, + self); + OpPipeWithDefinedOut pipe; + return pipe.CheckMemory({self}, {result}) + .Func([&self, &min, &max](at::Tensor& result){clamp_out_npu_nocheck(result, self, min, max);}) + .Call(result); + return result; +} + +at::Tensor NPUNativeFunctions::clamp_min(const at::Tensor& self, at::Scalar min) { + at::Tensor result = OpPreparation::ApplyTensor(self); + clamp_min_out_npu_nocheck(result, self, min); + return result; +} + +at::Tensor& NPUNativeFunctions::clamp_min_(at::Tensor& self, at::Scalar min) { + NPUNativeFunctions::clamp_min_out(self, min, self); + return self; +} + +at::Tensor NPUNativeFunctions::clamp_max(const at::Tensor& self, at::Scalar max) { + at::Tensor result = OpPreparation::ApplyTensor(self); + NPUNativeFunctions::clamp_max_out(self, max, result); + return result; +} + +at::Tensor& NPUNativeFunctions::clamp_max_(at::Tensor& self, at::Scalar max) { + OpPreparation::CheckMemory({self}, {self}); + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor result = NPUNativeFunctions::clamp_max_out(contiguousSelf, max, contiguousSelf); + NpuUtils::format_fresh_view(self, result); + } else { + NPUNativeFunctions::clamp_max_out(self, max, self); + } + return self; +} + +at::Tensor NPUNativeFunctions::clamp( + const at::Tensor& self, + c10::optional min, + c10::optional max) { + at::Tensor result = OpPreparation::ApplyTensor(self); + clamp_out_npu_nocheck(result, self, min, max); + return result; +} + +at::Tensor& NPUNativeFunctions::clamp_(at::Tensor& self, c10::optional min, c10::optional max) { + NPUNativeFunctions::clamp_out(self, min, max, self); + return self; +} +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/CtcLossBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/CtcLossBackwardKernelNpu.cpp new file mode 100644 index 00000000000..df2b0b3d263 --- /dev/null +++ b/torch_npu/csrc/aten/ops/CtcLossBackwardKernelNpu.cpp @@ -0,0 +1,87 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor NPUNativeFunctions::_ctc_loss_backward( + const at::Tensor& gradOut, + const at::Tensor& logProbs, + const at::Tensor& targets, + at::IntArrayRef inputLengths, + at::IntArrayRef targetLengths, + const at::Tensor& negLogLikelihood, + const at::Tensor& logAlpha, + int64_t blank, + bool zeroInfinity) { + at::Tensor gradOutNeed = gradOut; + if (gradOut.scalar_type() == at::ScalarType::Half) { + gradOutNeed = gradOutNeed.to(at::ScalarType::Float); + } + + at::Tensor logProbsNeed = logProbs; + if (logProbs.scalar_type() == at::ScalarType::Half) { + logProbsNeed = logProbsNeed.to(at::ScalarType::Float); + } + + at::Tensor negLogLikelihoodNeed = negLogLikelihood; + if (negLogLikelihood.scalar_type() == at::ScalarType::Half) { + negLogLikelihoodNeed = negLogLikelihoodNeed.to(at::ScalarType::Float); + } + + at::Tensor logAlphaNeed = logAlpha; + if (logAlpha.scalar_type() == at::ScalarType::Half) { + logAlphaNeed = logAlphaNeed.to(at::ScalarType::Float); + } + + at::Tensor targetsCast = targets; + if(targets.scalar_type() == at::ScalarType::Long){ + targetsCast = targetsCast.to(at::ScalarType::Int); + } + + auto inputLengthsTensor = at::tensor(inputLengths, targetsCast.options().dtype(at::kInt)); + auto targetLengthsTensor = at::tensor(targetLengths, targetsCast.options().dtype(at::kInt)); + + auto outputSize = input_same_output_size(logProbs); + + // construct the output tensor of the NPU + at::Tensor grad = OpPreparation::ApplyTensor(logProbsNeed, outputSize); + // calculate the output result of the NPU + OpCommand cmd; + cmd.Name("CTCLossV2Grad") + .Input(gradOutNeed) + .Input(logProbsNeed) + .Input(targetsCast) + .Input(inputLengthsTensor) + .Input(targetLengthsTensor) + .Input(negLogLikelihoodNeed) + .Input(logAlphaNeed) + .Output(grad) + .Attr("blank", blank) + .Attr("zero_infinity", zeroInfinity) + .Run(); + + if (gradOut.scalar_type() == at::ScalarType::Half) { + grad = grad.to(at::ScalarType::Half); + } + + return grad; +} +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/CtcLossKernelNpu.cpp b/torch_npu/csrc/aten/ops/CtcLossKernelNpu.cpp new file mode 100644 index 00000000000..b0df7b80c40 --- /dev/null +++ b/torch_npu/csrc/aten/ops/CtcLossKernelNpu.cpp @@ -0,0 +1,140 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +std::tuple NPUNativeFunctions::_ctc_loss( + const at::Tensor& logProbs, + const at::Tensor& targets, + at::IntArrayRef inputLengths, + at::IntArrayRef targetLengths, + int64_t blank, + bool zeroInfinity) { + at::Tensor logProbsNeed = logProbs; + if (logProbs.scalar_type() == at::ScalarType::Half) { + logProbsNeed = logProbsNeed.to(at::ScalarType::Float); + } + + // Aicore supports only the int type + at::Tensor targetsCast = targets; + if(targets.scalar_type() == at::ScalarType::Long){ + targetsCast = targetsCast.to(at::ScalarType::Int); + } + + // IntArrayRef to Tensor + auto inputLengthsTensor = at::tensor(inputLengths, targetsCast.options()); + auto targetLengthsTensor = at::tensor(targetLengths, targetsCast.options()); + + // calculate the output size + auto outputSizes = ctc_loss_npu_output_size(logProbs, targetsCast, targetLengths); + + // construct the output tensor of the NPU + at::Tensor negLogLikelihood = OpPreparation::ApplyTensorWithSizes( + std::get<0>(outputSizes), + logProbsNeed.options()); + + at::Tensor logAlpha = OpPreparation::ApplyTensorWithSizes( + std::get<1>(outputSizes), + logProbsNeed.options()); + + // calculate the output result of the NPU + OpCommand cmd; + cmd.Name("CTCLossV2") + .Input(logProbsNeed) + .Input(targetsCast) + .Input(inputLengthsTensor) + .Input(targetLengthsTensor) + .Output(negLogLikelihood) + .Output(logAlpha) + .Attr("blank", blank) + .Attr("zero_infinity", zeroInfinity) + .Run(); + + if (logProbs.scalar_type() == at::ScalarType::Half) { + negLogLikelihood = NPUNativeFunctions::npu_dtype_cast(negLogLikelihood, at::ScalarType::Half); + logAlpha = logAlpha.to(at::ScalarType::Half); + } + + return std::tuple(negLogLikelihood, logAlpha); +} + +at::Tensor NPUNativeFunctions::ctc_loss( + const at::Tensor& logProbs, + const at::Tensor& targets, + at::IntArrayRef inputLengths, + at::IntArrayRef targetLengths, + int64_t blank, + int64_t reduction, + bool zeroInfinity) { + at::Tensor res = std::get<0>(at::_ctc_loss( + logProbs, + targets, + inputLengths, + targetLengths, + blank, + zeroInfinity)); + + if (zeroInfinity) { + res = at::where( + res == at::Scalar(std::numeric_limits::infinity()), + at::zeros({}, res.options()), + res); + } + + if (reduction == at::Reduction::Mean) { + std::vector targetLengthsVector = targetLengths.vec(); + + auto targetLengthsTensor = CalcuOpUtil::copy_tensor_host_to_device( + at::from_blob(targetLengthsVector.data(), {targetLengthsVector.size()}, at::kLong)).clamp_min(1); + + at::Tensor targetLengthsTensor_ = targetLengthsTensor.to(res.dtype()); + + return (res / targetLengthsTensor_).mean(); + + } else if (reduction == at::Reduction::Sum) { + return res.sum(); + } + + return res; +} + +at::Tensor NPUNativeFunctions::ctc_loss( + const at::Tensor& logProbs, + const at::Tensor& targets, + const at::Tensor& inputLengths, + const at::Tensor& targetLengths, + int64_t blank, + int64_t reduction, + bool zeroInfinity) { + TORCH_CHECK(isIntegralType(inputLengths.scalar_type(), false), "input_lengths must be integral"); + TORCH_CHECK(isIntegralType(targetLengths.scalar_type(), false), "target_lengths must be integral"); + + at::Tensor inputLengthsTensor = inputLengths.to(at::Device(at::kCPU), at::kLong).contiguous(); + at::Tensor targetLengthsTensor = targetLengths.to(at::Device(at::kCPU), at::kLong).contiguous(); + + at::IntArrayRef inputLengthsList(inputLengthsTensor.data_ptr(), inputLengthsTensor.numel()); + at::IntArrayRef targetLengthsList(targetLengthsTensor.data_ptr(), targetLengthsTensor.numel()); + + + return at::ctc_loss(logProbs, targets, inputLengthsList, targetLengthsList, blank, reduction, zeroInfinity); +} +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/WhereKernelNpu.cpp b/torch_npu/csrc/aten/ops/WhereKernelNpu.cpp new file mode 100644 index 00000000000..149b23b25c2 --- /dev/null +++ b/torch_npu/csrc/aten/ops/WhereKernelNpu.cpp @@ -0,0 +1,123 @@ +// Copyright (c) 2020, Huawei Technologies. +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +std::tuple npu_expand_outplace( + const at::Tensor &to_expand1, + const at::Tensor &to_expand2, + const at::Tensor &to_expand3, + const char *api_name) { + for (auto& t : {to_expand1, to_expand2, to_expand3}) { + if (!t.defined()) { + AT_ERROR(api_name, "(...) called with an undefined Tensor"); + } + } + + if (to_expand1.sizes().equals(to_expand2.sizes()) && to_expand1.sizes().equals(to_expand3.sizes())) { + return std::make_tuple(to_expand1, to_expand2, to_expand3); + } + + auto expanded_size12 = broadcast_ops_npu_output_size(to_expand1, to_expand2); + auto expanded_size = broadcast_ops_npu_output_size(expanded_size12, to_expand3.sizes()); + + return std::make_tuple( + to_expand1.expand(expanded_size, true), + to_expand2.expand(expanded_size, true), + to_expand3.expand(expanded_size, true)); +} + +at::Tensor NPUNativeFunctions::_s_where( + const at::Tensor& condition, + const at::Tensor& self, + const at::Tensor& other) { + at::Tensor result = OpPreparation::ApplyTensor(self); + + OpCommand cmd; + cmd.Name("Select") + .Input(condition) + .Input(self) + .Input(other) + .Output(result) + .Run(); + + return result; +} + +at::Tensor NPUNativeFunctions::where( + const at::Tensor& condition, + const at::Tensor& self, + const at::Tensor& other) { + TORCH_CHECK(condition.device() == self.device() && self.device() == other.device(), + "expected condition, x and y to be on the same device, but condition is on ", + condition.device(), " and x and y are on ", self.device(), " and ", other.device(), + " respectively"); + if (condition.scalar_type() != at::ScalarType::Byte && condition.scalar_type() != at::ScalarType::Bool) { + AT_ERROR("Expected condition to have ScalarType Byte, but got ScalarType ", + toString(condition.scalar_type())); + } + at::Tensor b_condition, b_self, b_other; + std::tie(b_condition, b_self, b_other) = npu_expand_outplace(condition, self, other, "where_npu"); + return at::_s_where(b_condition, b_self, b_other); +} + +c10::SmallVector where_npu_output_size(const at::Tensor& condition){ + int64_t dim = condition.dim(); + at::Tensor boolSelf = NPUNativeFunctions::npu_dtype_cast(condition, at::ScalarType::Bool); + at::Tensor intSelf = NPUNativeFunctions::npu_dtype_cast(boolSelf, at::ScalarType::Int); + at::Tensor coutNonzeroSelf = at::sum(intSelf, at::ScalarType::Int); + int64_t nonzeroNum = coutNonzeroSelf.item().toInt(); + c10::SmallVector outputSize = {nonzeroNum, dim}; + return outputSize; +} + +vector NPUNativeFunctions::where(const at::Tensor& condition) { + at::Tensor formatCastOfCondition = condition; + if (condition.storage().unsafeGetStorageImpl()->npu_desc_.npu_format_ != + ACL_FORMAT_ND) { + formatCastOfCondition = NPUNativeFunctions::npu_format_cast(formatCastOfCondition, ACL_FORMAT_ND); + } + if (condition.scalar_type() == at::ScalarType::Half) { + formatCastOfCondition = NPUNativeFunctions::npu_dtype_cast(formatCastOfCondition, at::ScalarType::Float); + } + + // calculate the output size + auto outputSize = where_npu_output_size(formatCastOfCondition); + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + outputSize, formatCastOfCondition.options().dtype(at::kLong), ACL_FORMAT_ND); + + OpCommand cmd; + cmd.Name("NonZero") + .Input(formatCastOfCondition) + .Output(result) + .Run(); + result = result.transpose(1, 0); + std::vector chunkResult = result.chunk(result.size(0), 0); + std::vector squeezeResult; + for(int64_t i = 0; i < chunkResult.size(); i++){ + squeezeResult.push_back(chunkResult[i].squeeze(0)); + } + + return squeezeResult; +} +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee