diff --git a/PyTorch/dev/cv/image_classification/PASSRnet_ID0986_for_PyTorch/models.py b/PyTorch/dev/cv/image_classification/PASSRnet_ID0986_for_PyTorch/models.py index 77821070468ddd4b4114171dbaf53752a6a97738..e3c5191a5f5c2d1f8deae3ffe537f51eab270aa0 100644 --- a/PyTorch/dev/cv/image_classification/PASSRnet_ID0986_for_PyTorch/models.py +++ b/PyTorch/dev/cv/image_classification/PASSRnet_ID0986_for_PyTorch/models.py @@ -181,6 +181,7 @@ class PAM(nn.Module): ### fusion buffer = self.b3(x_right).permute(0,2,3,1).contiguous().view(-1, w, c) # (B*H) * W * C buffer = torch.bmm(M_right_to_left, buffer).contiguous().view(b, h, w, c).permute(0,3,1,2) # B * C * H * W + V_left_to_right = V_left_to_right.to(x_left.dtype) out = self.fusion(torch.cat((buffer, x_left, V_left_to_right), 1)) ## output