diff --git a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/README.md b/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/README.md
index e468296b772d860e00a18e91ae6ad65302d1c54e..2cd61357548625d37842e54554045555a2423fb1 100644
--- a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/README.md
+++ b/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/README.md
@@ -3,7 +3,7 @@
- [概述](#ZH-CN_TOPIC_0000001172161501)
- - [输入输出数据](#ZH-CN_TOPIC_0000001126281702)
+ - [输入输出数据](#section540883920406)
- [推理环境准备](#ZH-CN_TOPIC_0000001126281702)
@@ -27,12 +27,11 @@ UNet/UNet++是在医学图像处理领域应用广泛的语义分割网络。本
- 参考实现:
- ```shell
- UNet(https://github.com/milesial/Pytorch-UNet)
- branch=master
- UNet++(https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets)
- branch=master
- ```
+```
+url=https://github.com/milesial/Pytorch-UNet
+commit_id=6aa14cbbc445672d97190fec06d5568a0a004740
+model_name=UNet
+```
## 输入输出数据
@@ -40,7 +39,7 @@ UNet/UNet++是在医学图像处理领域应用广泛的语义分割网络。本
| 输入数据 | 数据类型 | 大小 | 数据排布格式 |
| -------- | -------- | ------------------------- | ------------ |
- | images | RGB_FP32 | batchsize x 3 x 512 x 512 | NCHW |
+ | images | RGB_FP32 | batchsize x 3 x 572 x 572 | NCHW |
# 推理环境准备
@@ -51,12 +50,12 @@ UNet/UNet++是在医学图像处理领域应用广泛的语义分割网络。本
| 配套 | 版本 |
|-----------------------|-----------------|
-| CANN | 7.0.T3 | - |
+| CANN | 7.0.0 | - |
| Python | 3.9 |
| PyTorch | 2.0.1 |
| torchVison | 0.15.2 |-
-| Ascend-cann-torch-aie | 7.0.T3
-| Ascend-cann-aie | 7.0.T3
+| Ascend-cann-torch-aie | >= 7.0.T3
+| Ascend-cann-aie | >= 7.0.T3
| 芯片类型 | Ascend310P3 | - |
# 快速上手
@@ -155,11 +154,11 @@ bash uninstall.sh
数据集结构如下
```
test
- |-- images
+ |-- val_images
| |-- fecea3036c59_01.jpg
| |-- fecea3036c59_02.jpg
| |-- ...
- `-- masks
+ `-- val_masks
|-- fecea3036c59_01_mask.gif
|-- fecea3036c59_02_mask.gif
`-- ...
@@ -167,14 +166,23 @@ bash uninstall.sh
## 模型推理
-1. 获取权重文件。
-```
-(U-Net) https://github.com/milesial/Pytorch-UNet/releases/tag/v3.0
-(U-Net++) 使用随机权重
-```
-2. 执行python脚本
+
+1. 获取源码
+ ```
+ git clone https://github.com/milesial/Pytorch-UNet.git
+ cd Pytorch-UNet
+ git reset --hard 6aa14cb
+ cd Pytorch-UNet
+ cp -r ./unet path/to/sample.py
+ ```
+
+2. 获取权重文件。
+ ```
+ (U-Net) https://ascend-repo-modelzoo.obs.cn-east-2.myhuaweicloud.com/model/1_PyTorch_PTH/Unet/PTH/UNet.pth
```
- python3 sample.py --image_path path_to_image --mask_path path_to_mask --model_name "unet" --pth path_to_weight --batch_size 1 --device 0 --loop 100 --warm_counter 10
+3. 执行python脚本
+ ```
+ python3 sample.py --data_path path_to_val_data --pth path_to_weight --batch_size 1 --device 0 --loop 100 --warm_counter 10
```
@@ -183,13 +191,6 @@ bash uninstall.sh
Unet模型精度与性能如下表
-| 芯片型号 | Batch Size | 数据集 | 精度(dice score) | 性能(fps)|
-| --------- | ---------------- | ---------- | ---------- | --------------- |
-| Ascend310P3 | 1 | Carvana Image Masking Challenge | 97.4% | 50.968 |
-
-
-Unet++模型精度与性能如下表
-
-| 芯片型号 | Batch Size | 数据集 | 精度(cosine similarity with torch) | 性能(fps) |
+| 芯片型号 | Batch Size | 数据集 | 精度(iou) | 性能(fps)|
| --------- | ---------------- | ---------- | ---------- | --------------- |
-| Ascend310P3 | 1 | -------- | 100.0% | 25.008 |
+| Ascend310P3 | 1 | Carvana Image Masking Challenge | 98.63% | 68.68 |
diff --git a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/__init__.py b/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/__init__.py
deleted file mode 100644
index db72703966c7cc5d68052d1809ab7b93de55cba3..0000000000000000000000000000000000000000
--- a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/__init__.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# BSD 3-Clause License
-#
-# Copyright (c) 2017 xxxx
-# All rights reserved.
-# Copyright 2021 Huawei Technologies Co., Ltd
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are met:
-#
-# * Redistributions of source code must retain the above copyright notice, this
-# list of conditions and the following disclaimer.
-#
-# * Redistributions in binary form must reproduce the above copyright notice,
-# this list of conditions and the following disclaimer in the documentation
-# and/or other materials provided with the distribution.
-#
-# * Neither the name of the copyright holder nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-# ============================================================================
-
-from .model import UNet, NestedUNet
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/model.py b/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/model.py
deleted file mode 100644
index 4ff704ccc1435eea3669c7f611d47565b1877b66..0000000000000000000000000000000000000000
--- a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/model.py
+++ /dev/null
@@ -1,140 +0,0 @@
-# BSD 3-Clause License
-#
-# Copyright (c) 2017 xxxx
-# All rights reserved.
-# Copyright 2021 Huawei Technologies Co., Ltd
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are met:
-#
-# * Redistributions of source code must retain the above copyright notice, this
-# list of conditions and the following disclaimer.
-#
-# * Redistributions in binary form must reproduce the above copyright notice,
-# this list of conditions and the following disclaimer in the documentation
-# and/or other materials provided with the distribution.
-#
-# * Neither the name of the copyright holder nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-# ============================================================================
-
-""" Full assembly of the parts to form the complete network """
-from torch import nn
-
-from .unet_parts import Down, DoubleConv, Up, OutConv
-from .unetpp_parts import ConvBlockNested
-
-
-class UNet(nn.Module):
- def __init__(self, n_channels, n_classes, bilinear=False):
- super(UNet, self).__init__()
- self.n_channels = n_channels
- self.n_classes = n_classes
- self.bilinear = bilinear
-
- self.inc = (DoubleConv(n_channels, 64))
- self.down1 = (Down(64, 128))
- self.down2 = (Down(128, 256))
- self.down3 = (Down(256, 512))
- factor = 2 if bilinear else 1
- self.down4 = (Down(512, 1024 // factor))
- self.up1 = (Up(1024, 512 // factor, bilinear))
- self.up2 = (Up(512, 256 // factor, bilinear))
- self.up3 = (Up(256, 128 // factor, bilinear))
- self.up4 = (Up(128, 64, bilinear))
- self.outc = (OutConv(64, n_classes))
-
- def forward(self, x):
- x1 = self.inc(x)
- x2 = self.down1(x1)
- x3 = self.down2(x2)
- x4 = self.down3(x3)
- x5 = self.down4(x4)
- x = self.up1(x5, x4)
- x = self.up2(x, x3)
- x = self.up3(x, x2)
- x = self.up4(x, x1)
- logits = self.outc(x)
- return logits
-
- def use_checkpointing(self):
- self.inc = torch.utils.checkpoint(self.inc)
- self.down1 = torch.utils.checkpoint(self.down1)
- self.down2 = torch.utils.checkpoint(self.down2)
- self.down3 = torch.utils.checkpoint(self.down3)
- self.down4 = torch.utils.checkpoint(self.down4)
- self.up1 = torch.utils.checkpoint(self.up1)
- self.up2 = torch.utils.checkpoint(self.up2)
- self.up3 = torch.utils.checkpoint(self.up3)
- self.up4 = torch.utils.checkpoint(self.up4)
- self.outc = torch.utils.checkpoint(self.outc)
-
-
-class NestedUNet(nn.Module):
- def __init__(self, in_ch=3, out_ch=1):
- super(NestedUNet, self).__init__()
-
- n1 = 64
- filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
-
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
- self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
-
- self.conv0_0 = ConvBlockNested(in_ch, filters[0], filters[0])
- self.conv1_0 = ConvBlockNested(filters[0], filters[1], filters[1])
- self.conv2_0 = ConvBlockNested(filters[1], filters[2], filters[2])
- self.conv3_0 = ConvBlockNested(filters[2], filters[3], filters[3])
- self.conv4_0 = ConvBlockNested(filters[3], filters[4], filters[4])
-
- self.conv0_1 = ConvBlockNested(filters[0] + filters[1], filters[0], filters[0])
- self.conv1_1 = ConvBlockNested(filters[1] + filters[2], filters[1], filters[1])
- self.conv2_1 = ConvBlockNested(filters[2] + filters[3], filters[2], filters[2])
- self.conv3_1 = ConvBlockNested(filters[3] + filters[4], filters[3], filters[3])
-
- self.conv0_2 = ConvBlockNested(filters[0] * 2 + filters[1], filters[0], filters[0])
- self.conv1_2 = ConvBlockNested(filters[1] * 2 + filters[2], filters[1], filters[1])
- self.conv2_2 = ConvBlockNested(filters[2] * 2 + filters[3], filters[2], filters[2])
-
- self.conv0_3 = ConvBlockNested(filters[0] * 3 + filters[1], filters[0], filters[0])
- self.conv1_3 = ConvBlockNested(filters[1] * 3 + filters[2], filters[1], filters[1])
-
- self.conv0_4 = ConvBlockNested(filters[0] * 4 + filters[1], filters[0], filters[0])
-
- self.final = nn.Conv2d(filters[0], out_ch, kernel_size=1)
-
-
- def forward(self, x):
-
- x0_0 = self.conv0_0(x)
- x1_0 = self.conv1_0(self.pool(x0_0))
- x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
-
- x2_0 = self.conv2_0(self.pool(x1_0))
- x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
- x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
-
- x3_0 = self.conv3_0(self.pool(x2_0))
- x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
- x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
- x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
-
- x4_0 = self.conv4_0(self.pool(x3_0))
- x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
- x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
- x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
- x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
-
- output = self.final(x0_4)
- return output
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/unet_parts.py b/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/unet_parts.py
deleted file mode 100644
index 9e8ca37b6ec628993c1a5b2e7d8a193f0ddcbc36..0000000000000000000000000000000000000000
--- a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/unet_parts.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# BSD 3-Clause License
-#
-# Copyright (c) 2017 xxxx
-# All rights reserved.
-# Copyright 2021 Huawei Technologies Co., Ltd
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are met:
-#
-# * Redistributions of source code must retain the above copyright notice, this
-# list of conditions and the following disclaimer.
-#
-# * Redistributions in binary form must reproduce the above copyright notice,
-# this list of conditions and the following disclaimer in the documentation
-# and/or other materials provided with the distribution.
-#
-# * Neither the name of the copyright holder nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-# ============================================================================
-
-""" Parts of the U-Net model """
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class DoubleConv(nn.Module):
- """(convolution => [BN] => ReLU) * 2"""
-
- def __init__(self, in_channels, out_channels, mid_channels=None):
- super().__init__()
- if not mid_channels:
- mid_channels = out_channels
- self.double_conv = nn.Sequential(
- nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
- nn.BatchNorm2d(mid_channels),
- nn.ReLU(inplace=True),
- nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True)
- )
-
- def forward(self, x):
- return self.double_conv(x)
-
-
-class Down(nn.Module):
- """Downscaling with maxpool then double conv"""
-
- def __init__(self, in_channels, out_channels):
- super().__init__()
- self.maxpool_conv = nn.Sequential(
- nn.MaxPool2d(2),
- DoubleConv(in_channels, out_channels)
- )
-
- def forward(self, x):
- return self.maxpool_conv(x)
-
-
-class Up(nn.Module):
- """Upscaling then double conv"""
-
- def __init__(self, in_channels, out_channels, bilinear=True):
- super().__init__()
-
- # if bilinear, use the normal convolutions to reduce the number of channels
- if bilinear:
- self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
- self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
- else:
- self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
- self.conv = DoubleConv(in_channels, out_channels)
-
- def forward(self, x1, x2):
- x1 = self.up(x1)
- # input is CHW
- diff_y = x2.size()[2] - x1.size()[2]
- diff_x = x2.size()[3] - x1.size()[3]
-
- x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
- diff_y // 2, diff_y - diff_y // 2])
- x = torch.cat([x2, x1], dim=1)
- return self.conv(x)
-
-
-class OutConv(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(OutConv, self).__init__()
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
-
- def forward(self, x):
- return self.conv(x)
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/unetpp_parts.py b/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/unetpp_parts.py
deleted file mode 100644
index 81e2d3e547c8433b18725f34c49f3542d3cf48e5..0000000000000000000000000000000000000000
--- a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/unetpp_parts.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# BSD 3-Clause License
-#
-# Copyright (c) 2017 xxxx
-# All rights reserved.
-# Copyright 2021 Huawei Technologies Co., Ltd
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are met:
-#
-# * Redistributions of source code must retain the above copyright notice, this
-# list of conditions and the following disclaimer.
-#
-# * Redistributions in binary form must reproduce the above copyright notice,
-# this list of conditions and the following disclaimer in the documentation
-# and/or other materials provided with the distribution.
-#
-# * Neither the name of the copyright holder nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-# ============================================================================
-
-""" Parts of the U-Netpp model """
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class ConvBlockNested(nn.Module):
-
- def __init__(self, in_ch, mid_ch, out_ch):
- super(ConvBlockNested, self).__init__()
- self.activation = nn.ReLU(inplace=True)
- self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1, bias=True)
- self.bn1 = nn.BatchNorm2d(mid_ch)
- self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True)
- self.bn2 = nn.BatchNorm2d(out_ch)
-
- def forward(self, x):
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.activation(x)
-
- x = self.conv2(x)
- x = self.bn2(x)
- output = self.activation(x)
-
- return output
diff --git a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/sample.py b/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/sample.py
index 5e01240bc58782a1d6d6ebf75e200851c36b019b..9a7cb5e06d4a27f932226a4efd8b086b01aee0f8 100644
--- a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/sample.py
+++ b/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/sample.py
@@ -31,6 +31,7 @@
# ============================================================================
import time
+import os
import argparse
from PIL import Image
import numpy as np
@@ -39,41 +40,43 @@ import torch.nn.functional as F
import torch_aie
from torch_aie import _enums
-from model import UNet, NestedUNet
-
-
-def processing(img_path, msk_path, bs, default_size=(512, 512)):
- img = Image.open(img_path)
- img.resize(default_size, resample=Image.BICUBIC)
- img = np.asarray(img)
- img = img.transpose(2, 0, 1)
- if (img > 1).any():
- img = img / 255
- img = img.astype(np.float32)
- img_tensor = torch.from_numpy(img)
- img_tensor = img_tensor.expand(bs, *img_tensor)
-
- msk = Image.open(msk_path)
- msk.resize(default_size, resample=Image.BICUBIC)
- msk = np.asarray(msk)
- mask_tensor = torch.as_tensor(msk.copy()).long().contiguous()
- mask_tensor = img_tensor.expand(bs, *mask_tensor)
-
- return img_tensor, mask_tensor
-
-
-def init_model(model_path, model_type="unet"):
- if model_type == "unet" :
- model_init = UNet(3, 2)
- model_init.load_state_dict(torch.load(model_path, map_location="cpu"))
- model_init.eval()
- return model_init
- else :
- model_init = NestedUNet(3, 2)
- model_init.load_state_dict(torch.load(model_path, map_location="cpu"))
- model_init.eval()
- return model_init
-
+from unet import UNet
+
+
+def processing(data_path, bs=1, img_size=572):
+ images_path = os.path.join(data_path, "val_images")
+ masks_path = os.path.join(data_path, "val_masks")
+ images, masks = [], []
+ default_size = (img_size, img_size)
+ for name in sorted(os.listdir(images_path)):
+ img_path = os.path.join(images_path, name)
+ img = Image.open(img_path)
+ img = img.resize(default_size, resample=Image.BICUBIC)
+ img = np.asarray(img)
+ img = img.transpose(2, 0, 1)
+ if (img > 1).any():
+ img = img / 255
+ img = img.astype(np.float32)
+ img_tensor = torch.from_numpy(img)
+ img_tensor = img_tensor.expand(bs, 3, img_size, img_size)
+ images.append(img_tensor)
+
+ img_id = name.split(".")[0]
+ msk_name = img_id + "_mask.gif"
+ msk_path = os.path.join(masks_path, msk_name)
+ msk = Image.open(msk_path)
+ msk = msk.resize(default_size, resample=Image.BICUBIC)
+ msk = np.asarray(msk)
+ msk_tensor = torch.as_tensor(msk.copy()).long().contiguous()
+ msk_tensor = msk_tensor.expand(bs, 1, img_size, img_size)
+ masks.append(msk_tensor)
+ return images, masks
+
+def init_model(model_path, nc=1):
+ model_init = UNet(n_channels=3, n_classes=nc, bilinear=False)
+ model_init.load_state_dict(torch.load(model_path, map_location="cpu"))
+ model_init.eval()
+ return model_init
def compile_model(model_compiled, data, data_info):
trace_model = torch.jit.trace(model_compiled, data)
@@ -82,68 +85,50 @@ def compile_model(model_compiled, data, data_info):
precision_policy=_enums.PrecisionPolicy.FP16,
allow_tensor_replace_int=True,
soc_version="Ascend310P3")
+ print("compile done")
return pt_model
+def get_iou(pred, gt):
+ inter = torch.logical_and(pred, gt)
+ union = torch.logical_or(pred, gt)
+ return torch.sum(inter) / torch.sum(union)
-def dice_coeff(pred, target, reduce_batch_first=False, eps=1e-6):
- # Average of Dice coefficient for all batches, or for a single mask
- assert pred.size() == target.size()
- assert pred.dim() == 3 or not reduce_batch_first
-
- sum_dim = (-1, -2) if pred.dim() == 2 or not reduce_batch_first else (-1, -2, -3)
-
- inter = 2 * (pred * target).sum(dim=sum_dim)
- sets_sum = pred.sum(dim=sum_dim) + target.sum(dim=sum_dim)
- sets_sum = torch.where(sets_sum == 0, inter, sets_sum)
-
- dice = (inter + eps) / (sets_sum + eps)
- return dice.mean()
-
-
-def compute_score(mask_pred, mask_true, classes):
- mask_true = F.one_hot(mask_true, 2).permute(0, 3, 1, 2).float()
- mask_pred = F.one_hot(mask_pred.argmax(dim=1), classes).permute(0, 3, 1, 2).float()
- # compute the Dice score, ignoring background
- dice_score += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False)
-
- print(f"The score is {dice_coeff:0.3f}")
+def compute_score(mask_pred, mask_true):
+ mask_pred = torch.sigmoid(mask_pred)
+ mask_pred = (mask_pred > 0.5).type_as(mask_true)
+ return get_iou(mask_pred, mask_true)
def compute_fps(model_eval, data, loop_counter, warm_counter):
+ data = data.contiguous().to("npu:0")
+ stream = torch_aie.npu.Stream("npu:0")
+
loops = loop_counter
while warm_counter:
- _ = model_eval(data)
+ with torch_aie.npu.stream(stream):
+ _ = model_eval(data)
+ stream.synchronize()
warm_counter -= 1
t0 = time.time()
- while loops:
- _ = model_eval(data)
- loops -= 1
+ while loop_counter:
+ with torch_aie.npu.stream(stream):
+ _ = model_eval(data)
+ stream.synchronize()
+ loop_counter -= 1
time_cost = time.time() - t0
- print(f"fps: {loop_counter} * {data.shape[0]} / {time_cost : .3f} samples/s")
+ fps = loops * data.shape[0] / time_cost
+ print(f"fps: {fps} samples/s")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
- '--image_path',
- type=str,
- help="path of demo image"
- )
-
- parser.add_argument(
- '--mask_path',
- type=str,
- help="path of demo mask"
- )
-
- parser.add_argument(
- '--model_name',
+ '--data_path',
type=str,
- default="unet",
- help="only support unet/unetpp"
+ help="path to dataset"
)
parser.add_argument(
@@ -159,10 +144,17 @@ def parse_args():
help="batch size, default is 1"
)
+ parser.add_argument(
+ '--image_size',
+ type=int,
+ default=572,
+ help="image size, default is 572"
+ )
+
parser.add_argument(
'--num_class',
type=int,
- default=2,
+ default=1,
help="num of classes, default is 2"
)
@@ -191,22 +183,28 @@ def parse_args():
if __name__ == "__main__":
+ # init
opts = parse_args()
print(opts)
-
- input_info = [torch_aie.Input(shape=(1, 3, 512, 512))]
- net = init_model(opts.pth, opts.model_name)
- image, mask = processing(opts.image_path, opts.mask_path, opts.batch_size)
- compiled_model = compile_model(net, image, input_info)
- jit_result = net(image)
- aie_result = compiled_model(image)
- jit_dice_score = compute_score(jit_result, mask, opts.num_class)
- aie_dice_score = compute_score(aie_result, mask, opts.num_class)
- print(f"jit infer score: {jit_dice_score}, aie infer score: {aie_dice_score}")
-
- cosine_similarity = torch.cosine_similarity(jit_result, aie_result, 0, 1e-6)
- print(f"cosine similarity between jit result and aie result is: {cosine_similarity}")
-
- compute_fps(compiled_model, image, opts.loop, opts.warm_counter)
+ torch_aie.set_device(0)
+
+ # preprocessing
+ images, masks = processing(opts.data_path, opts.batch_size, opts.image_size)
+
+ # compile
+ input_info = [torch_aie.Input(shape=(opts.batch_size, 3, opts.image_size, opts.image_size))]
+ net = init_model(opts.pth, opts.num_class)
+ compiled_model = compile_model(net, images[0], input_info)
+
+ # infer
+ aie_scores = []
+ for img, msk in zip(images, masks):
+ img = img.contiguous().to("npu:0")
+ aie_result = compiled_model(img).to("cpu")
+ aie_scores.append(compute_score(aie_result, msk))
+ print(f"aie iou score is: {sum(aie_scores)/len(aie_scores)}")
+
+ # performance test
+ compute_fps(compiled_model, images[0], opts.loop, opts.warm_counter)