From e78d65ac85e4f7b92fa675d87f5708d0c721dc13 Mon Sep 17 00:00:00 2001 From: fengjiawei1 Date: Mon, 15 Jan 2024 17:01:51 +0800 Subject: [PATCH] meta_Ca --- torch_npu/meta/meta_registrations.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/torch_npu/meta/meta_registrations.py b/torch_npu/meta/meta_registrations.py index 8cf3af22891..4f0bb2be106 100644 --- a/torch_npu/meta/meta_registrations.py +++ b/torch_npu/meta/meta_registrations.py @@ -242,3 +242,22 @@ def npu_trans_quant_param_meta(scale, offset=None): dim_offset = offset.size(0) dim_max = max(dim_max, dim_offset) return scale.new_empty((dim_max), dtype=torch.int64) + +@impl(m, "npu_max_pool2d_with_indices") +def npu_max_pool2d_with_indices_meta(self, kernel_size, stride, padding, dilation, ceil_mode): + height = self.size(-2) + width = self.size(-1) + kernel_h = kernel_size[0] + kernel_w = kernel_size[1] + stride_h = stride[0] + stride_w = stride[1] + padding_h = padding[0] + padding_w = padding[1] + dilation_h = dilation[0] + dilation_w = dilation[1] + + out_height = (height + 2 - dilation_h * (kernel_h - 1) - 1) / stride_h + 1 + out_width = (width + 2 - dilation_w * (kernel_w - 1) - 1) / stride_w + 1 + shape = (1, self.size(-3), int(out_height), int(out_width)) + + return (self.new_empty(shape), self.new_empty(shape, dtype = torch.int32)) \ No newline at end of file -- Gitee