diff --git a/research/cv/dlinknet/src/dinknet.py b/research/cv/dlinknet/src/dinknet.py index 0fd2dcfad7c331566f2cada17c20896662c1cb0b..9c87b8c13470c85bc07dba8ca2cc536f705aa298 100644 --- a/research/cv/dlinknet/src/dinknet.py +++ b/research/cv/dlinknet/src/dinknet.py @@ -82,16 +82,12 @@ class DecoderBlock(nn.Cell): weight_init=Tensor(conv_weight_init((in_channels // 4, in_channels, 1, 1)))) self.norm1 = nn.BatchNorm2d(in_channels // 4) self.relu1 = relu - self.expand_dims = P.ExpandDims() - self.deconv3 = nn.Conv3dTranspose(in_channels // 4, in_channels // 4, - kernel_size=(1, 3, 3), - stride=(1, 2, 2), - padding=(0, 0, 1, 1, 1, 1), - output_padding=(0, 1, 1), - pad_mode='pad', - has_bias=True, - weight_init=Tensor(conv_weight_init(( - in_channels // 4, in_channels // 4, 1, 3, 3)))) + self.deconv2 = nn.Conv2dTranspose(in_channels // 4, in_channels // 4, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + pad_mode='pad') self.norm2 = nn.BatchNorm2d(in_channels // 4) self.relu2 = relu self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1, has_bias=True, @@ -103,9 +99,7 @@ class DecoderBlock(nn.Cell): x = self.conv1(x) x = self.norm1(x) x = self.relu1(x) - x = self.expand_dims(x, 2) # ExpandDims - x = self.deconv3(x) # Conv3dTranspose - x = x.squeeze(2) # squeeze + x = self.deconv2(x) # Conv2dTranspose x = self.norm2(x) x = self.relu2(x) x = self.conv3(x)