diff --git a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/README.md b/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/README.md index e468296b772d860e00a18e91ae6ad65302d1c54e..2cd61357548625d37842e54554045555a2423fb1 100644 --- a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/README.md +++ b/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/README.md @@ -3,7 +3,7 @@ - [概述](#ZH-CN_TOPIC_0000001172161501) - - [输入输出数据](#ZH-CN_TOPIC_0000001126281702) + - [输入输出数据](#section540883920406) - [推理环境准备](#ZH-CN_TOPIC_0000001126281702) @@ -27,12 +27,11 @@ UNet/UNet++是在医学图像处理领域应用广泛的语义分割网络。本 - 参考实现: - ```shell - UNet(https://github.com/milesial/Pytorch-UNet) - branch=master - UNet++(https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets) - branch=master - ``` +``` +url=https://github.com/milesial/Pytorch-UNet +commit_id=6aa14cbbc445672d97190fec06d5568a0a004740 +model_name=UNet +``` ## 输入输出数据 @@ -40,7 +39,7 @@ UNet/UNet++是在医学图像处理领域应用广泛的语义分割网络。本 | 输入数据 | 数据类型 | 大小 | 数据排布格式 | | -------- | -------- | ------------------------- | ------------ | - | images | RGB_FP32 | batchsize x 3 x 512 x 512 | NCHW | + | images | RGB_FP32 | batchsize x 3 x 572 x 572 | NCHW | # 推理环境准备 @@ -51,12 +50,12 @@ UNet/UNet++是在医学图像处理领域应用广泛的语义分割网络。本 | 配套 | 版本 | |-----------------------|-----------------| -| CANN | 7.0.T3 | - | +| CANN | 7.0.0 | - | | Python | 3.9 | | PyTorch | 2.0.1 | | torchVison | 0.15.2 |- -| Ascend-cann-torch-aie | 7.0.T3 -| Ascend-cann-aie | 7.0.T3 +| Ascend-cann-torch-aie | >= 7.0.T3 +| Ascend-cann-aie | >= 7.0.T3 | 芯片类型 | Ascend310P3 | - | # 快速上手 @@ -155,11 +154,11 @@ bash uninstall.sh 数据集结构如下 ``` test - |-- images + |-- val_images | |-- fecea3036c59_01.jpg | |-- fecea3036c59_02.jpg | |-- ... - `-- masks + `-- val_masks |-- fecea3036c59_01_mask.gif |-- fecea3036c59_02_mask.gif `-- ... @@ -167,14 +166,23 @@ bash uninstall.sh ## 模型推理 -1. 获取权重文件。 -``` -(U-Net) https://github.com/milesial/Pytorch-UNet/releases/tag/v3.0 -(U-Net++) 使用随机权重 -``` -2. 执行python脚本 + +1. 获取源码 + ``` + git clone https://github.com/milesial/Pytorch-UNet.git + cd Pytorch-UNet + git reset --hard 6aa14cb + cd Pytorch-UNet + cp -r ./unet path/to/sample.py + ``` + +2. 获取权重文件。 + ``` + (U-Net) https://ascend-repo-modelzoo.obs.cn-east-2.myhuaweicloud.com/model/1_PyTorch_PTH/Unet/PTH/UNet.pth ``` - python3 sample.py --image_path path_to_image --mask_path path_to_mask --model_name "unet" --pth path_to_weight --batch_size 1 --device 0 --loop 100 --warm_counter 10 +3. 执行python脚本 + ``` + python3 sample.py --data_path path_to_val_data --pth path_to_weight --batch_size 1 --device 0 --loop 100 --warm_counter 10 ``` @@ -183,13 +191,6 @@ bash uninstall.sh Unet模型精度与性能如下表 -| 芯片型号 | Batch Size | 数据集 | 精度(dice score) | 性能(fps)| -| --------- | ---------------- | ---------- | ---------- | --------------- | -| Ascend310P3 | 1 | Carvana Image Masking Challenge | 97.4% | 50.968 | - - -Unet++模型精度与性能如下表 - -| 芯片型号 | Batch Size | 数据集 | 精度(cosine similarity with torch) | 性能(fps) | +| 芯片型号 | Batch Size | 数据集 | 精度(iou) | 性能(fps)| | --------- | ---------------- | ---------- | ---------- | --------------- | -| Ascend310P3 | 1 | -------- | 100.0% | 25.008 | +| Ascend310P3 | 1 | Carvana Image Masking Challenge | 98.63% | 68.68 | diff --git a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/__init__.py b/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/__init__.py deleted file mode 100644 index db72703966c7cc5d68052d1809ab7b93de55cba3..0000000000000000000000000000000000000000 --- a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# BSD 3-Clause License -# -# Copyright (c) 2017 xxxx -# All rights reserved. -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# * Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# ============================================================================ - -from .model import UNet, NestedUNet \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/model.py b/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/model.py deleted file mode 100644 index 4ff704ccc1435eea3669c7f611d47565b1877b66..0000000000000000000000000000000000000000 --- a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/model.py +++ /dev/null @@ -1,140 +0,0 @@ -# BSD 3-Clause License -# -# Copyright (c) 2017 xxxx -# All rights reserved. -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# * Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# ============================================================================ - -""" Full assembly of the parts to form the complete network """ -from torch import nn - -from .unet_parts import Down, DoubleConv, Up, OutConv -from .unetpp_parts import ConvBlockNested - - -class UNet(nn.Module): - def __init__(self, n_channels, n_classes, bilinear=False): - super(UNet, self).__init__() - self.n_channels = n_channels - self.n_classes = n_classes - self.bilinear = bilinear - - self.inc = (DoubleConv(n_channels, 64)) - self.down1 = (Down(64, 128)) - self.down2 = (Down(128, 256)) - self.down3 = (Down(256, 512)) - factor = 2 if bilinear else 1 - self.down4 = (Down(512, 1024 // factor)) - self.up1 = (Up(1024, 512 // factor, bilinear)) - self.up2 = (Up(512, 256 // factor, bilinear)) - self.up3 = (Up(256, 128 // factor, bilinear)) - self.up4 = (Up(128, 64, bilinear)) - self.outc = (OutConv(64, n_classes)) - - def forward(self, x): - x1 = self.inc(x) - x2 = self.down1(x1) - x3 = self.down2(x2) - x4 = self.down3(x3) - x5 = self.down4(x4) - x = self.up1(x5, x4) - x = self.up2(x, x3) - x = self.up3(x, x2) - x = self.up4(x, x1) - logits = self.outc(x) - return logits - - def use_checkpointing(self): - self.inc = torch.utils.checkpoint(self.inc) - self.down1 = torch.utils.checkpoint(self.down1) - self.down2 = torch.utils.checkpoint(self.down2) - self.down3 = torch.utils.checkpoint(self.down3) - self.down4 = torch.utils.checkpoint(self.down4) - self.up1 = torch.utils.checkpoint(self.up1) - self.up2 = torch.utils.checkpoint(self.up2) - self.up3 = torch.utils.checkpoint(self.up3) - self.up4 = torch.utils.checkpoint(self.up4) - self.outc = torch.utils.checkpoint(self.outc) - - -class NestedUNet(nn.Module): - def __init__(self, in_ch=3, out_ch=1): - super(NestedUNet, self).__init__() - - n1 = 64 - filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] - - self.pool = nn.MaxPool2d(kernel_size=2, stride=2) - self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) - - self.conv0_0 = ConvBlockNested(in_ch, filters[0], filters[0]) - self.conv1_0 = ConvBlockNested(filters[0], filters[1], filters[1]) - self.conv2_0 = ConvBlockNested(filters[1], filters[2], filters[2]) - self.conv3_0 = ConvBlockNested(filters[2], filters[3], filters[3]) - self.conv4_0 = ConvBlockNested(filters[3], filters[4], filters[4]) - - self.conv0_1 = ConvBlockNested(filters[0] + filters[1], filters[0], filters[0]) - self.conv1_1 = ConvBlockNested(filters[1] + filters[2], filters[1], filters[1]) - self.conv2_1 = ConvBlockNested(filters[2] + filters[3], filters[2], filters[2]) - self.conv3_1 = ConvBlockNested(filters[3] + filters[4], filters[3], filters[3]) - - self.conv0_2 = ConvBlockNested(filters[0] * 2 + filters[1], filters[0], filters[0]) - self.conv1_2 = ConvBlockNested(filters[1] * 2 + filters[2], filters[1], filters[1]) - self.conv2_2 = ConvBlockNested(filters[2] * 2 + filters[3], filters[2], filters[2]) - - self.conv0_3 = ConvBlockNested(filters[0] * 3 + filters[1], filters[0], filters[0]) - self.conv1_3 = ConvBlockNested(filters[1] * 3 + filters[2], filters[1], filters[1]) - - self.conv0_4 = ConvBlockNested(filters[0] * 4 + filters[1], filters[0], filters[0]) - - self.final = nn.Conv2d(filters[0], out_ch, kernel_size=1) - - - def forward(self, x): - - x0_0 = self.conv0_0(x) - x1_0 = self.conv1_0(self.pool(x0_0)) - x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1)) - - x2_0 = self.conv2_0(self.pool(x1_0)) - x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1)) - x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1)) - - x3_0 = self.conv3_0(self.pool(x2_0)) - x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1)) - x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1)) - x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1)) - - x4_0 = self.conv4_0(self.pool(x3_0)) - x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1)) - x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1)) - x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1)) - x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1)) - - output = self.final(x0_4) - return output \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/unet_parts.py b/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/unet_parts.py deleted file mode 100644 index 9e8ca37b6ec628993c1a5b2e7d8a193f0ddcbc36..0000000000000000000000000000000000000000 --- a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/unet_parts.py +++ /dev/null @@ -1,106 +0,0 @@ -# BSD 3-Clause License -# -# Copyright (c) 2017 xxxx -# All rights reserved. -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# * Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# ============================================================================ - -""" Parts of the U-Net model """ - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class DoubleConv(nn.Module): - """(convolution => [BN] => ReLU) * 2""" - - def __init__(self, in_channels, out_channels, mid_channels=None): - super().__init__() - if not mid_channels: - mid_channels = out_channels - self.double_conv = nn.Sequential( - nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(mid_channels), - nn.ReLU(inplace=True), - nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True) - ) - - def forward(self, x): - return self.double_conv(x) - - -class Down(nn.Module): - """Downscaling with maxpool then double conv""" - - def __init__(self, in_channels, out_channels): - super().__init__() - self.maxpool_conv = nn.Sequential( - nn.MaxPool2d(2), - DoubleConv(in_channels, out_channels) - ) - - def forward(self, x): - return self.maxpool_conv(x) - - -class Up(nn.Module): - """Upscaling then double conv""" - - def __init__(self, in_channels, out_channels, bilinear=True): - super().__init__() - - # if bilinear, use the normal convolutions to reduce the number of channels - if bilinear: - self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) - self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) - else: - self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) - self.conv = DoubleConv(in_channels, out_channels) - - def forward(self, x1, x2): - x1 = self.up(x1) - # input is CHW - diff_y = x2.size()[2] - x1.size()[2] - diff_x = x2.size()[3] - x1.size()[3] - - x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2, - diff_y // 2, diff_y - diff_y // 2]) - x = torch.cat([x2, x1], dim=1) - return self.conv(x) - - -class OutConv(nn.Module): - def __init__(self, in_channels, out_channels): - super(OutConv, self).__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) - - def forward(self, x): - return self.conv(x) \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/unetpp_parts.py b/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/unetpp_parts.py deleted file mode 100644 index 81e2d3e547c8433b18725f34c49f3542d3cf48e5..0000000000000000000000000000000000000000 --- a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/model/unetpp_parts.py +++ /dev/null @@ -1,59 +0,0 @@ -# BSD 3-Clause License -# -# Copyright (c) 2017 xxxx -# All rights reserved. -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# * Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# ============================================================================ - -""" Parts of the U-Netpp model """ - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class ConvBlockNested(nn.Module): - - def __init__(self, in_ch, mid_ch, out_ch): - super(ConvBlockNested, self).__init__() - self.activation = nn.ReLU(inplace=True) - self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1, bias=True) - self.bn1 = nn.BatchNorm2d(mid_ch) - self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True) - self.bn2 = nn.BatchNorm2d(out_ch) - - def forward(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.activation(x) - - x = self.conv2(x) - x = self.bn2(x) - output = self.activation(x) - - return output diff --git a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/sample.py b/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/sample.py index 5e01240bc58782a1d6d6ebf75e200851c36b019b..9a7cb5e06d4a27f932226a4efd8b086b01aee0f8 100644 --- a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/sample.py +++ b/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/UnetUnet++/sample.py @@ -31,6 +31,7 @@ # ============================================================================ import time +import os import argparse from PIL import Image import numpy as np @@ -39,41 +40,43 @@ import torch.nn.functional as F import torch_aie from torch_aie import _enums -from model import UNet, NestedUNet - - -def processing(img_path, msk_path, bs, default_size=(512, 512)): - img = Image.open(img_path) - img.resize(default_size, resample=Image.BICUBIC) - img = np.asarray(img) - img = img.transpose(2, 0, 1) - if (img > 1).any(): - img = img / 255 - img = img.astype(np.float32) - img_tensor = torch.from_numpy(img) - img_tensor = img_tensor.expand(bs, *img_tensor) - - msk = Image.open(msk_path) - msk.resize(default_size, resample=Image.BICUBIC) - msk = np.asarray(msk) - mask_tensor = torch.as_tensor(msk.copy()).long().contiguous() - mask_tensor = img_tensor.expand(bs, *mask_tensor) - - return img_tensor, mask_tensor - - -def init_model(model_path, model_type="unet"): - if model_type == "unet" : - model_init = UNet(3, 2) - model_init.load_state_dict(torch.load(model_path, map_location="cpu")) - model_init.eval() - return model_init - else : - model_init = NestedUNet(3, 2) - model_init.load_state_dict(torch.load(model_path, map_location="cpu")) - model_init.eval() - return model_init - +from unet import UNet + + +def processing(data_path, bs=1, img_size=572): + images_path = os.path.join(data_path, "val_images") + masks_path = os.path.join(data_path, "val_masks") + images, masks = [], [] + default_size = (img_size, img_size) + for name in sorted(os.listdir(images_path)): + img_path = os.path.join(images_path, name) + img = Image.open(img_path) + img = img.resize(default_size, resample=Image.BICUBIC) + img = np.asarray(img) + img = img.transpose(2, 0, 1) + if (img > 1).any(): + img = img / 255 + img = img.astype(np.float32) + img_tensor = torch.from_numpy(img) + img_tensor = img_tensor.expand(bs, 3, img_size, img_size) + images.append(img_tensor) + + img_id = name.split(".")[0] + msk_name = img_id + "_mask.gif" + msk_path = os.path.join(masks_path, msk_name) + msk = Image.open(msk_path) + msk = msk.resize(default_size, resample=Image.BICUBIC) + msk = np.asarray(msk) + msk_tensor = torch.as_tensor(msk.copy()).long().contiguous() + msk_tensor = msk_tensor.expand(bs, 1, img_size, img_size) + masks.append(msk_tensor) + return images, masks + +def init_model(model_path, nc=1): + model_init = UNet(n_channels=3, n_classes=nc, bilinear=False) + model_init.load_state_dict(torch.load(model_path, map_location="cpu")) + model_init.eval() + return model_init def compile_model(model_compiled, data, data_info): trace_model = torch.jit.trace(model_compiled, data) @@ -82,68 +85,50 @@ def compile_model(model_compiled, data, data_info): precision_policy=_enums.PrecisionPolicy.FP16, allow_tensor_replace_int=True, soc_version="Ascend310P3") + print("compile done") return pt_model +def get_iou(pred, gt): + inter = torch.logical_and(pred, gt) + union = torch.logical_or(pred, gt) + return torch.sum(inter) / torch.sum(union) -def dice_coeff(pred, target, reduce_batch_first=False, eps=1e-6): - # Average of Dice coefficient for all batches, or for a single mask - assert pred.size() == target.size() - assert pred.dim() == 3 or not reduce_batch_first - - sum_dim = (-1, -2) if pred.dim() == 2 or not reduce_batch_first else (-1, -2, -3) - - inter = 2 * (pred * target).sum(dim=sum_dim) - sets_sum = pred.sum(dim=sum_dim) + target.sum(dim=sum_dim) - sets_sum = torch.where(sets_sum == 0, inter, sets_sum) - - dice = (inter + eps) / (sets_sum + eps) - return dice.mean() - - -def compute_score(mask_pred, mask_true, classes): - mask_true = F.one_hot(mask_true, 2).permute(0, 3, 1, 2).float() - mask_pred = F.one_hot(mask_pred.argmax(dim=1), classes).permute(0, 3, 1, 2).float() - # compute the Dice score, ignoring background - dice_score += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False) - - print(f"The score is {dice_coeff:0.3f}") +def compute_score(mask_pred, mask_true): + mask_pred = torch.sigmoid(mask_pred) + mask_pred = (mask_pred > 0.5).type_as(mask_true) + return get_iou(mask_pred, mask_true) def compute_fps(model_eval, data, loop_counter, warm_counter): + data = data.contiguous().to("npu:0") + stream = torch_aie.npu.Stream("npu:0") + loops = loop_counter while warm_counter: - _ = model_eval(data) + with torch_aie.npu.stream(stream): + _ = model_eval(data) + stream.synchronize() warm_counter -= 1 t0 = time.time() - while loops: - _ = model_eval(data) - loops -= 1 + while loop_counter: + with torch_aie.npu.stream(stream): + _ = model_eval(data) + stream.synchronize() + loop_counter -= 1 time_cost = time.time() - t0 - print(f"fps: {loop_counter} * {data.shape[0]} / {time_cost : .3f} samples/s") + fps = loops * data.shape[0] / time_cost + print(f"fps: {fps} samples/s") def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( - '--image_path', - type=str, - help="path of demo image" - ) - - parser.add_argument( - '--mask_path', - type=str, - help="path of demo mask" - ) - - parser.add_argument( - '--model_name', + '--data_path', type=str, - default="unet", - help="only support unet/unetpp" + help="path to dataset" ) parser.add_argument( @@ -159,10 +144,17 @@ def parse_args(): help="batch size, default is 1" ) + parser.add_argument( + '--image_size', + type=int, + default=572, + help="image size, default is 572" + ) + parser.add_argument( '--num_class', type=int, - default=2, + default=1, help="num of classes, default is 2" ) @@ -191,22 +183,28 @@ def parse_args(): if __name__ == "__main__": + # init opts = parse_args() print(opts) - - input_info = [torch_aie.Input(shape=(1, 3, 512, 512))] - net = init_model(opts.pth, opts.model_name) - image, mask = processing(opts.image_path, opts.mask_path, opts.batch_size) - compiled_model = compile_model(net, image, input_info) - jit_result = net(image) - aie_result = compiled_model(image) - jit_dice_score = compute_score(jit_result, mask, opts.num_class) - aie_dice_score = compute_score(aie_result, mask, opts.num_class) - print(f"jit infer score: {jit_dice_score}, aie infer score: {aie_dice_score}") - - cosine_similarity = torch.cosine_similarity(jit_result, aie_result, 0, 1e-6) - print(f"cosine similarity between jit result and aie result is: {cosine_similarity}") - - compute_fps(compiled_model, image, opts.loop, opts.warm_counter) + torch_aie.set_device(0) + + # preprocessing + images, masks = processing(opts.data_path, opts.batch_size, opts.image_size) + + # compile + input_info = [torch_aie.Input(shape=(opts.batch_size, 3, opts.image_size, opts.image_size))] + net = init_model(opts.pth, opts.num_class) + compiled_model = compile_model(net, images[0], input_info) + + # infer + aie_scores = [] + for img, msk in zip(images, masks): + img = img.contiguous().to("npu:0") + aie_result = compiled_model(img).to("cpu") + aie_scores.append(compute_score(aie_result, msk)) + print(f"aie iou score is: {sum(aie_scores)/len(aie_scores)}") + + # performance test + compute_fps(compiled_model, images[0], opts.loop, opts.warm_counter)