From dbbea8f590f56dedf34f7f8b065269f88848ba66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=87=A7=E5=BA=86=E9=A6=99?= Date: Tue, 15 Oct 2024 10:53:07 +0800 Subject: [PATCH] fix dlinknet Conv3dTranspose error --- research/cv/dlinknet/src/dinknet.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/research/cv/dlinknet/src/dinknet.py b/research/cv/dlinknet/src/dinknet.py index 0fd2dcfad..9c87b8c13 100644 --- a/research/cv/dlinknet/src/dinknet.py +++ b/research/cv/dlinknet/src/dinknet.py @@ -82,16 +82,12 @@ class DecoderBlock(nn.Cell): weight_init=Tensor(conv_weight_init((in_channels // 4, in_channels, 1, 1)))) self.norm1 = nn.BatchNorm2d(in_channels // 4) self.relu1 = relu - self.expand_dims = P.ExpandDims() - self.deconv3 = nn.Conv3dTranspose(in_channels // 4, in_channels // 4, - kernel_size=(1, 3, 3), - stride=(1, 2, 2), - padding=(0, 0, 1, 1, 1, 1), - output_padding=(0, 1, 1), - pad_mode='pad', - has_bias=True, - weight_init=Tensor(conv_weight_init(( - in_channels // 4, in_channels // 4, 1, 3, 3)))) + self.deconv2 = nn.Conv2dTranspose(in_channels // 4, in_channels // 4, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + pad_mode='pad') self.norm2 = nn.BatchNorm2d(in_channels // 4) self.relu2 = relu self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1, has_bias=True, @@ -103,9 +99,7 @@ class DecoderBlock(nn.Cell): x = self.conv1(x) x = self.norm1(x) x = self.relu1(x) - x = self.expand_dims(x, 2) # ExpandDims - x = self.deconv3(x) # Conv3dTranspose - x = x.squeeze(2) # squeeze + x = self.deconv2(x) # Conv2dTranspose x = self.norm2(x) x = self.relu2(x) x = self.conv3(x) -- Gitee