From c24bd882b8e3f2611d3cca4887191e82bccad58c Mon Sep 17 00:00:00 2001 From: pipihugh Date: Tue, 22 Mar 2022 16:37:26 +0800 Subject: [PATCH 1/8] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dgrid=5Fsampler=5F2d?= =?UTF-8?q?=E7=AE=97=E5=AD=90=E7=B2=BE=E5=BA=A6=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_grid_sampler_2d.py | 3 +-- torch_npu/csrc/aten/ops/GridSampler2dKernelNpu.cpp | 8 ++------ 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/test/test_network_ops/test_grid_sampler_2d.py b/test/test_network_ops/test_grid_sampler_2d.py index 9ea285d514a..eba45367282 100644 --- a/test/test_network_ops/test_grid_sampler_2d.py +++ b/test/test_network_ops/test_grid_sampler_2d.py @@ -65,12 +65,11 @@ class TestGridSampler2D(TestCase): ] for item in shape_format: cpu_input, npu_input = create_common_tensor(item[0], 1, 100) - cpu_grid, npu_grid = create_common_tensor(item[1], -1, 1) + cpu_grid, npu_grid = create_common_tensor(item[1], -3, 3) cpu_output = self.cpu_op_fp16_exec(cpu_input, cpu_grid) npu_output = self.npu_op_exec(npu_input, npu_grid) self.assertRtolEqual(cpu_output, npu_output) - if __name__ == "__main__": run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/GridSampler2dKernelNpu.cpp b/torch_npu/csrc/aten/ops/GridSampler2dKernelNpu.cpp index 674479dbe23..38e32044c88 100644 --- a/torch_npu/csrc/aten/ops/GridSampler2dKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/GridSampler2dKernelNpu.cpp @@ -49,16 +49,13 @@ at::Tensor NPUNativeFunctions::grid_sampler_2d( at::Tensor result = OpPreparation::ApplyTensorWithFormat(dtypeCastOfSelf, outputSize, ACL_FORMAT_ND); - c10::SmallVectorinterMode = {"bilinear", "nearest", "bicubic"}; - c10::SmallVectorpaddingMode = {"zeros", "border", "reflection"}; - OpCommand cmd; cmd.Name("GridSampler2D") .Input(dtypeCastOfSelf) .Input(dtypeCastOfGrid) .Output(result) - .Attr("interpolation_mode", interMode[interpolation_mode]) - .Attr("padding_mode", paddingMode[padding_mode]) + .Attr("interpolation_mode", interpolation_mode) + .Attr("padding_mode", padding_mode) .Attr("align_corners", align_corners) .Run(); @@ -68,6 +65,5 @@ at::Tensor NPUNativeFunctions::grid_sampler_2d( } return result; } - } // namespace native } // namespace at_npu -- Gitee From e87e5f97c411b49e770fbcaf4206ab002fa1263b Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Tue, 22 Mar 2022 17:15:56 +0800 Subject: [PATCH 2/8] =?UTF-8?q?TensorOptions=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../csrc/aten/common/TensorFactories.cpp | 70 ++++++++----------- torch_npu/csrc/aten/ops/ArangeKernelNpu.cpp | 10 ++- torch_npu/csrc/aten/ops/FullKernelNpu.cpp | 10 ++- torch_npu/csrc/aten/ops/RangeKernelNpu.cpp | 20 +++--- 4 files changed, 44 insertions(+), 66 deletions(-) diff --git a/torch_npu/csrc/aten/common/TensorFactories.cpp b/torch_npu/csrc/aten/common/TensorFactories.cpp index 75601864502..b71f65b9c41 100644 --- a/torch_npu/csrc/aten/common/TensorFactories.cpp +++ b/torch_npu/csrc/aten/common/TensorFactories.cpp @@ -309,12 +309,10 @@ namespace at_npu c10::optional optional_memory_format) { - c10::TensorOptions options; - auto device = device_or_default(device_opt); - options = options.dtype(dtype_opt) - .device(device) - .layout(layout_opt) - .pinned_memory(pin_memory_opt); + c10::TensorOptions options = c10::TensorOptions().dtype(dtype_opt) + .device(device_opt) + .layout(layout_opt) + .pinned_memory(pin_memory_opt); return at_npu::native::empty_like_npu(self, options, optional_memory_format); } @@ -419,12 +417,10 @@ namespace at_npu int64_t dst_format) { caffe2::TypeMeta dtype = c10::scalarTypeToTypeMeta(dtype_or_default(dtype_opt)); - c10::TensorOptions options; - auto device = device_or_default(device_opt); - options = options.dtype(dtype_opt) - .device(device) - .layout(layout_opt) - .pinned_memory(pin_memory_opt); + c10::TensorOptions options = c10::TensorOptions().dtype(dtype_opt) + .device(device_opt) + .layout(layout_opt) + .pinned_memory(pin_memory_opt); at::Tensor result = OpPreparation::ApplyTensorWithFormat(size, options, dst_format); if (names.has_value()) { @@ -514,12 +510,10 @@ namespace at_npu c10::optional pin_memory_opt) { - c10::TensorOptions options; - auto device = device_or_default(device_opt); - options = options.dtype(dtype_opt) - .device(device) - .layout(layout_opt) - .pinned_memory(pin_memory_opt); + c10::TensorOptions options = c10::TensorOptions().dtype(dtype_opt) + .device(device_opt) + .layout(layout_opt) + .pinned_memory(pin_memory_opt); window_function_checks("blackman_window", options, window_length); if (window_length == 0) @@ -558,12 +552,10 @@ namespace at_npu c10::optional pin_memory_opt) { - c10::TensorOptions options; - auto device = device_or_default(device_opt); - options = options.dtype(dtype_opt) - .device(device) - .layout(layout_opt) - .pinned_memory(pin_memory_opt); + c10::TensorOptions options = c10::TensorOptions().dtype(dtype_opt) + .device(device_opt) + .layout(layout_opt) + .pinned_memory(pin_memory_opt); window_function_checks("bartlett_window", options, window_length); if (window_length == 0) @@ -604,12 +596,10 @@ namespace at_npu c10::optional pin_memory_opt) { - c10::TensorOptions options; - auto device = device_or_default(device_opt); - options = options.dtype(dtype_opt) - .device(device) - .layout(layout_opt) - .pinned_memory(pin_memory_opt); + c10::TensorOptions options = c10::TensorOptions().dtype(dtype_opt) + .device(device_opt) + .layout(layout_opt) + .pinned_memory(pin_memory_opt); window_function_checks("hann_window", options, window_length); return at::hamming_window(window_length, periodic, 0.5, 0.5, options); @@ -636,12 +626,10 @@ namespace at_npu c10::optional pin_memory_opt) { - c10::TensorOptions options; - auto device = device_or_default(device_opt); - options = options.dtype(dtype_opt) - .device(device) - .layout(layout_opt) - .pinned_memory(pin_memory_opt); + c10::TensorOptions options = c10::TensorOptions().dtype(dtype_opt) + .device(device_opt) + .layout(layout_opt) + .pinned_memory(pin_memory_opt); window_function_checks("hamming_window", options, window_length); if (window_length == 0) @@ -753,12 +741,10 @@ namespace at_npu c10::optional device_opt, c10::optional pin_memory_opt) { - c10::TensorOptions options; - auto device = device_or_default(device_opt); - options = options.dtype(dtype_opt) - .device(device) - .layout(layout_opt) - .pinned_memory(pin_memory_opt); + c10::TensorOptions options = c10::TensorOptions().dtype(dtype_opt) + .device(device_opt) + .layout(layout_opt) + .pinned_memory(pin_memory_opt); TORCH_CHECK( options.layout() != at::kSparse, "full(...) is not implemented for sparse layout"); diff --git a/torch_npu/csrc/aten/ops/ArangeKernelNpu.cpp b/torch_npu/csrc/aten/ops/ArangeKernelNpu.cpp index 494594327d4..31e6acafc9a 100644 --- a/torch_npu/csrc/aten/ops/ArangeKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/ArangeKernelNpu.cpp @@ -58,12 +58,10 @@ at::Tensor NPUNativeFunctions::arange( c10::optional device_opt, c10::optional pin_memory_opt) { - auto device = device_or_default(device_opt); - at::TensorOptions option; - option = option.dtype(dtype_opt) - .layout(layout_opt) - .device(device) - .pinned_memory(pin_memory_opt); + c10::TensorOptions option = c10::TensorOptions().dtype(dtype_opt) + .device(device_opt) + .layout(layout_opt) + .pinned_memory(pin_memory_opt); float start_value = CalcuOpUtil::get_scalar_float_value(start); float end_value = CalcuOpUtil::get_scalar_float_value(end); diff --git a/torch_npu/csrc/aten/ops/FullKernelNpu.cpp b/torch_npu/csrc/aten/ops/FullKernelNpu.cpp index 8f644cbc0f9..f2fe8936c32 100644 --- a/torch_npu/csrc/aten/ops/FullKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/FullKernelNpu.cpp @@ -37,12 +37,10 @@ at::Tensor NPUNativeFunctions::full( c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { - c10::TensorOptions option; - auto device = device_or_default(device_opt); - option = option.dtype(dtype_opt) - .device(device) - .layout(layout_opt) - .pinned_memory(pin_memory_opt); + c10::TensorOptions option = c10::TensorOptions().dtype(dtype_opt) + .device(device_opt) + .layout(layout_opt) + .pinned_memory(pin_memory_opt); at::Tensor result = OpPreparation::ApplyTensorWithSizes(size, option); return result.fill_(fill_value); } diff --git a/torch_npu/csrc/aten/ops/RangeKernelNpu.cpp b/torch_npu/csrc/aten/ops/RangeKernelNpu.cpp index df65f925d39..c8ea4f2e408 100644 --- a/torch_npu/csrc/aten/ops/RangeKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/RangeKernelNpu.cpp @@ -55,12 +55,10 @@ at::Tensor NPUNativeFunctions::range( c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { - auto device = device_or_default(device_opt); - c10::TensorOptions option; - option = option.dtype(dtype_opt) - .device(device) - .layout(layout_opt) - .pinned_memory(pin_memory_opt); + c10::TensorOptions option = c10::TensorOptions().dtype(dtype_opt) + .device(device_opt) + .layout(layout_opt) + .pinned_memory(pin_memory_opt); return at::range(start, end, 1, option); } @@ -72,12 +70,10 @@ at::Tensor NPUNativeFunctions::range( c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { - auto device = device_or_default(device_opt); - c10::TensorOptions option; - option = option.dtype(dtype_opt) - .device(device) - .layout(layout_opt) - .pinned_memory(pin_memory_opt); + c10::TensorOptions option = c10::TensorOptions().dtype(dtype_opt) + .device(device_opt) + .layout(layout_opt) + .pinned_memory(pin_memory_opt); float start_value = CalcuOpUtil::get_scalar_float_value(start); float end_value = CalcuOpUtil::get_scalar_float_value(end); -- Gitee From 5993205cbaeeff196c2df62faf689c66ba5f0533 Mon Sep 17 00:00:00 2001 From: "jiangrongqiang@huawei.com" Date: Tue, 22 Mar 2022 17:41:41 +0800 Subject: [PATCH 3/8] delete release queue --- torch_npu/csrc/framework/OpParamMaker.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torch_npu/csrc/framework/OpParamMaker.cpp b/torch_npu/csrc/framework/OpParamMaker.cpp index 304138f6642..38c0e6eb619 100644 --- a/torch_npu/csrc/framework/OpParamMaker.cpp +++ b/torch_npu/csrc/framework/OpParamMaker.cpp @@ -375,7 +375,16 @@ namespace at_npu } } - void ReleaseFunc(void* ptr, c10::npu::ReleaseQueue& releaseQueue) + void ReleaseFunc(void* ptr, c10::npu::ReleaseQueue& releaseQueue) { + auto queueParam = static_cast(ptr); + auto type = queueParam->paramType; + if (type == COMPILE_AND_EXECUTE) { + auto cur_paras = static_cast(queueParam->paramVal); + cur_paras->Release(); + } + } + + void ReleaseFunc_(void* ptr, c10::npu::ReleaseQueue& releaseQueue) { releaseQueue.PushToReleaseQueue(ptr); } -- Gitee From b40c78fdc21dc7fa03055701bff728a8cca49494 Mon Sep 17 00:00:00 2001 From: zhoufan37 Date: Tue, 22 Mar 2022 17:23:35 +0800 Subject: [PATCH 4/8] modified take Operator --- torch_npu/csrc/aten/ops/TakeKernelNpu.cpp | 65 +++++++++++------------ 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/torch_npu/csrc/aten/ops/TakeKernelNpu.cpp b/torch_npu/csrc/aten/ops/TakeKernelNpu.cpp index f00ecc6d046..f088f621d8d 100644 --- a/torch_npu/csrc/aten/ops/TakeKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/TakeKernelNpu.cpp @@ -19,42 +19,40 @@ namespace at_npu { namespace native { -c10::SmallVector take_npu_input( - const c10::SmallVector& inputTensor) { - at::Tensor contiguousTensor; - c10::SmallVector inputs; - - for (int i = 0; i < inputTensor.size(); i++) { - if (i == 0) { - int64_t input_size = 1; - at::Tensor input_tensor = inputTensor[i].reshape(-1); - contiguousTensor = NpuUtils::format_contiguous(input_tensor); - } else { - contiguousTensor = NpuUtils::format_contiguous(inputTensor[i]); - } - inputs.emplace_back(NPUTensorDesc(contiguousTensor)); - } - return inputs; -} - -c10::SmallVector take_npu_output(const c10::SmallVector& outputTensor) { - return CalcuOpUtil::create_npu_output_tensor_desc(outputTensor); -} +at::Tensor& take_out_nocheck(const at::Tensor& self, const at::Tensor& index, at::Tensor& result) { + at::Tensor input_tensor = self.reshape(-1); + at::Tensor contiguousSelf = NpuUtils::format_contiguous(input_tensor); + at::Tensor contiguousIndex = NpuUtils::format_contiguous(index); -c10::SmallVector take_npu_attr(const at::Tensor& self) { - NPUAttrDesc npuAttrValidateIndices = NPUAttrDesc("validate_indices", false); - c10::SmallVector attrs = {npuAttrValidateIndices}; - return attrs; + OpCommand cmd; + cmd.Name("Gather") + .Input(contiguousSelf) + .Input(contiguousIndex) + .Output(result) + .Attr("validate_indices", false) + .Run(); + + return result; } at::Tensor& NPUNativeFunctions::take_out(const at::Tensor& self, const at::Tensor& index, at::Tensor& result) { - // constructs the input and output NPUTensorDesc - auto inputs = take_npu_input({self,index}); - auto outputs = take_npu_output({result}); - // constructs the attr of the NPUAttrDesc - auto attrs = take_npu_attr(self); - // executing the NPU operator - CalcuOpUtil::execute_npu_operate("Gather", inputs, outputs, attrs); + // calculate the output size + auto outputSize = input_same_output_size(index); + + OpPreparation::CheckOut( + {self, index}, + result, + self, + outputSize); + + if (!NpuUtils::check_match(&result)) { + at::Tensor contiguousResult = NpuUtils::format_contiguous(result); + take_out_nocheck(self, index, contiguousResult); + NpuUtils::format_fresh_view(result, contiguousResult); + } else { + take_out_nocheck(self, index, result); + } + return result; } @@ -67,7 +65,8 @@ at::Tensor NPUNativeFunctions::take(const at::Tensor& self, const at::Tensor& in outputSize, self.options()); - NPUNativeFunctions::take_out(self, index, result); + take_out_nocheck(self, index, result); + return result; } } // namespace native -- Gitee From dd573dc9befd6117302cb536bd024496eeac1589 Mon Sep 17 00:00:00 2001 From: "jiangrongqiang@huawei.com" Date: Tue, 22 Mar 2022 18:15:05 +0800 Subject: [PATCH 5/8] delete release queue --- torch_npu/csrc/framework/OpParamMaker.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_npu/csrc/framework/OpParamMaker.cpp b/torch_npu/csrc/framework/OpParamMaker.cpp index 38c0e6eb619..967b9285991 100644 --- a/torch_npu/csrc/framework/OpParamMaker.cpp +++ b/torch_npu/csrc/framework/OpParamMaker.cpp @@ -376,9 +376,9 @@ namespace at_npu } void ReleaseFunc(void* ptr, c10::npu::ReleaseQueue& releaseQueue) { - auto queueParam = static_cast(ptr); + auto queueParam = static_cast(ptr); auto type = queueParam->paramType; - if (type == COMPILE_AND_EXECUTE) { + if (type == c10::npu::queue::COMPILE_AND_EXECUTE) { auto cur_paras = static_cast(queueParam->paramVal); cur_paras->Release(); } -- Gitee From cb17f984b1da9f3ddbe062a823b1abac3bbba8a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=A4=E5=AE=89=E5=8D=87?= Date: Tue, 22 Mar 2022 20:21:34 +0800 Subject: [PATCH 6/8] fix h2d error. --- torch_npu/csrc/aten/common/CopyKernel.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_npu/csrc/aten/common/CopyKernel.cpp b/torch_npu/csrc/aten/common/CopyKernel.cpp index 13191fa1335..03dd5cc7193 100644 --- a/torch_npu/csrc/aten/common/CopyKernel.cpp +++ b/torch_npu/csrc/aten/common/CopyKernel.cpp @@ -249,7 +249,7 @@ void copy_d2h_baseformat(at::Tensor& dst, const at::Tensor& src, bool non_blocki void copy_h2d(at::Tensor& self, const at::Tensor& src, bool non_blocking) { if (!FormatHelper::IsBaseFormatType(self)) { - at::Tensor dst = OpPreparation::ApplyTensor(self); + at::Tensor dst = OpPreparation::ApplyTensorWithSizes(self.sizes(), self.options()); copy_h2d_baseformat(dst, src, non_blocking, true); NPUNativeFunctions::npu_format_cast_(self, dst); return; -- Gitee From 75794374dce9361721709cfd8da2fdf686247c82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=A4=E5=AE=89=E5=8D=87?= Date: Tue, 22 Mar 2022 20:34:45 +0800 Subject: [PATCH 7/8] Add npu_confusion_transpose. --- torch_npu/utils/tensor_methods.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch_npu/utils/tensor_methods.py b/torch_npu/utils/tensor_methods.py index 0839be9d8e3..0bdbe7e6854 100644 --- a/torch_npu/utils/tensor_methods.py +++ b/torch_npu/utils/tensor_methods.py @@ -52,6 +52,10 @@ def one_(self): warnings.warn(warning_str.format("one_")) return torch_npu.one_(self) +def npu_confusion_transpose(self, perm, shape, transpose_first): + warnings.warn(warning_str.format("npu_confusion_transpose")) + return torch_npu.npu_confusion_transpose(self, perm, shape, transpose_first) + def add_tensor_methods(): torch.Tensor.npu_format_cast_ = npu_format_cast_ @@ -59,4 +63,5 @@ def add_tensor_methods(): torch.Tensor.npu_dtype_cast = npu_dtype_cast torch.Tensor.npu_dtype_cast_ = npu_dtype_cast_ torch.Tensor.copy_memory_ = copy_memory_ - torch.Tensor.one_ = one_ \ No newline at end of file + torch.Tensor.one_ = one_ + torch.Tensor.npu_confusion_transpose = npu_confusion_transpose \ No newline at end of file -- Gitee From e33763bce3bc4c95915b11e245398c409ee7a267 Mon Sep 17 00:00:00 2001 From: pipihugh Date: Tue, 22 Mar 2022 14:01:43 +0800 Subject: [PATCH 8/8] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dscatter=E7=AE=97=E5=AD=90?= =?UTF-8?q?src=E6=95=B0=E6=8D=AE=E7=B1=BB=E5=9E=8B=E4=B8=8D=E5=8C=B9?= =?UTF-8?q?=E9=85=8D=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_scatter.py | 16 +++++++++++++++- torch_npu/csrc/aten/ops/ScatterKernelNpu.cpp | 17 ++++++----------- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/test/test_network_ops/test_scatter.py b/test/test_network_ops/test_scatter.py index 1936d3761b2..f377284bc0d 100644 --- a/test/test_network_ops/test_scatter.py +++ b/test/test_network_ops/test_scatter.py @@ -51,7 +51,7 @@ class TestScatter(TestCase): input1 = input1.cpu() return input1.numpy() - def test_scatter_shape_format(self, device="npu"): + def test_scatter_shape_format(self): shape_format = [ [0, [3, 5], [np.float32, 0, [2, 5]]], [0, [3, 5], [np.float32, 3, [2, 5]]], @@ -86,6 +86,20 @@ class TestScatter(TestCase): npu_output = self.npu_op_exec_inplace(item[1], item[0], index, 1.23, False) self.assertRtolEqual(cpu_output, npu_output) + def test_scatter_debug(self): + a = np.random.uniform(-2,2,(31, 43, 41, 97)).astype(np.float16) + b = np.random.uniform(0,30,(31, 43, 41, 97)).astype(np.int32) + c = np.random.uniform(-2,2,(31, 43, 41, 97)).astype(np.float16) + ca = torch.from_numpy(a) + cb = torch.from_numpy(b).long() + cc = torch.from_numpy(c) + na = ca.npu() + nb = cb.npu() + nc = cc.npu() + dim = 0 + cpu_output = torch.scatter(ca, dim, cb, cc) + npu_output = torch.scatter(na, dim, nb, nc) + self.assertRtolEqual(cpu_output, npu_output.cpu()) if __name__ == "__main__": run_tests() diff --git a/torch_npu/csrc/aten/ops/ScatterKernelNpu.cpp b/torch_npu/csrc/aten/ops/ScatterKernelNpu.cpp index 0fad28bcd3e..8a32d4da902 100644 --- a/torch_npu/csrc/aten/ops/ScatterKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/ScatterKernelNpu.cpp @@ -41,7 +41,7 @@ at::Tensor& scatter_npu_src_impl( at::Tensor& self, int64_t dim, const at::Tensor& index_ex, - const at::Tensor& src) { + const at::Tensor& src_ex) { at::ScalarType selfType = self.scalar_type(); if (selfType == at::ScalarType::Half) { self = NPUNativeFunctions::npu_dtype_cast(self, at::ScalarType::Float); @@ -52,6 +52,11 @@ at::Tensor& scatter_npu_src_impl( index = NPUNativeFunctions::npu_dtype_cast(index, at::ScalarType::Float); } + at::Tensor src(src_ex); + if (src.scalar_type() != self.scalar_type()) { + src = NPUNativeFunctions::npu_dtype_cast(src, self.scalar_type()); + } + if (!NpuUtils::check_match(&self)) { at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); @@ -64,7 +69,6 @@ at::Tensor& scatter_npu_src_impl( if(self.scalar_type() != selfType){ self = NPUNativeFunctions::npu_dtype_cast(self, at::ScalarType::Half); } - return self; } @@ -74,10 +78,6 @@ at::Tensor& NPUNativeFunctions::scatter_( const at::Tensor& index_ex, const at::Tensor& src_ex) { at::Tensor src(src_ex); - if (src.scalar_type() != self.scalar_type()) { - src = NPUNativeFunctions::npu_dtype_cast(src, self.scalar_type()); - } - scatter_npu_src_impl(self, dim, index_ex, src); return self; } @@ -90,11 +90,6 @@ at::Tensor& NPUNativeFunctions::scatter_( at::Tensor srcTensor = scalar_to_tensor(src).to(at::ScalarType::Float); srcTensor = CalcuOpUtil::copy_tensor_host_to_device(srcTensor); at::Tensor srcTensor_broadcast = NPUNativeFunctions::npu_broadcast(srcTensor, array_to_small_vector(index_ex.sizes())); - - if (srcTensor_broadcast.scalar_type() != self.scalar_type()) { - srcTensor_broadcast = NPUNativeFunctions::npu_dtype_cast(srcTensor_broadcast, self.scalar_type()); - } - scatter_npu_src_impl(self, dim, index_ex, srcTensor_broadcast); return self; } -- Gitee