diff --git a/community/cv/ADCAM/README_CN.md b/community/cv/ADCAM/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..2007560121656f081f841631eb663a56b0e907dc --- /dev/null +++ b/community/cv/ADCAM/README_CN.md @@ -0,0 +1,222 @@ +# 目录 + +- [目录](#目录) +- [ADCAM描述](#ADCAM描述) +- [模型架构](#模型架构) +- [数据集](#数据集) +- [快速入门](#快速入门) +- [脚本说明](#脚本说明) + - [脚本及样例代码](#脚本及样例代码) + - [脚本参数](#脚本参数) + - [训练过程](#训练过程) + - [训练](#训练) + - [分布式训练](#分布式训练) + - [推理过程](#推理过程) + - [评估](#评估) + +# [ADCAM描述](#目录) + +我们搜寻了相关作物种类识别和病虫害检测的论文和模型,对相应模型进行分析比较,提出了ADCAM模型,该模型融合了YOLO的CSPDarknet模块与注意力机制,并利用 WIOU损失函数提升准确率和收敛速度,最后对模型进行迁移到 mindspore中,并在IP102数据集实现了比较优秀的目标检测性能。 + +# [模型架构](#目录) + +ADCAM主要组成:CSP结构和Focus结构作为骨干、空间金字塔池化(SPP)作为附加模块、PANet路径聚合作为颈部、并在其中加入CA(Coordinate Attention)模块。CSP是一个新的骨干网络,可以增强CNN的学习能力。在CSP上添加空间金字塔池化模块来增加更多可接受空间,并分离出最重要的上下文特征。CA注意力机制是一种用于加强深度学习模型对输入数据的空间结构理解的注意力机制。CA 注意力机制的核心思想是引入坐标信息,以便模型可以更好地理解不同位置之间的关系。 + +# [数据集](#目录) + +使用的数据集:[IP102]() + +需要将IP102数据集处理成COCO格式,如下图所示: + +COCO_ROOT #根目录 + +├── annotations # 存放json格式的标注 + +│ ├── instances_train2017.json + +│ └── instances_val2017.json + +└── train2017 # 存放图片文件 + +│ ├── IP0000000001.jpg + +│ ├── IP0000000002.jpg + +│ └── IP0000000003.jpg + +└── val2017 + +│ ├── IP0000000004.jpg + +│ └── IP0000000005.jpg + +IP102数据集是一 个广泛用于农业害虫检测和分类的数据集。这个数据集特别为农作物害虫的图像识 别任务而设计,包含了多种不同的害虫类型,是进行农业害虫研究和开发相关算法的重要资源。IP102包含约75,222张图像,涵盖了102种不同的害虫类别。 + +# [快速入门](#目录) + +通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估: + +```bash +#通过Python在Ascend进行训练(单卡) +python train.py \ + --data_dir=xxx/dataset \ + --is_distributed=0 \ + --lr=0.01 \ + --max_epoch=320 \ +``` + +```bash +# 在Ascend中运行shell脚本进行分布式训练示例(8卡) +bash run_distribute_train.sh [DATASET_PATH] [RANK_TABLE_FILE] +``` + +```bash +# 通过Python命令在Ascend上运行评估 +python eval.py \ + --data_dir=xxx/dataset \ + --pretrained="***/*.ckpt" \ +``` + +# [脚本说明](#目录) + +## [脚本及样例代码](#目录) + +```text +├── model_zoo + ├── ADCAM + ├── README_CN.md // 启动模型训练教程 + ├── scripts + │ ├──docker_start.sh // 运行shell脚本启动docker + │ ├──run_distribute_train.sh // 在Ascend中进行分布式训练(8卡) + ├──model_utils + │ ├──config.py // 参数配置 + │ ├──device_adapter.py // 获取设备信息 + │ ├──local_adapter.py // 获取设备信息 + │ ├──moxing_adapter.py // 装饰器 + ├── src + │ ├──backbone.py // 骨干网络 + │ ├──distributed_sampler.py // 数据集迭代 + │ ├──initializer.py // 参数初始化 + │ ├──logger.py // 日志函数 + │ ├──loss.py // 损失函数 + │ ├──lr_scheduler.py // 生成学习率 + │ ├──transforms.py // 预处理数据 + │ ├──util.py // Util函数 + │ ├──yolo.py // YOLO网络 + │ ├──yolo_dataset.py // 创建YOLO数据集 + ├── default_config.yaml // 参数配置 + ├── train.py // 训练脚本 + ├── eval.py // 评估脚本 +``` + +## [脚本参数](#目录) + +```text +train.py中主要的参数有: + +可选参数: + + --data_dir 训练数据集目录 + --per_batch_size 训练的批处理大小。默认值:32(单卡),16(Ascend 8卡) + --lr_scheduler 学习率调度器。可选值:exponential或cosine_annealing + 默认值:cosine_annealing + --lr 学习率。默认值:0.01(单卡),0.02(Ascend 8卡) + --lr_epochs 学习率变化轮次,用英文逗号(,)分割。默认值为'220,250'。 + --max_epoch 模型训练最大轮次。默认值为300(8卡)。 + --ckpt_path CKPT文件保存位置。默认值为outputs/。 + --is_distributed 是否进行分布式训练,1表示是,0表示否。默认值为0。 +``` + +## [训练过程](#目录) + +### 训练 + +在Ascend上开始单机训练 + +```shell +#使用python命令进行训练(单卡) +python train.py \ + --data_dir=xxx/dataset \ + --is_distributed=0 \ + --lr=0.01 \ + --max_epoch=320 \ + --per_batch_size=32 \ + --lr_scheduler=cosine_annealing > log.txt 2>&1 & +``` + +在GPU上进行单卡训练时,应微调参数。 + +上述python命令将在后台运行,您可以通过`log.txt`文件查看结果。 + +训练结束后,您可在默认**outputs**文件夹下找到checkpoint文件。得到如下损失值: + +![image-loss](loss.png) + +### 分布式训练 + +运行shell脚本进行分布式训练示例(8卡) + +```bash +# 在Ascend环境中运行shell脚本进行分布式训练示例(8卡) +bash run_distribute_train.sh [DATASET_PATH] [RANK_TABLE_FILE] +``` + +上述shell脚本将在后台运行分布式训练。您可以通过文件train_parallel[X]/log.txt(Ascend)查看结果 得到如下损失值: + +```text +# 分布式训练结果(8卡,动态shape) +... +2024-10-23 13:01:34,116:INFO:epoch[0], iter[200], loss:415.453676, fps:580.07 imgs/sec, lr:0.0002742903889156878 +2024-10-23 13:01:57,588:INFO:epoch[0], iter[300], loss:273.358383, fps:545.96 imgs/sec, lr:0.00041075327317230403 +2024-10-23 13:02:26,247:INFO:epoch[0], iter[400], loss:244.621502, fps:446.64 imgs/sec, lr:0.0005472161574289203 +2024-10-23 13:02:55,532:INFO:epoch[0], iter[500], loss:234.524876, fps:437.10 imgs/sec, lr:0.000683679012581706 +2024-10-23 13:03:25,046:INFO:epoch[0], iter[600], loss:235.185213, fps:434.08 imgs/sec, lr:0.0008201419259421527 +2024-10-23 13:03:54,585:INFO:epoch[0], iter[700], loss:228.878598, fps:433.48 imgs/sec, lr:0.0009566047810949385 +2024-10-23 13:04:23,932:INFO:epoch[0], iter[800], loss:219.259134, fps:436.29 imgs/sec, lr:0.0010930676944553852 +2024-10-23 13:04:52,707:INFO:epoch[0], iter[900], loss:225.741833, fps:444.84 imgs/sec, lr:0.001229530549608171 +2024-10-23 13:05:21,872:INFO:epoch[1], iter[1000], loss:218.811336, fps:438.91 imgs/sec, lr:0.0013659934047609568 +2024-10-23 13:05:51,216:INFO:epoch[1], iter[1100], loss:219.491889, fps:436.50 imgs/sec, lr:0.0015024563763290644 +2024-10-23 13:06:20,546:INFO:epoch[1], iter[1200], loss:219.895906, fps:436.57 imgs/sec, lr:0.0016389192314818501 +2024-10-23 13:06:49,521:INFO:epoch[1], iter[1300], loss:218.516680, fps:441.79 imgs/sec, lr:0.001775382086634636 +2024-10-23 13:07:18,303:INFO:epoch[1], iter[1400], loss:209.922935, fps:444.79 imgs/sec, lr:0.0019118449417874217 +2024-10-23 13:07:47,702:INFO:epoch[1], iter[1500], loss:210.997816, fps:435.60 imgs/sec, lr:0.0020483077969402075 +2024-10-23 13:08:16,482:INFO:epoch[1], iter[1600], loss:210.678421, fps:444.88 imgs/sec, lr:0.002184770768508315 +2024-10-23 13:08:45,568:INFO:epoch[1], iter[1700], loss:203.285874, fps:440.07 imgs/sec, lr:0.0023212337400764227 +2024-10-23 13:09:13,947:INFO:epoch[1], iter[1800], loss:203.014775, fps:451.11 imgs/sec, lr:0.0024576964788138866 +2024-10-23 13:09:42,954:INFO:epoch[2], iter[1900], loss:194.683969, fps:441.28 imgs/sec, lr:0.0025941594503819942 +... +``` + +## [推理过程](#目录) + +### 评估 + +在运行以下命令之前,请检查用于评估的检查点路径。以下脚本中使用的文件**best.ckpt**是最后保存的检查点文件。 + +```shell +# 使用python命令进行评估 +python eval.py \ + --data_dir=xxx/dataset \ + --pretrained=xxx/best.ckpt \ + --eval_shape=640 > log.txt 2>&1 & +``` + +上述python命令将在后台运行。您可以通过"log.txt"文件查看结果。测试数据集的mAP如下: + +```text +# log.txt +=============ip102 eval reulst========= +Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.341 +Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.620 +Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.348 +Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.340 +Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.342 +Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.346 +Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.414 +Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.435 +Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.435 +Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.420 +Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.377 +Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.440 +2024-10-23 16:34:04,653:INFO:testing cost time 0.35h +``` diff --git a/community/cv/ADCAM/ciou.py b/community/cv/ADCAM/ciou.py new file mode 100644 index 0000000000000000000000000000000000000000..a40e6f1df37a905c9dd68f0594386d24c568857e --- /dev/null +++ b/community/cv/ADCAM/ciou.py @@ -0,0 +1,117 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CIoU loss""" +# import numpy as np +import math +import mindspore +from mindspore import nn, ops + + +class CIou(nn.Cell): + """Calculating CIoU loss""" + + def __init__(self): + super(CIou, self).__init__() + self.min = ops.Minimum() + self.max = ops.Maximum() + self.sub = ops.Sub() + self.add = ops.Add() + self.mul = ops.Mul() + self.div = ops.RealDiv() + self.square = ops.Square() + self.sqrt = ops.Sqrt() + self.atan2 = ops.Atan2() + self.eps = 1e-7 + self.pi = mindspore.Tensor(math.pi, mindspore.float32) + self.cast = ops.Cast() + + def construct(self, boxes1, boxes2): + """ + Args: + boxes1: Tensor of shape (..., 4), format [xmin, ymin, xmax, ymax] + boxes2: Tensor of shape (..., 4), format [xmin, ymin, xmax, ymax] + Returns: + cious: Tensor of CIoU loss values + """ + boxes1 = self.cast(boxes1, mindspore.float32) + boxes2 = self.cast(boxes2, mindspore.float32) + + # Widths and heights + w1 = self.sub(boxes1[..., 2], boxes1[..., 0]) + h1 = self.sub(boxes1[..., 3], boxes1[..., 1]) + w2 = self.sub(boxes2[..., 2], boxes2[..., 0]) + h2 = self.sub(boxes2[..., 3], boxes2[..., 1]) + + w1 = self.max(w1, 0.0) + h1 = self.max(h1, 0.0) + w2 = self.max(w2, 0.0) + h2 = self.max(h2, 0.0) + + # Areas + area1 = self.mul(w1, h1) + area2 = self.mul(w2, h2) + + # Intersection + inter_left_up = self.max(boxes1[..., :2], boxes2[..., :2]) + inter_right_down = self.min(boxes1[..., 2:], boxes2[..., 2:]) + inter_wh = self.max(self.sub(inter_right_down, inter_left_up), 0.0) + inter_area = self.mul(inter_wh[..., 0], inter_wh[..., 1]) + + # Union + union_area = self.add(area1, area2) - inter_area + self.eps + + # IoU + ious = self.div(inter_area, union_area) + ious = ops.clip_by_value(ious, 0.0, 1.0) + + # Enclosing box + enclose_left_up = self.min(boxes1[..., :2], boxes2[..., :2]) + enclose_right_down = self.max(boxes1[..., 2:], boxes2[..., 2:]) + enclose_wh = self.max(self.sub(enclose_right_down, enclose_left_up), 0.0) + enclose_c2 = ( + self.square(enclose_wh[..., 0]) + self.square(enclose_wh[..., 1]) + self.eps + ) + + # Center distances + boxes1_center = self.mul(self.add(boxes1[..., :2], boxes1[..., 2:]), 0.5) + boxes2_center = self.mul(self.add(boxes2[..., :2], boxes2[..., 2:]), 0.5) + center_dist = self.square( + self.sub(boxes1_center[..., 0], boxes2_center[..., 0]) + ) + self.square(self.sub(boxes1_center[..., 1], boxes2_center[..., 1])) + + # Penalty term v + arctan1 = self.atan2(h1, w1) + arctan2 = self.atan2(h2, w2) + v = (4 / (self.pi**2)) * self.square(arctan1 - arctan2) + + # Alpha term + S = 1 - ious + alpha = v / (S + v + self.eps) + + # CIoU + ciou_term = self.div(center_dist, enclose_c2) + cious = ious - (ciou_term + alpha * v) + cious = ops.clip_by_value(cious, -1.0, 1.0) + + return cious + + +boxes_pred = mindspore.Tensor([[50, 50, 150, 150]], mindspore.float32) +boxes_target = mindspore.Tensor([[60, 60, 140, 140]], mindspore.float32) + +ciou_loss = CIou() + +loss = ciou_loss(boxes_pred, boxes_target) +print(loss) diff --git a/community/cv/ADCAM/cpp_infer/CMakeLists.txt b/community/cv/ADCAM/cpp_infer/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9db631e98656875af43113d278bfd3dff9ca246a --- /dev/null +++ b/community/cv/ADCAM/cpp_infer/CMakeLists.txt @@ -0,0 +1,27 @@ +cmake_minimum_required(VERSION 3.14.1) +project(Ascend310Infer) +add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined") +set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/) +option(MINDSPORE_PATH "mindspore install path" "") +include_directories(${MINDSPORE_PATH}) +include_directories(${MINDSPORE_PATH}/include) +include_directories(${PROJECT_SRC_ROOT}) +set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) +include_directories(${TOP_DIR}/utils/cpp_infer/example/) # common_inc in top dir +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/) # common_inc in local dir +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../) # common_inc in local dir + +if(EXISTS ${MINDSPORE_PATH}/lib/libmindspore-lite.so) + message(--------------- Compile-with-MindSpore-Lite ----------------) + set(MS_LIB ${MINDSPORE_PATH}/lib/libmindspore-lite.so) + set(MD_LIB ${MINDSPORE_PATH}/lib/libminddata-lite.so) + add_compile_definitions(ENABLE_LITE) +else() + message(--------------- Compile-with-MindSpore ----------------) + set(MS_LIB ${MINDSPORE_PATH}/lib/libmindspore.so) + file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*) +endif() + +add_executable(main src/main.cc) +target_link_libraries(main ${MS_LIB} ${MD_LIB}) diff --git a/community/cv/ADCAM/cpp_infer/build.sh b/community/cv/ADCAM/cpp_infer/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..fbe673eb993eb882207540450f1fbd12a90a22cd --- /dev/null +++ b/community/cv/ADCAM/cpp_infer/build.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +if [ ! -d out ]; then + mkdir out +fi +cd out || exit +if [ $MS_LITE_HOME ];then + MINDSPORE_PATH=$MS_LITE_HOME/runtime +else + MINDSPORE_PATH="`pip show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`" + if [[ ! $MINDSPORE_PATH ]];then + MINDSPORE_PATH="`pip show mindspore | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`" + fi +fi +cmake .. -DMINDSPORE_PATH=$MINDSPORE_PATH +make diff --git a/community/cv/ADCAM/cpp_infer/src/main.cc b/community/cv/ADCAM/cpp_infer/src/main.cc new file mode 100644 index 0000000000000000000000000000000000000000..18f095e76449489edd51eb8f8b39a30edee3f87a --- /dev/null +++ b/community/cv/ADCAM/cpp_infer/src/main.cc @@ -0,0 +1,128 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common_inc/infer.h" + +DEFINE_string(mindir_path, "", "mindir path"); +DEFINE_string(dataset_path, ".", "dataset path"); +DEFINE_int32(device_id, 0, "device id"); +DEFINE_string(device_type, "CPU", "device type"); +DEFINE_int32(image_height, 640, "image height"); +DEFINE_int32(image_width, 640, "image width"); + +int main(int argc, char **argv) { + if (!ParseCommandLineFlags(argc, argv)) { + std::cout << "Failed to parse args" << std::endl; + return 1; + } + if (RealPath(FLAGS_mindir_path).empty()) { + std::cout << "Invalid mindir" << std::endl; + return 1; + } + + auto ascend310 = std::make_shared(); + ascend310->SetDeviceID(FLAGS_device_id); + ascend310->SetPrecisionMode("preferred_fp32"); + ascend310->SetOpSelectImplMode("high_precision"); + ascend310->SetBufferOptimizeMode("off_optimize"); + mindspore::Model model; + if (!LoadModel(FLAGS_mindir_path, FLAGS_device_type, FLAGS_device_id, ascend310, &model)) { + std::cout << "Failed to load model " << FLAGS_mindir_path << ", device id: " << FLAGS_device_id + << ", device type: " << FLAGS_device_type; + return 1; + } + Status ret; + + auto all_files = GetAllFiles(FLAGS_dataset_path); + std::map costTime_map; + size_t size = all_files.size(); + std::shared_ptr decode(new Decode()); + auto resize = Resize({FLAGS_image_height, FLAGS_image_width}); + Execute composeDecode({decode}); + + for (size_t i = 0; i < size; ++i) { + struct timeval start = {0}; + struct timeval end = {0}; + double startTimeMs; + double endTimeMs; + std::vector inputs; + std::vector outputs; + auto imgDecode = MSTensor(); + auto img = MSTensor(); + composeDecode(ReadFileToTensor(all_files[i]), &imgDecode); + std::vector shape = imgDecode.Shape(); + + if ((static_cast(shape[0]) < static_cast(FLAGS_image_height)) && + (static_cast(shape[1]) < static_cast(FLAGS_image_width))) { + resize = Resize({FLAGS_image_height, FLAGS_image_width}, InterpolationMode::kCubic); + } else if ((static_cast(shape[0]) > static_cast(FLAGS_image_height)) && + (static_cast(shape[1]) > static_cast(FLAGS_image_width))) { + resize = Resize({FLAGS_image_height, FLAGS_image_width}, InterpolationMode::kNearestNeighbour); + } else { + resize = Resize({FLAGS_image_height, FLAGS_image_width}, InterpolationMode::kLinear); + } + if ((sizeof(shape) / sizeof(shape[0])) <= 2) { + std::cout << "image channels is not 3." << std::endl; + return 1; + } + Execute transform(resize); + transform(imgDecode, &img); + + std::vector model_inputs = model.GetInputs(); + inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(), + img.Data().get(), img.DataSize()); + gettimeofday(&start, nullptr); + ret = model.Predict(inputs, &outputs); + gettimeofday(&end, nullptr); + if (ret != kSuccess) { + std::cout << "Predict " << all_files[i] << " failed." << std::endl; + return 1; + } + startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000; + endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000; + costTime_map.insert(std::pair(startTimeMs, endTimeMs)); + WriteResult(all_files[i], outputs); + } + double average = 0.0; + int inferCount = 0; + + for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) { + double diff = iter->second - iter->first; + average += diff; + inferCount++; + } + average = average / inferCount; + std::stringstream timeCost; + timeCost << "NN inference cost average time: " << average << " ms of infer_count " << inferCount << std::endl; + std::cout << "NN inference cost average time: " << average << "ms of infer_count " << inferCount << std::endl; + std::string fileName = "./time_Result" + std::string("/test_perform_static.txt"); + std::ofstream fileStream(fileName.c_str(), std::ios::trunc); + fileStream << timeCost.str(); + fileStream.close(); + costTime_map.clear(); + return 0; +} diff --git a/community/cv/ADCAM/default_config.yaml b/community/cv/ADCAM/default_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..125ec991859bd7fdabf80638146213a3670ea267 --- /dev/null +++ b/community/cv/ADCAM/default_config.yaml @@ -0,0 +1,223 @@ + +enable_modelarts: False + +data_path: "/cache/data" +output_path: "/cache/train" +load_path: "/cache/checkpoint_path" +device_target: "Ascend" +need_modelarts_dataset_unzip: True +modelarts_dataset_unzip_name: "coco" + + +data_dir: "/data/coco" +per_batch_size: 16 +yolov5_version: "yolov5s" +pretrained_backbone: "" +resume_yolov5: "" +pretrained_checkpoint: "" +output_dir: "./output" +train_img_dir: "train2017" +train_ann_file: "annotations/instances_train2017.json" + +lr_scheduler: "cosine_annealing" +lr: 0.01 +lr_epochs: "220,250" +lr_gamma: 0.1 +eta_min: 0.0 +T_max: 320 +max_epoch: 300 +warmup_epochs: 20 +weight_decay: 0.0005 +momentum: 0.9 +loss_scale: 1024 +label_smooth: 0 +label_smooth_factor: 0.1 +log_interval: 100 +ckpt_path: "outputs/" +is_distributed: 0 +bind_cpu: True +device_num: 8 +rank: 0 +group_size: 1 +need_profiler: 0 +resize_rate: 10 +filter_weight: False +save_ckpt_interval: 1 +save_ckpt_max_num: 10 + + +pretrained: "" +log_path: "outputs/" +ann_val_file: "" +eval_nms_thresh: 0.6 +ignore_threshold: 0.7 +test_ignore_threshold: 0.001 +multi_label: True +multi_label_thresh: 0.1 + +save_prefix: "../eval_parallel" +run_eval: True +eval_epoch_interval: 10 +eval_start_epoch: 100 +eval_parallel: True +val_img_dir: "val2017" +val_ann_file: "annotations/instances_val2017.json" + + +batch_size: 1 +testing_shape: [640, 640] +ckpt_file: "" +file_name: "yolov5" +file_format: "MINDIR" +dataset_path: "" +ann_file: "" + + + +hue: 0.015 +saturation: 1.5 +value: 0.4 +jitter: 0.3 + +num_classes: 102 +# num_classes: 80 +max_box: 150 +checkpoint_filter_list: ['feature_map.back_block1.conv.weight', 'feature_map.back_block1.conv.bias', + 'feature_map.back_block2.conv.weight', 'feature_map.back_block2.conv.bias', + 'feature_map.back_block3.conv.weight', 'feature_map.back_block3.conv.bias'] + +# h->w +anchor_scales: [[12, 16], + [19, 36], + [40, 28], + [36, 75], + [76, 55], + [72, 146], + [142, 110], + [192, 243], + [459, 401]] + +out_channel: 255 # 3 * (num_classes + 5) + +input_shape: [[3, 32, 64, 128, 256, 512, 1], + [3, 48, 96, 192, 384, 768, 2], + [3, 64, 128, 256, 512, 1024, 3], + [3, 80, 160, 320, 640, 1280, 4]] + + +test_img_shape: [640, 640] + +# labels: [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', +# 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', +# 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', +# 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', +# 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', +# 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', +# 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', +# 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', +# 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', +# 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' ] + +labels: [ 'rice leaf roller', 'rice leaf caterpillar', 'paddy stem maggot', 'asiatic rice borer', + 'yellow rice borer', 'rice gall midge', 'Rice Stemfly', 'brown plant hopper', + 'white backed plant hopper', 'small brown plant hopper', 'rice water weevil', + 'rice leafhopper', 'grain spreader thrips', 'rice shell pest', 'grub', 'mole cricket', + 'wireworm', 'white margined moth', 'black cutworm', 'large cutworm', 'yellow cutworm', + 'red spider', 'corn borer', 'army worm', 'aphids', 'Potosiabre vitarsis', 'peach borer', + 'english grain aphid', 'green bug', 'bird cherry-oataphid', 'wheat blossom midge', + 'penthaleus major', 'longlegged spider mite', 'wheat phloeothrips', 'wheat sawfly', + 'cerodonta denticornis', 'beet fly', 'flea beetle', 'cabbage army worm', + 'beet army worm', 'Beet spot flies', 'meadow moth', 'beet weevil', + 'sericaorient alismots chulsky', 'alfalfa weevil', 'flax budworm', + 'alfalfa plant bug', 'tarnished plant bug', 'Locustoidea', 'lytta polita', + 'legume blister beetle', 'blister beetle', 'therioaphis maculata Buckton', + 'odontothrips loti', 'Thrips', 'alfalfa seed chalcid', 'Pieris canidia', + 'Apolygus lucorum', 'Limacodidae', 'Viteus vitifoliae', 'Colomerus vitis', + 'Brevipoalpus lewisi McGregor', 'oides decempunctata', 'Polyphagotars onemus latus', + 'Pseudococcus comstocki Kuwana', 'parathrene regalis', 'Ampelophaga', + 'Lycorma delicatula', 'Xylotrechus', 'Cicadella viridis', 'Miridae', + 'Trialeurodes vaporariorum', 'Erythroneura apicalis', 'Papilio xuthus', + 'Panonchus citri McGregor', 'Phyllocoptes oleiverus ashmead', 'Icerya purchasi Maskell', + 'Unaspis yanonensis', 'Ceroplastes rubens', 'Chrysomphalus aonidum', + 'Parlatoria zizyphus Lucus', 'Nipaecoccus vastalor', 'Aleurocanthus spiniferus', + 'Tetradacus c Bactrocera minax', 'Dacus dorsalis(Hendel)', 'Bactrocera tsuneonis', + 'Prodenia litura', 'Adristyrannus', 'Phyllocnistis citrella Stainton', + 'Toxoptera citricidus', 'Toxoptera aurantii', 'Aphis citricola Vander Goot', + 'Scirtothrips dorsalis Hood', 'Dasineura sp', 'Lawana imitata Melichar', + 'Salurnis marginella Guerr', 'Deporaus marginatus Pascoe', 'Chlumetia transversa', + 'Mango flat beak leafhopper', 'Rhytidodera bowrinii white', 'Sternochetus frigidus', + 'Cicadellidae' ] + + +# coco_ids: [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, +# 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, +# 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, +# 81, 82, 84, 85, 86, 87, 88, 89, 90 ] +coco_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, + 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, + 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, + 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101] + +result_files: './result_Files' + +--- + + +# data_dir: "Train dataset directory." +# per_batch_size: "Batch size for Training." +# pretrained_backbone: "The ckpt file of CspDarkNet53." +# resume_yolov5: "The ckpt file of YOLOv5, which used to fine tune." +# pretrained_checkpoint: "The ckpt file of YOLOv5CspDarkNet53." +# lr_scheduler: "Learning rate scheduler, options: exponential, cosine_annealing." +# lr: "Learning rate." +# lr_epochs: "Epoch of changing of lr changing, split with ','." +# lr_gamma: "Decrease lr by a factor of exponential lr_scheduler." +# eta_min: "Eta_min in cosine_annealing scheduler." +# T_max: "T-max in cosine_annealing scheduler." +# max_epoch: "Max epoch num to train the model." +# warmup_epochs: "Warmup epochs." +# weight_decay: "Weight decay factor." +# momentum: "Momentum." +# loss_scale: "Static loss scale." +# label_smooth: "Whether to use label smooth in CE." +# label_smooth_factor: "Smooth strength of original one-hot." +# log_interval: "Logging interval steps." +# ckpt_path: "Checkpoint save location." +# ckpt_interval: "Save checkpoint interval." +# is_save_on_master: "Save ckpt on master or all rank, 1 for master, 0 for all ranks." +# is_distributed: "Distribute train or not, 1 for yes, 0 for no." +# bind_cpu: "Whether bind cpu when distributed training." +# device_num: "Device numbers per server" +# rank: "Local rank of distributed." +# group_size: "World size of device." +# need_profiler: "Whether use profiler. 0 for no, 1 for yes." +# resize_rate: "Resize rate for multi-scale training." +# ann_file: "path to annotation" +# each_multiscale: "Apply multi-scale for each scale" +# labels: "the label of train data" +# multi_label: "use multi label to nms" +# multi_label_thresh: "multi label thresh" +# train_img_dir: "relative path of training image directory to data_dir" +# train_ann_file: "relative path of training annotation file to data_dir" + + +# pretrained: "model_path, local pretrained model to load" +# log_path: "checkpoint save location" +# save_prefix: "../eval_parallel" +# run_eval: "Whether enable validation after a training epoch" +# eval_epoch_interval: "Epoch interval to do validation" +# eval_start_epoch: "After which epoch, start to do validatation" +# eval_parallel: "Whether enable parallel evaluation to accelerate the validataion process" +# val_img_dir: "relative path of validation image directory to data_dir" +# val_ann_file: "relative path of validataion annotation file to data_dir" + + + +# device_id: "Device id for export" +# batch_size: "batch size for export" +# testing_shape: "shape for test" +# ckpt_file: "Checkpoint file path for export" +# file_name: "output file name for export" +# file_format: "file format for export" +# result_files: 'path to 310 infer result floder' diff --git a/community/cv/ADCAM/eval.py b/community/cv/ADCAM/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..28bdf04e69049bae32e2a115a88647a7dfb05a6e --- /dev/null +++ b/community/cv/ADCAM/eval.py @@ -0,0 +1,125 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" eval.py """ +import os +import time +import shutil + +import mindspore +from mindspore import ParallelMode +from mindspore.communication.management import init, get_group_size, get_rank + +from src.yolo import YOLOV5 +from src.logger import get_logger +from src.util import DetectionEngine, EvalWrapper +from src.yolo_dataset import create_yolo_dataset + +from model_utils.config import config +from model_utils.moxing_adapter import moxing_wrapper, modelarts_pre_process + + + +def eval_preprocess(): + """ modelarts env prepare""" + config.val_img_dir = os.path.join(config.data_dir, config.val_img_dir) + config.val_ann_file = os.path.join(config.data_dir, config.val_ann_file) + device_id = int(os.getenv("DEVICE_ID", "0")) + mindspore.set_context( + mode=0, device_target=config.device_target, device_id=device_id + ) + parallel_mode = ParallelMode.STAND_ALONE + config.eval_parallel = config.is_distributed and config.eval_parallel + device_num = 1 + if config.eval_parallel: + init() + config.rank = get_rank() + config.group_size = get_group_size() + device_num = get_group_size() + parallel_mode = ParallelMode.DATA_PARALLEL + mindspore.reset_auto_parallel_context() + mindspore.set_auto_parallel_context( + parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num + ) + config.logger = get_logger(config.output_dir, device_id) + + +def load_parameters(network, filename): + """load parameters""" + config.logger.info("yolov5 pretrained network model: %s", filename) + param_dict = mindspore.load_checkpoint(filename) + param_dict_new = {} + for key, values in param_dict.items(): + if key.startswith("moments."): + continue + elif key.startswith("yolo_network."): + param_dict_new[key[13:]] = values + else: + param_dict_new[key] = values + mindspore.load_param_into_net(network, param_dict_new) + config.logger.info("load_model %s success", filename) + + +@moxing_wrapper(pre_process=modelarts_pre_process, pre_args=[config]) +def run_eval(): + """run_eval""" + eval_preprocess() + start_time = time.time() + config.logger.info("Creating Network....") + dict_version = {"yolov5s": 0, "yolov5m": 1, "yolov5l": 2, "yolov5x": 3} + network = YOLOV5(is_training=False, version=dict_version[config.yolov5_version]) + + if os.path.isfile(config.pretrained): + load_parameters(network, config.pretrained) + else: + raise FileNotFoundError(f"{config.pretrained} is not a filename.") + rank_id = int(os.getenv("RANK_ID", "0")) + if config.eval_parallel: + rank_id = get_rank() + ds = create_yolo_dataset( + config.val_img_dir, + config.val_ann_file, + is_training=False, + batch_size=config.per_batch_size, + device_num=config.group_size, + rank=rank_id, + shuffle=False, + config=config, + ) + + config.logger.info("testing shape : %s", config.test_img_shape) + config.logger.info( + "total %d images to eval", ds.get_dataset_size() * config.per_batch_size + ) + + network.set_train(False) + + detection = DetectionEngine(config, config.test_ignore_threshold) + if config.eval_parallel: + if os.path.exists(config.save_prefix): + shutil.rmtree(config.save_prefix, ignore_errors=True) + + config.logger.info("Start inference....") + eval_wrapper = EvalWrapper(config, network, ds, detection) + eval_wrapper.inference() + eval_result, _ = eval_wrapper.get_results() + + cost_time = time.time() - start_time + eval_log_string = "\n=============coco eval result=========\n" + eval_result + config.logger.info(eval_log_string) + config.logger.info("testing cost time %.2f h", cost_time / 3600.0) + + +if __name__ == "__main__": + run_eval() diff --git a/community/cv/ADCAM/fusion_result.json b/community/cv/ADCAM/fusion_result.json new file mode 100644 index 0000000000000000000000000000000000000000..23dcbbe9e44730f0ad51860afce5adafbaf8134c --- /dev/null +++ b/community/cv/ADCAM/fusion_result.json @@ -0,0 +1,132 @@ +[{ + "graph_fusion": { + "AABiasaddConvFusion": { + "effect_times": "0", + "match_times": "12" + }, + "AReduceMeanFusionPass": { + "effect_times": "0", + "match_times": "6" + }, + "ARefreshCubeC0FusionPass": { + "effect_times": "71", + "match_times": "71" + }, + "CastRemoveFusionPass": { + "effect_times": "0", + "match_times": "1" + }, + "ConstToAttrPass": { + "effect_times": "16", + "match_times": "18" + }, + "ConstToAttrStridedSliceFusion": { + "effect_times": "22", + "match_times": "22" + }, + "ConvConcatFusionPass": { + "effect_times": "0", + "match_times": "20" + }, + "ConvFormatRefreshFusionPass": { + "effect_times": "0", + "match_times": "71" + }, + "ConvToFullyConnectionFusionPass": { + "effect_times": "0", + "match_times": "71" + }, + "ConvWeightCompressFusionPass": { + "effect_times": "0", + "match_times": "71" + }, + "CubeTransFixpipeFusionPass": { + "effect_times": "71", + "match_times": "71" + }, + "FIXPIPEAPREQUANTFUSIONPASS": { + "effect_times": "0", + "match_times": "71" + }, + "FIXPIPEFUSIONPASS": { + "effect_times": "0", + "match_times": "71" + }, + "ForceFp16CastFusionPass": { + "effect_times": "0", + "match_times": "1" + }, + "MulAddFusionPass": { + "effect_times": "0", + "match_times": "17" + }, + "MulSquareFusionPass": { + "effect_times": "0", + "match_times": "74" + }, + "RealDiv2MulsFusionPass": { + "effect_times": "0", + "match_times": "7" + }, + "RefreshInt64ToInt32FusionPass": { + "effect_times": "1", + "match_times": "1" + }, + "RemoveCastFusionPass": { + "effect_times": "0", + "match_times": "1" + }, + "ReshapeTransposeFusionPass": { + "effect_times": "0", + "match_times": "3" + }, + "SplitConvConcatFusionPass": { + "effect_times": "0", + "match_times": "20" + }, + "StridedSliceRemovePass": { + "effect_times": "0", + "match_times": "22" + }, + "SubFusionPass": { + "effect_times": "0", + "match_times": "4" + }, + "TileConstToAttrFusion": { + "effect_times": "6", + "match_times": "6" + }, + "TransdataCastFusionPass": { + "effect_times": "0", + "match_times": "223" + }, + "TransdataFz2FzgFusionPass": { + "effect_times": "0", + "match_times": "211" + }, + "TransdataFzg2FzFusionPass": { + "effect_times": "0", + "match_times": "211" + }, + "TransposedUpdateFusionPass": { + "effect_times": "10", + "match_times": "10" + }, + "ZConcatDFusionPass": { + "effect_times": "0", + "match_times": "20" + }, + "ZConcatFusionPass": { + "effect_times": "20", + "match_times": "20" + } + }, + "session_and_graph_id": "0_1", + "ub_fusion": { + "AutomaticUbFusion": { + "effect_times": "78", + "match_times": "78", + "repository_hit_times": "0" + } + } +}] \ No newline at end of file diff --git a/community/cv/ADCAM/infer/Dockerfile b/community/cv/ADCAM/infer/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..360861ede17fb0ab697fbcac190acde7c1e29fef --- /dev/null +++ b/community/cv/ADCAM/infer/Dockerfile @@ -0,0 +1,5 @@ +ARG FROM_IMAGE_NAME +FROM ${FROM_IMAGE_NAME} + +COPY requirements.txt . +RUN pip3.7 install -r requirements.txt diff --git a/community/cv/ADCAM/infer/convert/atc_model_convert.sh b/community/cv/ADCAM/infer/convert/atc_model_convert.sh new file mode 100644 index 0000000000000000000000000000000000000000..00f7a46e33727b7d3f915f915cdd5cdac50a18e2 --- /dev/null +++ b/community/cv/ADCAM/infer/convert/atc_model_convert.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +model_path=../data/models/yolov5.air +output_model_name=../data/models/yolov5 + +atc --framework=1 \ + --model="${model_path}" \ + --input_shape="actual_input_1:1,12,320,320" \ + --output="${output_model_name}" \ + --enable_small_channel=1 \ + --log=error \ + --soc_version=Ascend310 \ + --op_select_implmode=high_precision \ + --output_type=FP32 +exit 0 diff --git a/community/cv/ADCAM/infer/data/models/coco2017.names b/community/cv/ADCAM/infer/data/models/coco2017.names new file mode 100644 index 0000000000000000000000000000000000000000..1db41f581a4b1b54086cfc44f81a192b262ad63c --- /dev/null +++ b/community/cv/ADCAM/infer/data/models/coco2017.names @@ -0,0 +1,81 @@ +# This file is originally from https://github.com/pjreddie/darknet/blob/master/data/coco.names +person +bicycle +car +motorbike +aeroplane +bus +train +truck +boat +traffic light +fire hydrant +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +backpack +umbrella +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +sofa +pottedplant +bed +diningtable +toilet +tvmonitor +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +book +clock +vase +scissors +teddy bear +hair drier +toothbrush diff --git a/community/cv/ADCAM/infer/data/models/yolov5.cfg b/community/cv/ADCAM/infer/data/models/yolov5.cfg new file mode 100644 index 0000000000000000000000000000000000000000..45773713c37648c9de4a25c610350f6eaee7f8a9 --- /dev/null +++ b/community/cv/ADCAM/infer/data/models/yolov5.cfg @@ -0,0 +1,10 @@ +CLASS_NUM=80 +BIASES_NUM=18 +BIASES=10,13,16,30,33,23,30,61,62,45,59,119,116,90,156,198,373,326 +SCORE_THRESH=0.001 +OBJECTNESS_THRESH=0.001 +IOU_THRESH=0.6 +YOLO_TYPE=3 +ANCHOR_DIM=3 +MODEL_TYPE=0 + diff --git a/community/cv/ADCAM/infer/docker_start_infer.sh b/community/cv/ADCAM/infer/docker_start_infer.sh new file mode 100644 index 0000000000000000000000000000000000000000..7ce943468ec6c33b5196317305d2ddab4bbd8831 --- /dev/null +++ b/community/cv/ADCAM/infer/docker_start_infer.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +docker_image=$1 +data_dir=$2 + +function show_help() { + echo "Usage: docker_start.sh docker_image data_dir" +} + +function param_check() { + if [ -z "${docker_image}" ]; then + echo "please input docker_image" + show_help + exit 1 + fi + + if [ -z "${data_dir}" ]; then + echo "please input data_dir" + show_help + exit 1 + fi +} + +param_check + +docker run -it \ + --device=/dev/davinci0 \ + --device=/dev/davinci_manager \ + --device=/dev/devmm_svm \ + --device=/dev/hisi_hdc \ + -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ + -v ${data_dir}:${data_dir} \ + ${docker_image} \ + /bin/bash diff --git a/community/cv/ADCAM/infer/mxbase/CMakeLists.txt b/community/cv/ADCAM/infer/mxbase/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..86df663487d3ab2fa1719bc0782b7f34b98f1f41 --- /dev/null +++ b/community/cv/ADCAM/infer/mxbase/CMakeLists.txt @@ -0,0 +1,52 @@ +cmake_minimum_required(VERSION 3.5.2) +SET(CMAKE_BUILD_TYPE "Debug") +SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++0x") +SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -g -ggdb") +SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall") +project(yolov5) +add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) + +set(TARGET_LIBRARY yolov5postprocessor) +set(TARGET_MAIN yolov5) + +set(ACL_LIB_PATH $ENV{ASCEND_HOME}/ascend-toolkit/latest/acllib) + +include_directories(${CMAKE_CURRENT_BINARY_DIR}) + +include_directories($ENV{MX_SDK_HOME}/include) +include_directories($ENV{MX_SDK_HOME}/opensource/include) +include_directories($ENV{MX_SDK_HOME}/opensource/include/opencv4) +include_directories($ENV{MX_SDK_HOME}/opensource/include/gstreamer-1.0) +include_directories($ENV{MX_SDK_HOME}/opensource/include/glib-2.0) +include_directories($ENV{MX_SDK_HOME}/opensource/lib/glib-2.0/include) +include_directories($ENV{MX_SDK_HOME}/ascend-toolkit/latest/include) + +link_directories($ENV{MX_SDK_HOME}/lib) +link_directories($ENV{MX_SDK_HOME}/opensource/lib/) + +add_compile_options(-std=c++11 -fPIC -fstack-protector-all -pie -Wno-deprecated-declarations) +add_compile_options("-DPLUGIN_NAME=${PLUGIN_NAME}") +add_compile_options("-Dgoogle=mindxsdk_private") + +add_definitions(-DENABLE_DVPP_INTERFACE) + +include_directories(${ACL_LIB_PATH}/include) +link_directories(${ACL_LIB_PATH}/lib64/) + +add_compile_options(-std=c++11 -fPIC -fstack-protector-all -pie -Wno-deprecated-declarations) +add_compile_options("-DPLUGIN_NAME=${PLUGIN_NAME}") +add_compile_options("-Dgoogle=mindxsdk_private") + +add_library(${TARGET_LIBRARY} SHARED src/PostProcess/Yolov5MindSporePost.cpp) + +target_link_libraries(${TARGET_LIBRARY} glib-2.0 gstreamer-1.0 gobject-2.0 gstbase-1.0 gmodule-2.0) +target_link_libraries(${TARGET_LIBRARY} plugintoolkit mxpidatatype mxbase) +target_link_libraries(${TARGET_LIBRARY} -Wl,-z,relro,-z,now,-z,noexecstack) + +message("TARGET_LIBRARY:${TARGET_LIBRARY}.") + +add_executable(${TARGET_MAIN} src/main.cpp src/Yolov5Detection.cpp) +target_link_libraries(${TARGET_MAIN} glog cpprest mxbase libascendcl.so + libruntime.so libopencv_world.so.4.3 opencv_world) +target_link_libraries(${TARGET_MAIN} ${TARGET_LIBRARY} glog cpprest mxbase libascendcl.so + libruntime.so libopencv_world.so.4.3 opencv_world) diff --git a/community/cv/ADCAM/infer/mxbase/build.sh b/community/cv/ADCAM/infer/mxbase/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..df079d2c8d7a3119f2776bc085c6d70ef39ebe73 --- /dev/null +++ b/community/cv/ADCAM/infer/mxbase/build.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +path_cur=$(dirname $0) + +function check_env() +{ + # set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user + if [ ! "${ASCEND_HOME}" ]; then + export ASCEND_HOME=/usr/local/Ascend/ + echo "Set ASCEND_HOME to the default value: ${ASCEND_HOME}" + else + echo "ASCEND_HOME is set to ${ASCEND_HOME} by user" + fi + + if [ ! "${ASCEND_VERSION}" ]; then + export ASCEND_VERSION=nnrt/latest + echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}" + else + echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user" + fi + + if [ ! "${ARCH_PATTERN}" ]; then + # set ARCH_PATTERN to ./ when it was not specified by user + export ARCH_PATTERN=./ + echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}" + else + echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user" + fi +} + +function build_east() +{ + cd $path_cur + rm -rf build + mkdir -p build + cd build + cmake .. + make + ret=$? + if [ ${ret} -ne 0 ]; then + echo "Failed to build east." + exit ${ret} + fi + make install +} + +check_env +build_east diff --git a/community/cv/ADCAM/infer/mxbase/src/PostProcess/Yolov5MindSporePost.cpp b/community/cv/ADCAM/infer/mxbase/src/PostProcess/Yolov5MindSporePost.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8b433f1731f7c7f7863e381d30c36ab0379f4de9 --- /dev/null +++ b/community/cv/ADCAM/infer/mxbase/src/PostProcess/Yolov5MindSporePost.cpp @@ -0,0 +1,290 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Yolov5MindSporePost.h" +#include +#include +#include +#include "MxBase/Log/Log.h" +#include "MxBase/CV/ObjectDetection/Nms/Nms.h" + +namespace { + const int SCALE = 32; + const int BIASESDIM = 2; + const int OFFSETWIDTH = 2; + const int OFFSETHEIGHT = 3; + const int OFFSETBIASES = 1; + const int OFFSETOBJECTNESS = 1; + + const int NHWC_HEIGHTINDEX = 1; + const int NHWC_WIDTHINDEX = 2; + const int NCHW_HEIGHTINDEX = 2; + const int NCHW_WIDTHINDEX = 3; + const int YOLO_INFO_DIM = 5; + + auto uint8Deleter = [] (uint8_t* p) { }; +} // namespace + +namespace localParameter { + const uint32_t VECTOR_FIRST_INDEX = 0; + const uint32_t VECTOR_SECOND_INDEX = 1; + const uint32_t VECTOR_THIRD_INDEX = 2; + const uint32_t VECTOR_FOURTH_INDEX = 3; + const uint32_t VECTOR_FIFTH_INDEX = 4; +} + +namespace MxBase { + Yolov5PostProcess& Yolov5PostProcess::operator=(const Yolov5PostProcess &other) { + if (this == &other) { + return *this; + } + ObjectPostProcessBase::operator=(other); + objectnessThresh_ = other.objectnessThresh_; // Threshold of objectness value + iouThresh_ = other.iouThresh_; + anchorDim_ = other.anchorDim_; + biasesNum_ = other.biasesNum_; + yoloType_ = other.yoloType_; + modelType_ = other.modelType_; + inputType_ = other.inputType_; + biases_ = other.biases_; + return *this; + } + + APP_ERROR Yolov5PostProcess::Init(const std::map>& postConfig) { + LogDebug << "Start to Init Yolov5PostProcess."; + APP_ERROR ret = ObjectPostProcessBase::Init(postConfig); + if (ret != APP_ERR_OK) { + LogError << GetError(ret) << "Fail to superInit in ObjectPostProcessBase."; + return ret; + } + + configData_.GetFileValue("BIASES_NUM", biasesNum_); + std::string str; + configData_.GetFileValue("BIASES", str); + configData_.GetFileValue("OBJECTNESS_THRESH", objectnessThresh_); + configData_.GetFileValue("IOU_THRESH", iouThresh_); + configData_.GetFileValue("YOLO_TYPE", yoloType_); + configData_.GetFileValue("MODEL_TYPE", modelType_); + configData_.GetFileValue("YOLO_VERSION", yoloVersion_); + configData_.GetFileValue("INPUT_TYPE", inputType_); + configData_.GetFileValue("ANCHOR_DIM", anchorDim_); + ret = GetBiases(&str); + if (ret != APP_ERR_OK) { + LogError << GetError(ret) << "Failed to get biases."; + return ret; + } + LogDebug << "End to Init Yolov5PostProcess."; + return APP_ERR_OK; + } + + APP_ERROR Yolov5PostProcess::DeInit() { + return APP_ERR_OK; + } + + bool Yolov5PostProcess::IsValidTensors(const std::vector &tensors) const { + if (tensors.size() != (size_t)yoloType_) { + LogError << "number of tensors (" << tensors.size() << ") " << "is unequal to yoloType_(" + << yoloType_ << ")"; + return false; + } + if (yoloVersion_ == YOLOV5_VERSION) { + for (size_t i = 0; i < tensors.size(); i++) { + auto shape = tensors[i].GetShape(); + if (shape.size() < localParameter::VECTOR_FIFTH_INDEX) { + LogError << "dimensions of tensor [" << i << "] is less than " << + localParameter::VECTOR_FIFTH_INDEX << "."; + return false; + } + uint32_t channelNumber = 1; + int startIndex = modelType_ ? localParameter::VECTOR_SECOND_INDEX : localParameter::VECTOR_FOURTH_INDEX; + int endIndex = modelType_ ? (shape.size() - localParameter::VECTOR_THIRD_INDEX) : shape.size(); + for (int j = startIndex; j < endIndex; j++) { + channelNumber *= shape[j]; + } + if (channelNumber != anchorDim_ * (classNum_ + YOLO_INFO_DIM)) { + LogError << "channelNumber(" << channelNumber << ") != anchorDim_ * (classNum_ + 5)."; + return false; + } + } + } + return true; + } + + void Yolov5PostProcess::ObjectDetectionOutput(const std::vector& tensors, + std::vector> *objectInfos, + const std::vector& resizedImageInfos) { + LogDebug << "Yolov5PostProcess start to write results."; + if (tensors.size() == 0) { + return; + } + auto shape = tensors[0].GetShape(); + if (shape.size() == 0) { + return; + } + uint32_t batchSize = shape[0]; + for (uint32_t i = 0; i < batchSize; i++) { + std::vector> featLayerData = {}; + std::vector> featLayerShapes = {}; + for (uint32_t j = 0; j < tensors.size(); j++) { + auto dataPtr = reinterpret_cast (tensors[j].GetBuffer()) + + i * tensors[j].GetByteSize() / batchSize; + std::shared_ptr tmpPointer; + tmpPointer.reset(dataPtr, uint8Deleter); + featLayerData.push_back(tmpPointer); + shape = tensors[j].GetShape(); + std::vector featLayerShape(shape.size()); + transform(shape.begin(), shape.end(), featLayerShape.begin(), [](uint32_t s) { return (size_t)s; }); + featLayerShapes.push_back(featLayerShape); + } + std::vector objectInfo; + GenerateBbox(featLayerData, &objectInfo, featLayerShapes, resizedImageInfos[i].widthResize, + resizedImageInfos[i].heightResize); + MxBase::NmsSort(objectInfo, iouThresh_); + objectInfos->push_back(objectInfo); + } + LogDebug << "Yolov5PostProcess write results success."; + } + + APP_ERROR Yolov5PostProcess::Process(const std::vector &tensors, + std::vector> &objectInfos, + const std::vector &resizedImageInfos, + const std::map> &configParamMap) { + LogDebug << "Start to Process Yolov5PostProcess."; + APP_ERROR ret = APP_ERR_OK; + auto inputs = tensors; + ret = CheckAndMoveTensors(inputs); + if (ret != APP_ERR_OK) { + LogError << "CheckAndMoveTensors failed. ret=" << ret; + return ret; + } + + ObjectDetectionOutput(inputs, &objectInfos, resizedImageInfos); + + for (uint32_t i = 0; i < resizedImageInfos.size(); i++) { + CoordinatesReduction(i, resizedImageInfos[i], objectInfos[i]); + } + LogObjectInfos(objectInfos); + LogDebug << "End to Process Yolov5PostProcess."; + return APP_ERR_OK; + } + + void Yolov5PostProcess::CompareProb(int *classID, float *maxProb, float classProb, int classNum) { + if (classProb > (*maxProb)) { + (*maxProb) = classProb; + (*classID) = classNum; + } + } + + void Yolov5PostProcess::SelectClassNHWC(std::shared_ptr netout, NetInfo info, + std::vector *detBoxes, int stride) { + const int offsetY = 1; + for (int j = 0; j < stride; ++j) { + for (int k = 0; k < info.anchorDim; ++k) { + int bIdx = (info.bboxDim + 1 + info.classNum) * info.anchorDim * j + + k * (info.bboxDim + 1 + info.classNum); + int oIdx = bIdx + info.bboxDim; // objectness index + float objectness = static_cast(netout.get())[oIdx]; + if (objectness < objectnessThresh_) { + continue; + } + int classID = -1; + float maxProb = scoreThresh_; + for (int c = 0; c < info.classNum; ++c) { + float clsProb = static_cast(netout.get())[bIdx + (info.bboxDim + + OFFSETOBJECTNESS + c)] * objectness; + CompareProb(&classID, &maxProb, clsProb, c); + } + if (classID < 0) continue; + MxBase::ObjectInfo det; + float x = static_cast(netout.get())[bIdx]; + float y = static_cast(netout.get())[bIdx + offsetY]; + float width = static_cast(netout.get())[bIdx + OFFSETWIDTH]; + float height = static_cast(netout.get())[bIdx + OFFSETHEIGHT]; + det.x0 = std::max(0.0f, x - width / COORDINATE_PARAM); + det.x1 = std::min(1.0f, x + width / COORDINATE_PARAM); + det.y0 = std::max(0.0f, y - height / COORDINATE_PARAM); + det.y1 = std::min(1.0f, y + height / COORDINATE_PARAM); + det.classId = classID; + det.className = configData_.GetClassName(classID); + det.confidence = maxProb; + if (det.confidence < separateScoreThresh_[classID]) continue; + detBoxes->emplace_back(det); + } + } + } + + void Yolov5PostProcess::GenerateBbox(std::vector> featLayerData, + std::vector *detBoxes, + const std::vector>& featLayerShapes, const int netWidth, + const int netHeight) { + NetInfo netInfo; + netInfo.anchorDim = anchorDim_; + netInfo.bboxDim = BOX_DIM; + netInfo.classNum = classNum_; + netInfo.netWidth = netWidth; + netInfo.netHeight = netHeight; + for (int i = 0; i < yoloType_; ++i) { + int widthIndex_ = modelType_ ? NCHW_WIDTHINDEX : NHWC_WIDTHINDEX; + int heightIndex_ = modelType_ ? NCHW_HEIGHTINDEX : NHWC_HEIGHTINDEX; + OutputLayer layer = {featLayerShapes[i][widthIndex_], featLayerShapes[i][heightIndex_]}; + int logOrder = log(featLayerShapes[i][widthIndex_] * SCALE / netWidth) / log(BIASESDIM); + int startIdx = (yoloType_ - 1 - logOrder) * netInfo.anchorDim * BIASESDIM; + int endIdx = startIdx + netInfo.anchorDim * BIASESDIM; + int idx = 0; + for (int j = startIdx; j < endIdx; ++j) { + layer.anchors[idx++] = biases_[j]; + } + int stride = layer.width * layer.height; + std::shared_ptr netout = featLayerData[i]; + SelectClassNHWC(netout, netInfo, detBoxes, stride); + } + } + + APP_ERROR Yolov5PostProcess::GetBiases(std::string *strBiases) { + if (biasesNum_ <= 0) { + LogError << GetError(APP_ERR_COMM_INVALID_PARAM) << "Failed to get biasesNum (" << biasesNum_ << ")."; + return APP_ERR_COMM_INVALID_PARAM; + } + biases_.clear(); + int i = 0; + int num = strBiases->find(","); + while (num >= 0 && i < biasesNum_) { + std::string tmp = strBiases->substr(0, num); + num++; + (*strBiases) = strBiases->substr(num, strBiases->size()); + biases_.push_back(stof(tmp)); + i++; + num = strBiases->find(","); + } + if (i != biasesNum_ - 1 || strBiases->size() == 0) { + LogError << GetError(APP_ERR_COMM_INVALID_PARAM) << "biasesNum (" << biasesNum_ + << ") is not equal to total number of biases (" << (*strBiases) <<")."; + return APP_ERR_COMM_INVALID_PARAM; + } + biases_.push_back(stof((*strBiases))); + return APP_ERR_OK; + } + +#ifndef ENABLE_POST_PROCESS_INSTANCE + extern "C" { + std::shared_ptr GetObjectInstance() { + LogInfo << "Begin to get Yolov5PostProcess instance."; + auto instance = std::make_shared(); + LogInfo << "End to get Yolov5PostProcess instance."; + return instance; + } + } +#endif +} // namespace MxBase diff --git a/community/cv/ADCAM/infer/mxbase/src/PostProcess/Yolov5MindSporePost.h b/community/cv/ADCAM/infer/mxbase/src/PostProcess/Yolov5MindSporePost.h new file mode 100644 index 0000000000000000000000000000000000000000..deaf93bd1a6531c43d8667dabb958076d929831e --- /dev/null +++ b/community/cv/ADCAM/infer/mxbase/src/PostProcess/Yolov5MindSporePost.h @@ -0,0 +1,104 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef YOLOV5_POST_PROCESS_H +#define YOLOV5_POST_PROCESS_H +#include +#include +#include +#include +#include +#include +#include "MxBase/ErrorCode/ErrorCode.h" +#include "MxBase/CV/Core/DataType.h" +#include "MxBase/PostProcessBases/ObjectPostProcessBase.h" + +const float DEFAULT_OBJECTNESS_THRESH = 0.3; +const float DEFAULT_IOU_THRESH = 0.6; +const int DEFAULT_ANCHOR_DIM = 3; +const int DEFAULT_BIASES_NUM = 18; +const int DEFAULT_YOLO_TYPE = 3; +const int DEFAULT_YOLO_VERSION = 5; +const int YOLOV3_VERSION = 3; +const int YOLOV4_VERSION = 4; +const int YOLOV5_VERSION = 5; +const int ANCHOR_NUM = 9; +struct OutputLayer { + size_t width; + size_t height; + float anchors[ANCHOR_NUM]; +}; + +struct NetInfo { + int anchorDim; + int classNum; + int bboxDim; + int netWidth; + int netHeight; +}; + +namespace MxBase { +class Yolov5PostProcess : public ObjectPostProcessBase { + public: + Yolov5PostProcess() = default; + + ~Yolov5PostProcess() = default; + + Yolov5PostProcess(const Yolov5PostProcess &other) = default; + + Yolov5PostProcess &operator=(const Yolov5PostProcess &other); + + APP_ERROR Init(const std::map> &postConfig) override; + + APP_ERROR DeInit() override; + + APP_ERROR Process(const std::vector &tensors, std::vector> &objectInfos, + const std::vector &resizedImageInfos = {}, + const std::map> &configParamMap = {}) override; + + protected: + bool IsValidTensors(const std::vector &tensors) const; + + void ObjectDetectionOutput(const std::vector &tensors, + std::vector> *objectInfos, + const std::vector &resizedImageInfos = {}); + + void CompareProb(int *classID, float *maxProb, float classProb, int classNum); + void SelectClassNHWC(std::shared_ptr netout, NetInfo info, std::vector *detBoxes, + int stride); + void GenerateBbox(std::vector> featLayerData, + std::vector *detBoxes, + const std::vector>& featLayerShapes, + const int netWidth, const int netHeight); + APP_ERROR GetBiases(std::string *strBiases); + + protected: + float objectnessThresh_ = DEFAULT_OBJECTNESS_THRESH; // Threshold of objectness value + float iouThresh_ = DEFAULT_IOU_THRESH; // Non-Maximum Suppression threshold + int anchorDim_ = DEFAULT_ANCHOR_DIM; + int biasesNum_ = DEFAULT_BIASES_NUM; // anchors, generate from train data, coco dataset + int yoloType_ = DEFAULT_YOLO_TYPE; + int modelType_ = 0; + int yoloVersion_ = DEFAULT_YOLO_VERSION; + int inputType_ = 0; + std::vector biases_ = {}; +}; +#ifndef ENABLE_POST_PROCESS_INSTANCE + extern "C" { + std::shared_ptr GetObjectInstance(); + } +#endif +} // namespace MxBase +#endif diff --git a/community/cv/ADCAM/infer/mxbase/src/Yolov5Detection.cpp b/community/cv/ADCAM/infer/mxbase/src/Yolov5Detection.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cc274cf09722971e88b0de8dd453cac22ac4d971 --- /dev/null +++ b/community/cv/ADCAM/infer/mxbase/src/Yolov5Detection.cpp @@ -0,0 +1,350 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Yolov5Detection.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "MxBase/DeviceManager/DeviceManager.h" + +namespace { + const uint32_t MODEL_HEIGHT = 640; + const uint32_t MODEL_WIDTH = 640; + const uint32_t MODEL_CHANNEL = 12; + const uint32_t BATCH_NUM = 1; + const std::vector MEAN = {0.485, 0.456, 0.406}; + const std::vector STD = {0.229, 0.224, 0.225}; + const float MODEL_MAX = 255.0; + const int DATA_SIZE = 1228800; + const int coco_class_nameid[80] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, + 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90}; +} // namespace + +APP_ERROR Yolov5Detection::LoadLabels(const std::string &labelPath, std::map *labelMap) { + std::ifstream infile; + // open label file + infile.open(labelPath, std::ios_base::in); + std::string s; + // check label file validity + if (infile.fail()) { + LogError << "Failed to open label file: " << labelPath << "."; + return APP_ERR_COMM_OPEN_FAIL; + } + labelMap->clear(); + // construct label map + int count = 0; + while (std::getline(infile, s)) { + if (s[0] == '#') { + continue; + } + size_t eraseIndex = s.find_last_not_of("\r\n\t"); + if (eraseIndex != std::string::npos) { + s.erase(eraseIndex + 1, s.size() - eraseIndex); + } + labelMap->insert(std::pair(count, s)); + count++; + } + infile.close(); + return APP_ERR_OK; +} + +APP_ERROR Yolov5Detection::Init(const InitParam &initParam) { + deviceId_ = initParam.deviceId; + APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices(); + if (ret != APP_ERR_OK) { + LogError << "Init devices failed, ret=" << ret << "."; + return ret; + } + ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId); + if (ret != APP_ERR_OK) { + LogError << "Set context failed, ret=" << ret << "."; + return ret; + } + + model_ = std::make_shared(); + ret = model_->Init(initParam.modelPath, modelDesc_); + if (ret != APP_ERR_OK) { + LogError << "ModelInferenceProcessor init failed, ret=" << ret << "."; + return ret; + } + MxBase::ConfigData configData; + const std::string checkTensor = initParam.checkTensor ? "true" : "false"; + configData.SetJsonValue("CLASS_NUM", std::to_string(initParam.classNum)); + configData.SetJsonValue("BIASES_NUM", std::to_string(initParam.biasesNum)); + configData.SetJsonValue("BIASES", initParam.biases); + configData.SetJsonValue("OBJECTNESS_THRESH", initParam.objectnessThresh); + configData.SetJsonValue("IOU_THRESH", initParam.iouThresh); + configData.SetJsonValue("SCORE_THRESH", initParam.scoreThresh); + configData.SetJsonValue("YOLO_TYPE", std::to_string(initParam.yoloType)); + configData.SetJsonValue("MODEL_TYPE", std::to_string(initParam.modelType)); + configData.SetJsonValue("INPUT_TYPE", std::to_string(initParam.inputType)); + configData.SetJsonValue("ANCHOR_DIM", std::to_string(initParam.anchorDim)); + configData.SetJsonValue("CHECK_MODEL", checkTensor); + + auto jsonStr = configData.GetCfgJson().serialize(); + std::map> config; + config["postProcessConfigContent"] = std::make_shared(jsonStr); + config["labelPath"] = std::make_shared(initParam.labelPath); + + post_ = std::make_shared(); + ret = post_->Init(config); + if (ret != APP_ERR_OK) { + LogError << "Yolov5PostProcess init failed, ret=" << ret << "."; + return ret; + } + // load labels from file + ret = LoadLabels(initParam.labelPath, &labelMap_); + if (ret != APP_ERR_OK) { + LogError << "Failed to load labels, ret=" << ret << "."; + return ret; + } + LogInfo << "End to Init Yolov5DetectionOpencv."; + return APP_ERR_OK; +} + +APP_ERROR Yolov5Detection::DeInit() { + model_->DeInit(); + MxBase::DeviceManager::GetInstance()->DestroyDevices(); + return APP_ERR_OK; +} + +APP_ERROR Yolov5Detection::ReadImage(const std::string &imgPath, cv::Mat *imageMat) { + (*imageMat) = cv::imread(imgPath, cv::IMREAD_COLOR); + imageWidth_ = (*imageMat).cols; + imageHeight_ = (*imageMat).rows; + + return APP_ERR_OK; +} + +APP_ERROR Yolov5Detection::Resize(cv::Mat *srcImageMat, cv::Mat *dstImageMat) { + cv::resize((*srcImageMat), (*dstImageMat), cv::Size(MODEL_WIDTH, MODEL_HEIGHT)); + + return APP_ERR_OK; +} + +APP_ERROR Yolov5Detection::WhcToChw(const cv::Mat &srcImageMat, std::vector *imgData) { + int channel = srcImageMat.channels(); + std::vector bgrChannels(channel); + cv::split(srcImageMat, bgrChannels); + for (int i = channel - 1; i >= 0; i--) { + std::vector data = std::vector(bgrChannels[i].reshape(1, 1)); + std::transform(data.begin(), data.end(), data.begin(), + [&](float item) {return ((item / MODEL_MAX - MEAN[channel - i - 1]) / STD[channel - i - 1]); }); + imgData->insert(imgData->end(), data.begin(), data.end()); + } + + return APP_ERR_OK; +} + +APP_ERROR Yolov5Detection::Focus(const cv::Mat &srcImageMat, float* data) { + int outIdx = 0; + int imgIdx = 0; + int height = static_cast(srcImageMat.rows); + int width = static_cast(srcImageMat.cols); + int channel = static_cast(srcImageMat.channels()); + int newHeight = height / 2; + int newWidth = width / 2; + int newChannel = MODEL_CHANNEL; + + std::vector tmp; + WhcToChw(srcImageMat, &tmp); + + for (int newC = 0; newC < newChannel; newC++) { + int c = newC % channel; + for (int newH = 0; newH < newHeight; newH++) { + for (int newW = 0; newW < newWidth; newW++) { + if (newC < channel) { + outIdx = newC * newHeight * newWidth + newH * newWidth + newW; + imgIdx = c * height * width + newH * 2 * width + newW * 2; + } else if (channel <= newC && newC < channel * 2) { + outIdx = newC * newHeight * newWidth + newH * newWidth + newW; + imgIdx = c * height * width + static_cast((newH + 0.5) * 2 * width) + newW * 2; + } else if (channel * 2 <= newC && newC < channel * 3) { + outIdx = newC * newHeight * newWidth + newH * newWidth + newW; + imgIdx = c * height * width + newH * 2 * width + static_cast((newW + 0.5) * 2); + } else if (channel * 3 <= newC && newC < channel * 4) { + outIdx = newC * newHeight * newWidth + newH * newWidth + newW; + imgIdx = c * height * width + static_cast((newH + 0.5) * 2 * width) + + static_cast((newW + 0.5) * 2); + } else { + LogError << "new channels Out of range."; + return APP_ERR_OK; + } + data[outIdx] = tmp[imgIdx]; + } + } + } + + return APP_ERR_OK; +} + +APP_ERROR Yolov5Detection::CVMatToTensorBase(float* data, MxBase::TensorBase *tensorBase) { + uint32_t height = MODEL_HEIGHT / 2; + uint32_t width = MODEL_WIDTH / 2; + const uint32_t dataSize = MODEL_CHANNEL * height * width * sizeof(float); + MxBase::MemoryData memoryDataDst(dataSize, MxBase::MemoryData::MEMORY_DEVICE, deviceId_); + MxBase::MemoryData memoryDataSrc(data, dataSize, MxBase::MemoryData::MEMORY_HOST_MALLOC); + + APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc); + if (ret != APP_ERR_OK) { + LogError << GetError(ret) << "Memory malloc failed."; + return ret; + } + std::vector shape = {BATCH_NUM, MODEL_CHANNEL, height, width}; + (*tensorBase) = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_FLOAT32); + return APP_ERR_OK; +} + +APP_ERROR Yolov5Detection::Inference(const std::vector &inputs, + std::vector *outputs) { + auto dtypes = model_->GetOutputDataType(); + for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) { + std::vector shape = {}; + for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) { + shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]); + } + MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_); + APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor); + if (ret != APP_ERR_OK) { + LogError << "TensorBaseMalloc failed, ret=" << ret << "."; + return ret; + } + (*outputs).push_back(tensor); + } + + MxBase::DynamicInfo dynamicInfo = {}; + dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH; + auto startTime = std::chrono::high_resolution_clock::now(); + APP_ERROR ret = model_->ModelInference(inputs, (*outputs), dynamicInfo); + auto endTime = std::chrono::high_resolution_clock::now(); + double costMs = std::chrono::duration(endTime - startTime).count(); + g_inferCost.push_back(costMs); + if (ret != APP_ERR_OK) { + LogError << "ModelInference failed, ret=" << ret << "."; + return ret; + } + + return APP_ERR_OK; +} + +APP_ERROR Yolov5Detection::PostProcess(const std::vector& tensors, + std::vector> *objInfos) { + MxBase::ResizedImageInfo imgInfo; + imgInfo.widthOriginal = imageWidth_; + imgInfo.heightOriginal = imageHeight_; + imgInfo.widthResize = MODEL_WIDTH; + imgInfo.heightResize = MODEL_HEIGHT; + imgInfo.resizeType = MxBase::RESIZER_STRETCHING; + std::vector imageInfoVec = {}; + imageInfoVec.push_back(imgInfo); + + APP_ERROR ret = post_->Process(tensors, (*objInfos), imageInfoVec); + if (ret != APP_ERR_OK) { + LogInfo << "Process failed, ret=" << ret << "."; + return ret; + } + + return APP_ERR_OK; +} + +APP_ERROR Yolov5Detection::WriteResult(const std::vector> &objInfos, + const std::string &imgPath, std::vector *jsonText) { + uint32_t batchSize = objInfos.size(); + + int pos = imgPath.rfind('/'); + std::string fileName(imgPath, pos + 1); + fileName = fileName.substr(0, fileName.find('.')); + // write inference result into file + int image_id = std::stoi(fileName), cnt = 0; + for (uint32_t i = 0; i < batchSize; i++) { + for (auto &obj : objInfos[i]) { + jsonText->push_back("{\"image_id\": " + std::to_string(image_id) + ", \"category_id\": " + + std::to_string(coco_class_nameid[static_cast(obj.classId)]) + ", \"bbox\": [" + + std::to_string(obj.x0) + ", " + std::to_string(obj.y0) + ", " + + std::to_string(obj.x1 - obj.x0) + ", " + std::to_string(obj.y1 - obj.y0) + "], " + + "\"score\": " + std::to_string(obj.confidence) + "}"); + cnt++; + } + } + return APP_ERR_OK; +} + +APP_ERROR Yolov5Detection::Process(const std::string &imgPath, std::vector *jsonText) { + // process image + cv::Mat imageMat; + APP_ERROR ret = ReadImage(imgPath, &imageMat); + if (ret != APP_ERR_OK) { + LogError << "ReadImage failed, ret=" << ret << "."; + return ret; + } + + ret = Resize(&imageMat, &imageMat); + if (ret != APP_ERR_OK) { + LogError << "Resize failed, ret=" << ret << "."; + return ret; + } + + float data[DATA_SIZE]; + ret = Focus(imageMat, data); + + if (ret != APP_ERR_OK) { + LogError << "Focus failed, ret=" << ret << "."; + return ret; + } + + std::vector inputs = {}; + std::vector outputs = {}; + MxBase::TensorBase tensorBase; + ret = CVMatToTensorBase(data, &tensorBase); + + if (ret != APP_ERR_OK) { + LogError << "CVMatToTensorBase failed, ret=" << ret << "."; + return ret; + } + + inputs.push_back(tensorBase); + ret = Inference(inputs, &outputs); + if (ret != APP_ERR_OK) { + LogError << "Inference failed, ret=" << ret << "."; + return ret; + } + + + std::vector> objInfos; + ret = PostProcess(outputs, &objInfos); + if (ret != APP_ERR_OK) { + LogError << "PostProcess failed, ret=" << ret << "."; + return ret; + } + + ret = WriteResult(objInfos, imgPath, jsonText); + if (ret != APP_ERR_OK) { + LogError << "WriteResult failed, ret=" << ret << "."; + return ret; + } + + imageMat.release(); + return APP_ERR_OK; +} diff --git a/community/cv/ADCAM/infer/mxbase/src/Yolov5Detection.h b/community/cv/ADCAM/infer/mxbase/src/Yolov5Detection.h new file mode 100644 index 0000000000000000000000000000000000000000..c75705fa262d6de9bedc55f3e7010f5da129956e --- /dev/null +++ b/community/cv/ADCAM/infer/mxbase/src/Yolov5Detection.h @@ -0,0 +1,83 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef YOLOV5_RECOGNITION_H +#define YOLOV5_RECOGNITION_H + +#include +#include +#include +#include + +#include + +#include "MxBase/DvppWrapper/DvppWrapper.h" +#include "MxBase/ModelInfer/ModelInferenceProcessor.h" +#include "MxBase/DeviceManager/DeviceManager.h" +#include "MxBase/Tensor/TensorContext/TensorContext.h" +#include "PostProcess/Yolov5MindSporePost.h" +#include "MxBase/PostProcessBases/TextObjectPostProcessBase.h" + + +extern std::vector g_inferCost; + +struct InitParam { + uint32_t deviceId; + std::string labelPath; + bool checkTensor; + std::string modelPath; + uint32_t classNum; + uint32_t biasesNum; + std::string biases; + std::string objectnessThresh; + std::string iouThresh; + std::string scoreThresh; + uint32_t yoloType; + uint32_t modelType; + uint32_t inputType; + uint32_t anchorDim; +}; + +class Yolov5Detection { + public: + APP_ERROR Init(const InitParam &initParam); + APP_ERROR DeInit(); + APP_ERROR ReadImage(const std::string &imgPath, cv::Mat *imageMat); + APP_ERROR Resize(cv::Mat *srcImageMat, cv::Mat *dstImageMat); + APP_ERROR WhcToChw(const cv::Mat &srcImageMat, std::vector *imgData); + APP_ERROR Focus(const cv::Mat &srcImageMat, float* data); + APP_ERROR CVMatToTensorBase(float* data, MxBase::TensorBase *tensorBase); + APP_ERROR LoadLabels(const std::string &labelPath, std::map *labelMap); + APP_ERROR Inference(const std::vector &inputs, std::vector *outputs); + APP_ERROR PostProcess(const std::vector& tensors, + std::vector> *objInfos); + APP_ERROR WriteResult(const std::vector> &objInfos, + const std::string &imgPath, std::vector *jsonText); + APP_ERROR Process(const std::string &imgPath, std::vector *jsonText); + // get infer time + double GetInferCostMilliSec() const {return inferCostTimeMilliSec;} + + private: + std::shared_ptr model_; + std::shared_ptr post_; + MxBase::ModelDesc modelDesc_; + std::map labelMap_; + uint32_t deviceId_ = 0; + uint32_t imageWidth_ = 0; + uint32_t imageHeight_ = 0; + // infer time + double inferCostTimeMilliSec = 0.0; +}; +#endif diff --git a/community/cv/ADCAM/infer/mxbase/src/main.cpp b/community/cv/ADCAM/infer/mxbase/src/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6bb0c76fb04b000c2e3aafbca2e20aba4d552896 --- /dev/null +++ b/community/cv/ADCAM/infer/mxbase/src/main.cpp @@ -0,0 +1,154 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include "Yolov5Detection.h" +#include "MxBase/Log/Log.h" + +std::vector g_inferCost; + +void ShowUsage() { + LogInfo << "Usage : ./yolov5 <--image or --dir> [Option]" << std::endl; + LogInfo << "Options :" << std::endl; + LogInfo << " --image infer_image_path the path of single infer image, such as " + "./yolov5 --image /home/infer/images/test.jpg." << std::endl; + LogInfo << " --dir infer_image_dir the dir of batch infer images, such as " + "./yolov5 --dir /home/infer/images." << std::endl; + return; +} + +void InitYolov4TinyParam(InitParam *initParam) { + initParam->deviceId = 0; + initParam->labelPath = "../../data/models/coco2017.names"; + initParam->checkTensor = true; + initParam->modelPath = "../../data/models/yolov5.om"; + initParam->classNum = 80; + initParam->biasesNum = 18; + initParam->biases = "12,16,19,36,40,28,36,75,76,55,72,146,142,110,192,243,459,401"; + initParam->objectnessThresh = "0.001"; + initParam->iouThresh = "0.6"; + initParam->scoreThresh = "0.001"; + initParam->yoloType = 3; + initParam->modelType = 0; + initParam->inputType = 0; + initParam->anchorDim = 3; +} + +APP_ERROR saveResult(const std::vector &jsonText, const std::string &savePath) { + // create result directory when it does not exit + std::string resultPath = savePath; + if (access(resultPath.c_str(), 0) != 0) { + int ret = mkdir(resultPath.c_str(), S_IRUSR | S_IWUSR | S_IXUSR); + if (ret != 0) { + LogError << "Failed to create result directory: " << resultPath << ", ret = " << ret; + return APP_ERR_COMM_OPEN_FAIL; + } + } + // create result file under result directory + resultPath = resultPath + "/predict.json"; + std::ofstream tfile(resultPath, std::ofstream::out|std::ofstream::trunc); + if (tfile.fail()) { + LogError << "Failed to open result file: " << resultPath; + return APP_ERR_COMM_OPEN_FAIL; + } + tfile << "["; + for (uint32_t i = 0; i < jsonText.size(); i++) { + tfile << jsonText[i]; + if (i != jsonText.size() - 1) tfile << ", "; + } + tfile << "]"; + tfile.close(); + + return APP_ERR_OK; +} + +APP_ERROR ReadImagesPath(const std::string& dir, std::vector *imagesPath) { + DIR *dirPtr = opendir(dir.c_str()); + if (dirPtr == nullptr) { + LogError << "opendir failed. dir: " << dir; + return APP_ERR_INTERNAL_ERROR; + } + dirent *direntPtr = nullptr; + while ((direntPtr = readdir(dirPtr)) != nullptr) { + std::string fileName = direntPtr->d_name; + if (fileName == "." || fileName == "..") { + continue; + } + (*imagesPath).emplace_back(dir + "/" + fileName); + } + closedir(dirPtr); + return APP_ERR_OK; +} + +int main(int argc, char* argv[]) { + if (argc != 3) { + LogInfo << "Please use as follows." << std::endl; + ShowUsage(); + return APP_ERR_OK; + } + + std::string option = argv[1]; + std::string imgPath = argv[2]; + + if (option != "--image" && option != "--dir") { + LogInfo << "Please use as follows." << std::endl; + ShowUsage(); + return APP_ERR_OK; + } + + InitParam initParam = {}; + InitYolov4TinyParam(&initParam); + auto yolov5 = std::make_shared(); + APP_ERROR ret = yolov5->Init(initParam); + if (ret != APP_ERR_OK) { + LogInfo << "Yolov5DetectionOpencv init failed, ret=" << ret << "."; + return ret; + } + LogInfo << "End to Init yolov5."; + + std::vector imagesPath; + if (option == "--image") { + imagesPath.emplace_back(imgPath); + } else { + ret = ReadImagesPath(imgPath, &imagesPath); + } + + if (ret != APP_ERR_OK) { + LogInfo << "read file failed, ret=" << ret << "."; + return ret; + } + LogInfo << "read file success."; + std::vector jsonText; + for (auto path : imagesPath) { + LogInfo << "read image path " << path; + yolov5->Process(path, &jsonText); + } + + std::string resultPathName = "../result/"; + saveResult(jsonText, resultPathName); + + yolov5->DeInit(); + double costSum = 0; + for (uint32_t i = 0; i < g_inferCost.size(); i++) { + costSum += g_inferCost[i]; + } + LogInfo << "Infer images sum " << g_inferCost.size() << ", cost total time: " << costSum << " ms."; + LogInfo << "The throughput: " << g_inferCost.size() * 1000 / costSum << " images/sec."; + return APP_ERR_OK; +} diff --git a/community/cv/ADCAM/infer/requirements.txt b/community/cv/ADCAM/infer/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1e199daef6f7f5b00d9a3af9f6d5f20c4e65bffb --- /dev/null +++ b/community/cv/ADCAM/infer/requirements.txt @@ -0,0 +1,4 @@ +numpy +pillow +opencv-python +pycocotools >= 2.0.5 \ No newline at end of file diff --git a/community/cv/ADCAM/infer/sdk/api/infer.py b/community/cv/ADCAM/infer/sdk/api/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..01e6a44e772ebd9618d75a0a376001796be46a41 --- /dev/null +++ b/community/cv/ADCAM/infer/sdk/api/infer.py @@ -0,0 +1,160 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Inference Api +""" +import json +import logging +from config import config as cfg +import MxpiDataType_pb2 as MxpiDataType +from StreamManagerApi import ( + StreamManagerApi, + MxDataInput, + InProtobufVector, + MxProtobufIn, +) + + +class SdkApi: + """ + Manage pieline stream + """ + + INFER_TIMEOUT = cfg.INFER_TIMEOUT + STREAM_NAME = cfg.STREAM_NAME + + def __init__(self, pipeline_cfg): + """ + Parameter initialization + """ + self.pipeline_cfg = pipeline_cfg + self._stream_api = None + self._data_input = None + self._device_id = None + + def init(self): + """ + Stream initialization + """ + with open(self.pipeline_cfg, "r") as fp: + self._device_id = int( + json.loads(fp.read())[self.STREAM_NAME]["stream_config"]["deviceId"] + ) + + print(f"The device id: {self._device_id}.") + + # create api + self._stream_api = StreamManagerApi() + + # init stream mgr + ret = self._stream_api.InitManager() + if ret != 0: + print(f"Failed to init stream manager, ret={ret}.") + return False + + # create streams + with open(self.pipeline_cfg, "rb") as fp: + pipe_line = fp.read() + + ret = self._stream_api.CreateMultipleStreams(pipe_line) + if ret != 0: + print(f"Failed to create stream, ret={ret}.") + return False + + self._data_input = MxDataInput() + return True + + def __del__(self): + if not self._stream_api: + return + + self._stream_api.DestroyAllStreams() + + def send_data_input(self, stream_name, plugin_id, input_data): + data_input = MxDataInput() + data_input.data = input_data + unique_id = self._stream_api.SendData(stream_name, plugin_id, data_input) + if unique_id < 0: + logging.error("Fail to send data to stream.") + return False + return True + + def get_protobuf(self, stream_name, plugin_id, keyVec): + result = self._stream_api.GetProtobuf(stream_name, plugin_id, keyVec) + return result + + def _send_protobuf(self, stream_name, plugin_id, element_name, buf_type, pkg_list): + """ + Input image data + """ + protobuf = MxProtobufIn() + protobuf.key = element_name.encode("utf-8") + protobuf.type = buf_type + protobuf.protobuf = pkg_list.SerializeToString() + protobuf_vec = InProtobufVector() + protobuf_vec.push_back(protobuf) + err_code = self._stream_api.SendProtobuf(stream_name, plugin_id, protobuf_vec) + if err_code != 0: + logging.error( + "Failed to send data to stream, stream_name(%s), plugin_id(%s), element_name(%s), " + "buf_type(%s), err_code(%s).", + stream_name, + plugin_id, + element_name, + buf_type, + err_code, + ) + return False + return True + + def send_img_input(self, stream_name, plugin_id, element_name, input_data, img_size): + """ + input image data after preprocess + """ + vision_list = MxpiDataType.MxpiVisionList() + vision_vec = vision_list.visionVec.add() + vision_vec.visionInfo.format = 1 + vision_vec.visionInfo.width = img_size[1] + vision_vec.visionInfo.height = img_size[0] + vision_vec.visionInfo.widthAligned = img_size[1] + vision_vec.visionInfo.heightAligned = img_size[0] + vision_vec.visionData.memType = 0 + vision_vec.visionData.dataStr = input_data + vision_vec.visionData.dataSize = len(input_data) + + buf_type = b"MxTools.MxpiVisionList" + return self._send_protobuf( + stream_name, plugin_id, element_name, buf_type, vision_list + ) + + def send_tensor_input(self, stream_name, plugin_id, element_name, input_data, input_shape, data_type): + """ + get image tensor + """ + tensor_list = MxpiDataType.MxpiTensorPackageList() + tensor_pkg = tensor_list.tensorPackageVec.add() + # init tensor vector + tensor_vec = tensor_pkg.tensorVec.add() + tensor_vec.deviceId = self._device_id + tensor_vec.memType = 0 + tensor_vec.tensorShape.extend(input_shape) + tensor_vec.tensorDataType = data_type + tensor_vec.dataStr = input_data + tensor_vec.tensorDataSize = len(input_data) + + buf_type = b"MxTools.MxpiTensorPackageList" + return self._send_protobuf( + stream_name, plugin_id, element_name, buf_type, tensor_list + ) diff --git a/community/cv/ADCAM/infer/sdk/api/postprocess.py b/community/cv/ADCAM/infer/sdk/api/postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..33c9f21d9d67418623bccca9153888a8349f7f8e --- /dev/null +++ b/community/cv/ADCAM/infer/sdk/api/postprocess.py @@ -0,0 +1,255 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""post process""" +import sys + +from collections import defaultdict +import numpy as np +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + + +class Redirct: + def __init__(self): + self.content = "" + + def write(self, content): + self.content += content + + def flush(self): + self.content = "" + + +class DetectionEngine: + """Detection engine.""" + + def __init__(self, args_detection): + self.ignore_threshold = args_detection.ignore_threshold + self.args = args_detection + self.labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', + 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', + 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', + 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', + 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] + self.num_classes = len(self.labels) + self.results = {} + self.file_path = '' + self.ann_file = args_detection.ann_file + self._coco = COCO(self.ann_file) + self._img_ids = list(sorted(self._coco.imgs.keys())) + self.det_boxes = [] + self.nms_thresh = args_detection.nms_thresh + self.multi_label = args_detection.multi_label + self.multi_label_thresh = args_detection.multi_label_thresh + self.coco_catIds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, + 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, + 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 84, 85, 86, 87, 88, 89, 90] + + + def do_nms_for_results(self): + """Get result boxes.""" + for image_id in self.results: + for clsi in self.results[image_id]: + dets = self.results[image_id][clsi] + dets = np.array(dets) + keep_index = self._diou_nms(dets, thresh=self.nms_thresh) + + keep_box = [{'image_id': int(image_id), 'category_id': int(clsi), + 'bbox': list(dets[i][:4].astype(float)), + 'score': dets[i][4].astype(float)} for i in keep_index] + self.det_boxes.extend(keep_box) + + def _nms(self, predicts, threshold): + """Calculate NMS.""" + # convert xywh -> xmin ymin xmax ymax + x1 = predicts[:, 0] + y1 = predicts[:, 1] + x2 = x1 + predicts[:, 2] + y2 = y1 + predicts[:, 3] + scores = predicts[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + reserved_boxes = [] + while order.size > 0: + i = order[0] + reserved_boxes.append(i) + max_x1 = np.maximum(x1[i], x1[order[1:]]) + max_y1 = np.maximum(y1[i], y1[order[1:]]) + min_x2 = np.minimum(x2[i], x2[order[1:]]) + min_y2 = np.minimum(y2[i], y2[order[1:]]) + + intersect_w = np.maximum(0.0, min_x2 - max_x1 + 1) + intersect_h = np.maximum(0.0, min_y2 - max_y1 + 1) + intersect_area = intersect_w * intersect_h + ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area) + + indexes = np.where(ovr <= threshold)[0] + order = order[indexes + 1] + return reserved_boxes + + def _diou_nms(self, dets, thresh=0.5): + """ + convert xywh -> xmin ymin xmax ymax + """ + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = x1 + dets[:, 2] + y2 = y1 + dets[:, 3] + scores = dets[:, 4] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + center_x1 = (x1[i] + x2[i]) / 2 + center_x2 = (x1[order[1:]] + x2[order[1:]]) / 2 + center_y1 = (y1[i] + y2[i]) / 2 + center_y2 = (y1[order[1:]] + y2[order[1:]]) / 2 + inter_diag = (center_x2 - center_x1) ** 2 + (center_y2 - center_y1) ** 2 + out_max_x = np.maximum(x2[i], x2[order[1:]]) + out_max_y = np.maximum(y2[i], y2[order[1:]]) + out_min_x = np.minimum(x1[i], x1[order[1:]]) + out_min_y = np.minimum(y1[i], y1[order[1:]]) + outer_diag = (out_max_x - out_min_x) ** 2 + (out_max_y - out_min_y) ** 2 + diou = ovr - inter_diag / outer_diag + diou = np.clip(diou, -1, 1) + inds = np.where(diou <= thresh)[0] + order = order[inds + 1] + return keep + + def write_result(self): + """Save result to file.""" + import json + try: + self.file_path = self.args.result_files + '/predict' + '.json' + f = open(self.file_path, 'w') + json.dump(self.det_boxes, f) + except IOError as e: + raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e))) + else: + f.close() + return self.file_path + + def get_eval_result(self): + """Get eval result.""" + coco_gt = COCO(self.ann_file) + coco_dt = coco_gt.loadRes(self.file_path) + coco_eval = COCOeval(coco_gt, coco_dt, 'bbox') + coco_eval.evaluate() + coco_eval.accumulate() + rdct = Redirct() + stdout = sys.stdout + sys.stdout = rdct + coco_eval.summarize() + sys.stdout = stdout + return rdct.content + + def detect(self, outputs, batch, img_shape, image_id): + """Detect boxes.""" + outputs_num = len(outputs) + # output [|32, 52, 52, 3, 85| ] + for batch_id in range(batch): + for out_id in range(outputs_num): + # 32, 52, 52, 3, 85 + out_item = outputs[out_id] + # 52, 52, 3, 85 + out_item_single = out_item[batch_id, :] + # get number of items in one head, [B, gx, gy, anchors, 5+80] + dimensions = out_item_single.shape[:-1] + out_num = 1 + for d in dimensions: + out_num *= d + ori_w, ori_h = img_shape[batch_id] + img_id = int(image_id[batch_id]) + x = out_item_single[..., 0] * ori_w + y = out_item_single[..., 1] * ori_h + w = out_item_single[..., 2] * ori_w + h = out_item_single[..., 3] * ori_h + + conf = out_item_single[..., 4:5] + cls_emb = out_item_single[..., 5:] + cls_argmax = np.expand_dims(np.argmax(cls_emb, axis=-1), axis=-1) + x = x.reshape(-1) + y = y.reshape(-1) + w = w.reshape(-1) + h = h.reshape(-1) + x_top_left = x - w / 2. + y_top_left = y - h / 2. + cls_emb = cls_emb.reshape(-1, self.num_classes) + if self.multi_label: + conf = conf.reshape(-1, 1) + # create all False + confidence = cls_emb * conf + flag = cls_emb > self.multi_label_thresh + flag = flag.nonzero() + for index in range(len(flag[0])): + i = flag[0][index] + j = flag[1][index] + confi = confidence[i][j] + if confi < self.ignore_threshold: + continue + if img_id not in self.results: + self.results[img_id] = defaultdict(list) + x_lefti = max(0, x_top_left[i]) + y_lefti = max(0, y_top_left[i]) + wi = min(w[i], ori_w) + hi = min(h[i], ori_h) + clsi = j + # transform catId to match coco + coco_clsi = self.coco_catIds[clsi] + self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi]) + else: + cls_argmax = np.expand_dims(np.argmax(cls_emb, axis=-1), axis=-1) + conf = conf.reshape(-1) + cls_argmax = cls_argmax.reshape(-1) + + # create all False + flag = np.random.random(cls_emb.shape) > sys.maxsize + for i in range(flag.shape[0]): + c = cls_argmax[i] + flag[i, c] = True + confidence = cls_emb[flag] * conf + + for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, w, h, confidence, + cls_argmax): + if confi < self.ignore_threshold: + continue + if img_id not in self.results: + self.results[img_id] = defaultdict(list) + x_lefti = max(0, x_lefti) + y_lefti = max(0, y_lefti) + wi = min(wi, ori_w) + hi = min(hi, ori_h) + # transform catId to match coco + coco_clsi = self.coco_catIds[clsi] + self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi]) diff --git a/community/cv/ADCAM/infer/sdk/config/config.py b/community/cv/ADCAM/infer/sdk/config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..4bce8053eadf098cbecd9d3ac5bc92d0a718ca31 --- /dev/null +++ b/community/cv/ADCAM/infer/sdk/config/config.py @@ -0,0 +1,28 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Inference parameter configuration +""" +MODEL_WIDTH = 640 +MODEL_HEIGHT = 640 +NUM_CLASSES = 80 +SCORE_THRESH = 0.3 +STREAM_NAME = "im_yolov5" + +INFER_TIMEOUT = 100000 + +TENSOR_DTYPE_FLOAT32 = 0 +TENSOR_DTYPE_FLOAT16 = 1 +TENSOR_DTYPE_INT8 = 2 diff --git a/community/cv/ADCAM/infer/sdk/config/yolov5.pipeline b/community/cv/ADCAM/infer/sdk/config/yolov5.pipeline new file mode 100644 index 0000000000000000000000000000000000000000..34e2ac343b84f423a3b44f1d423ac915f79b635f --- /dev/null +++ b/community/cv/ADCAM/infer/sdk/config/yolov5.pipeline @@ -0,0 +1,30 @@ +{ + "im_yolov5": { + "stream_config": { + "deviceId": "0" + }, + "appsrc0": { + "props": { + "blocksize": "409600" + }, + "factory": "appsrc", + "next": "mxpi_tensorinfer0" + }, + "mxpi_tensorinfer0": { + "props": { + "dataSource": "appsrc0", + "modelPath": "../data/models/yolov5.om", + "waitingTime": "2000", + "outputDeviceId": "-1" + }, + "factory": "mxpi_tensorinfer", + "next": "appsink0" + }, + "appsink0": { + "props": { + "blocksize": "4096000" + }, + "factory": "appsink" + } + } +} diff --git a/community/cv/ADCAM/infer/sdk/eval/eval_by_sdk.py b/community/cv/ADCAM/infer/sdk/eval/eval_by_sdk.py new file mode 100644 index 0000000000000000000000000000000000000000..a0e11c0d2fc8d47ec1eebacee1c525d4ce1a44b0 --- /dev/null +++ b/community/cv/ADCAM/infer/sdk/eval/eval_by_sdk.py @@ -0,0 +1,40 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""eval_by_sdk""" +import argparse +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + + +def get_eval_result(ann_file, result_file): + """Get eval result.""" + coco_gt = COCO(ann_file) + coco_dt = coco_gt.loadRes(result_file) + coco_eval = COCOeval(coco_gt, coco_dt, "bbox") + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="yolov5 eval") + parser.add_argument("--ann_file", type=str, default="", help="path to annotation") + parser.add_argument( + "--result_file", type=str, default="", help="path to annotation" + ) + + args = parser.parse_args() + + get_eval_result(args.ann_file, args.result_file) diff --git a/community/cv/ADCAM/infer/sdk/main.py b/community/cv/ADCAM/infer/sdk/main.py new file mode 100644 index 0000000000000000000000000000000000000000..dc5c9bf7f19b9c68187b22f081c6eb7d3aae9c6d --- /dev/null +++ b/community/cv/ADCAM/infer/sdk/main.py @@ -0,0 +1,182 @@ +# !/usr/bin/env python +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Sdk internece +""" +import argparse +import os +import time +import ast +import numpy as np + +from PIL import Image +from pycocotools.coco import COCO + +from api.infer import SdkApi +from api.postprocess import DetectionEngine +import MxpiDataType_pb2 as MxpiDataType +from StreamManagerApi import StringVector +from config import config as cfg + + +def parser_args(): + """ + configuration parameter, input from outside + """ + parser = argparse.ArgumentParser(description="yolov5 inference") + parser.add_argument( + "--pipeline_path", + type=str, + required=False, + default="config/yolov5.pipeline", + help="pipeline file path. The default is 'config/centernet.pipeline'. ", + ) + + parser.add_argument( + "--nms_thresh", type=float, default=0.6, help="threshold for NMS" + ) + parser.add_argument("--ann_file", type=str, default="", help="path to annotation") + parser.add_argument( + "--ignore_threshold", + type=float, + default=0.001, + help="threshold to throw low quality boxes", + ) + + parser.add_argument( + "--dataset_path", type=str, default="", help="path of image dataset" + ) + parser.add_argument( + "--result_files", + type=str, + default="./result", + help="path to 310 infer result path", + ) + parser.add_argument( + "--multi_label", + type=ast.literal_eval, + default=True, + help="whether to use multi label", + ) + parser.add_argument( + "--multi_label_thresh", + type=float, + default=0.1, + help="threshold to throw low quality boxes", + ) + + arg = parser.parse_args() + return arg + + +def process_img(img_file): + """ + Preprocessing the images + """ + # Computed from random subset of ImageNet training images + mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) + std = np.array([0.229, 0.224, 0.225], dtype=np.float32) + img = Image.open(img_file).convert("RGB") + img = img.resize((cfg.MODEL_HEIGHT, cfg.MODEL_WIDTH), 0) + img = np.array(img, dtype=np.float32) + img = img / 255.0 + img = (img - mean) / std + img = img.transpose(2, 0, 1) + img = np.expand_dims(img, 0) + img = np.concatenate( + ( + img[..., ::2, ::2], + img[..., 1::2, ::2], + img[..., ::2, 1::2], + img[..., 1::2, 1::2], + ), + axis=1, + ) + + return img + + +def image_inference(pipeline_path, stream_name, img_dir, detection): + """ + image inference: get inference for images + """ + sdk_api = SdkApi(pipeline_path) + if not sdk_api.init(): + return + + img_data_plugin_id = 0 + print(f"\nBegin to inference for {img_dir}.\n") + + file_list = os.listdir(img_dir) + coco = COCO(args.ann_file) + start_time = time.time() + + for file_name in file_list: + if not file_name.lower().endswith((".jpg", "jpeg")): + continue + + img_ids_name = file_name.split(".")[0] + img_id_ = int(np.squeeze(img_ids_name)) + imgIds = coco.getImgIds(imgIds=[img_id_]) + img = coco.loadImgs(imgIds[np.random.randint(0, len(imgIds))])[0] + image_shape = ((img["width"], img["height"]),) + img_id_ = (np.squeeze(img_ids_name),) + + imgs = process_img(os.path.join(img_dir, file_name)) + sdk_api.send_tensor_input( + stream_name, + img_data_plugin_id, + "appsrc0", + imgs.tobytes(), + imgs.shape, + cfg.TENSOR_DTYPE_FLOAT32, + ) + + keys = [b"mxpi_tensorinfer0"] + keyVec = StringVector() + for key in keys: + keyVec.push_back(key) + infer_result = sdk_api.get_protobuf(stream_name, 0, keyVec) + + result = MxpiDataType.MxpiTensorPackageList() + result.ParseFromString(infer_result[0].messageBuf) + output_small = np.frombuffer( + result.tensorPackageVec[0].tensorVec[0].dataStr, dtype="float32" + ).reshape((1, 20, 20, 3, 85)) + output_me = np.frombuffer( + result.tensorPackageVec[0].tensorVec[1].dataStr, dtype="float32" + ).reshape((1, 40, 40, 3, 85)) + output_big = np.frombuffer( + result.tensorPackageVec[0].tensorVec[2].dataStr, dtype="float32" + ).reshape((1, 80, 80, 3, 85)) + print("process {}...".format(file_name)) + detection.detect([output_small, output_me, output_big], 1, image_shape, img_id_) + + print("do_nms_for_results...") + detection.do_nms_for_results() + detection.write_result() + + cost_time = time.time() - start_time + print("testing cost time {:.2f}h".format(cost_time / 3600.0)) + + +if __name__ == "__main__": + args = parser_args() + detections = DetectionEngine(args) + stream_name0 = cfg.STREAM_NAME.encode("utf-8") + print("stream_name0:") + print(stream_name0) + image_inference(args.pipeline_path, stream_name0, args.dataset_path, detections) diff --git a/community/cv/ADCAM/infer/sdk/run.sh b/community/cv/ADCAM/infer/sdk/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..e685af2595412a2919db9aca294b39dc795b42a3 --- /dev/null +++ b/community/cv/ADCAM/infer/sdk/run.sh @@ -0,0 +1,63 @@ +#!/bin/bash +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +pipeline_path=./config/yolov5.pipeline +dataset_path=../data/image/ +ann_file=../data/instances_val2017.json +result_files=./result +# help message +if [[ $1 == --help || $1 == -h ]];then + echo "usage:bash ./run.sh " + echo "parameter explain: + --pipeline_path set SDK infer pipeline, e.g. --pipeline_path=./config/yolov5.pipeline + --dataset_path root path of processed images, e.g. --dataset_path=../data/image + --ann_file the folder to save the semantic mask images, default: --ann_file=./result + -h/--help show help message + " + exit 1 +fi + +for para in "$@" +do + if [[ $para == --pipeline_path* ]];then + pipeline_path=`echo ${para#*=}` + elif [[ $para == --dataset_path* ]];then + dataset_path=`echo ${para#*=}` + elif [[ $para == --ann_file* ]];then + ann_file=`echo ${para#*=}` + elif [[ $para == --result_files* ]];then + result_files=`echo ${para#*=}` + fi +done + +if [[ $pipeline_path == "" ]];then + echo "[Error] para \"pipeline_path \" must be config" + exit 1 +fi +if [[ $dataset_path == "" ]];then + echo "[Error] para \"dataset_path \" must be config" + exit 1 +fi +if [[ $ann_file == "" ]];then + echo "[Error] para \"ann_file \" must be config" + exit 1 +fi + +python3 main.py --pipeline_path=$pipeline_path \ + --dataset_path=$dataset_path \ + --ann_file=$ann_file \ + --result_files=$result_files + +exit 0 diff --git a/community/cv/ADCAM/loss.png b/community/cv/ADCAM/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..f290a6d647da80b6b86cb749d861a9bf2495549b Binary files /dev/null and b/community/cv/ADCAM/loss.png differ diff --git a/community/cv/ADCAM/mindspore_hub_conf.py b/community/cv/ADCAM/mindspore_hub_conf.py new file mode 100644 index 0000000000000000000000000000000000000000..ef0ae91a02d12eb31901bae3eb93b49486b58fc5 --- /dev/null +++ b/community/cv/ADCAM/mindspore_hub_conf.py @@ -0,0 +1,22 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MindSpore Hub config file.""" +from src.yolo import YOLOV5s + +def create_network(name, *args, **kwargs): + if name == "yolov5s": + yolov5s_net = YOLOV5s(is_training=True) + return yolov5s_net + raise NotImplementedError(f"{name} is not implemented in the repo") diff --git a/community/cv/ADCAM/model_utils/__init__.py b/community/cv/ADCAM/model_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40bc337dc255b845ff202691d41a4adc9e7032fa --- /dev/null +++ b/community/cv/ADCAM/model_utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/community/cv/ADCAM/model_utils/config.py b/community/cv/ADCAM/model_utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..933e6b8d9f912e2ae5a6d60961afb5bc5324f1a6 --- /dev/null +++ b/community/cv/ADCAM/model_utils/config.py @@ -0,0 +1,154 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""config""" +import os +import ast +import argparse +from pprint import pformat +import yaml + + +class Config: + """ + Configuration namespace. Convert dictionary to members. + """ + + def __init__(self, cfg_dict): + for k, v in cfg_dict.items(): + if isinstance(v, (list, tuple)): + setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v]) + else: + setattr(self, k, Config(v) if isinstance(v, dict) else v) + + def __str__(self): + return pformat(self.__dict__) + + def __repr__(self): + return self.__str__() + + +def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"): + """ + Parse command line arguments to the configuration according to the default yaml. + + Args: + parser: Parent parser. + cfg: Base configuration. + helper: Helper description. + cfg_path: Path to the default yaml config. + """ + parser = argparse.ArgumentParser( + description="[REPLACE THIS at config.py]", parents=[parser] + ) + helper = {} if helper is None else helper + choices = {} if choices is None else choices + for item in cfg: + if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict): + help_description = ( + helper[item] + if item in helper + else "Please reference to {}".format(cfg_path) + ) + choice = choices[item] if item in choices else None + if isinstance(cfg[item], bool): + parser.add_argument( + "--" + item, + type=ast.literal_eval, + default=cfg[item], + choices=choice, + help=help_description, + ) + else: + parser.add_argument( + "--" + item, + type=type(cfg[item]), + default=cfg[item], + choices=choice, + help=help_description, + ) + args = parser.parse_args() + return args + + +def parse_yaml(yaml_path): + """ + Parse the yaml config file. + + Args: + yaml_path: Path to the yaml config. + """ + with open(yaml_path, "r") as fin: + try: + cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader) + cfgs = [x for x in cfgs] + if len(cfgs) == 1: + cfg_helper = {} + cfg = cfgs[0] + cfg_choices = {} + elif len(cfgs) == 2: + cfg, cfg_helper = cfgs + cfg_choices = {} + elif len(cfgs) == 3: + cfg, cfg_helper, cfg_choices = cfgs + else: + raise ValueError( + "At most 3 docs (config, description for help, choices) are supported in config yaml" + ) + print(cfg_helper) + except: + raise ValueError("Failed to parse yaml") + return cfg, cfg_helper, cfg_choices + + +def merge(args, cfg): + """ + Merge the base config from yaml file and command line arguments. + + Args: + args: Command line arguments. + cfg: Base configuration. + """ + args_var = vars(args) + for item in args_var: + cfg[item] = args_var[item] + return cfg + + +def get_config(): + """ + Get Config according to the yaml file and cli arguments. + """ + parser = argparse.ArgumentParser(description="default name", add_help=False) + current_dir = os.path.dirname(os.path.abspath(__file__)) + parser.add_argument( + "--config_path", + type=str, + default=os.path.join(current_dir, "../default_config.yaml"), + help="Config file path", + ) + path_args, _ = parser.parse_known_args() + default, helper, choices = parse_yaml(path_args.config_path) + args = parse_cli_to_yaml( + parser=parser, + cfg=default, + helper=helper, + choices=choices, + cfg_path=path_args.config_path, + ) + final_config = merge(args, default) + return Config(final_config) + + +config = get_config() diff --git a/community/cv/ADCAM/model_utils/device_adapter.py b/community/cv/ADCAM/model_utils/device_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..26a25cf03a64929cbf1b4b4312997474bc987320 --- /dev/null +++ b/community/cv/ADCAM/model_utils/device_adapter.py @@ -0,0 +1,24 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# Device adapter for ModelArts +"""Device adapter for ModelArts.""" +from .config import config + +if config.enable_modelarts: + from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id +else: + from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id + +__all__ = ["get_device_id", "get_device_num", "get_rank_id", "get_job_id"] diff --git a/community/cv/ADCAM/model_utils/local_adapter.py b/community/cv/ADCAM/model_utils/local_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..dc0c2e6ef10ed25215c39992308e7025c10a3131 --- /dev/null +++ b/community/cv/ADCAM/model_utils/local_adapter.py @@ -0,0 +1,35 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""local adapter""" +import os + + +def get_device_id(): + device_id = os.getenv("DEVICE_ID", "0") + return int(device_id) + + +def get_device_num(): + device_num = os.getenv("RANK_SIZE", "1") + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv("RANK_ID", "0") + return int(global_rank_id) + + +def get_job_id(): + return "Local Job" diff --git a/community/cv/ADCAM/model_utils/moxing_adapter.py b/community/cv/ADCAM/model_utils/moxing_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..984a68e0ca4916357cc3e93a1015fb4e9abbf981 --- /dev/null +++ b/community/cv/ADCAM/model_utils/moxing_adapter.py @@ -0,0 +1,208 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""moxing adapter.""" +import os +import functools +import mindspore +from .config import config + +_global_sync_count = 0 + + +def get_device_id(): + device_id = os.getenv("DEVICE_ID", "0") + return int(device_id) + + +def get_device_num(): + device_num = os.getenv("RANK_SIZE", "1") + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv("RANK_ID", "0") + return int(global_rank_id) + + +def get_job_id(): + job_id = os.getenv("JOB_ID") + job_id = job_id if job_id != "" else "default" + return job_id + + +def sync_data(from_path, to_path): + """ + Download data from remote obs to local directory if the first url is remote url and the second one is local path + Upload data from local directory to remote obs in contrast. + """ + import moxing as mox + import time + + global _global_sync_count + sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count) + _global_sync_count += 1 + + # Each server contains 8 devices as most. + if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): + print("from path: ", from_path) + print("to path: ", to_path) + mox.file.copy_parallel(from_path, to_path) + print("===finish data synchronization===") + try: + os.mknod(sync_lock) + except IOError: + pass + print("===save flag===") + + while True: + if os.path.exists(sync_lock): + break + time.sleep(1) + + print("Finish sync data from {} to {}.".format(from_path, to_path)) + + +def modelarts_pre_process(args): + """modelarts pre process function.""" + + def unzip(zip_file, save_dir): + import zipfile + + s_time = time.time() + if not os.path.exists(os.path.join(save_dir, args.modelarts_dataset_unzip_name)): + zip_isexist = zipfile.is_zipfile(zip_file) + if zip_isexist: + fz = zipfile.ZipFile(zip_file, "r") + data_num = len(fz.namelist()) + print("Extract Start...") + print("unzip file num: {}".format(data_num)) + data_print = int(data_num / 100) if data_num > 100 else 1 + i = 0 + for file in fz.namelist(): + if i % data_print == 0: + print( + "unzip percent: {}%".format(int(i * 100 / data_num)), + flush=True, + ) + i += 1 + fz.extract(file, save_dir) + print( + "cost time: {}min:{}s.".format( + int((time.time() - s_time) / 60), + int(int(time.time() - s_time) % 60), + ) + ) + print("Extract Done.") + else: + print("This is not zip.") + else: + print("Zip has been extracted.") + + if args.need_modelarts_dataset_unzip: + zip_file_1 = os.path.join( + args.data_path, args.modelarts_dataset_unzip_name + ".zip" + ) + save_dir_1 = os.path.join(args.data_path) + + sync_lock = "/tmp/unzip_sync.lock" + + # Each server contains 8 devices as most. + if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): + print("Zip file path: ", zip_file_1) + print("Unzip file save dir: ", save_dir_1) + unzip(zip_file_1, save_dir_1) + print("===Finish extract data synchronization===") + try: + os.mknod(sync_lock) + except IOError: + pass + + while True: + if os.path.exists(sync_lock): + break + time.sleep(1) + + print( + "Device: {}, Finish sync unzip data from {} to {}.".format( + get_device_id(), zip_file_1, save_dir_1 + ) + ) + + args.output_dir = os.path.join(args.output_path, args.output_dir) + args.ckpt_path = os.path.join(args.output_path, args.ckpt_path) + + +def modelarts_post_process(): + sync_data(from_path="/cache/output", to_path="obs://hit-cyf/yolov5_npu/outputs/") + + +def modelarts_export_preprocess(args): + args.file_name = os.path.join(args.output_path, args.file_name) + + +def moxing_wrapper(pre_process=None, post_process=None, **kwargs): + """ + Moxing wrapper to download dataset and upload outputs. + """ + + def wrapper(run_func): + @functools.wraps(run_func) + def wrapped_func(*args, **kwargs): + # Download data from data_url + if config.enable_modelarts: + if config.data_url: + sync_data(config.data_url, config.data_path) + print("Dataset downloaded: ", os.listdir(config.data_path)) + if config.checkpoint_url: + sync_data(config.checkpoint_url, config.load_path) + print("Preload downloaded: ", os.listdir(config.load_path)) + if config.train_url: + sync_data(config.train_url, config.output_path) + print("Workspace downloaded: ", os.listdir(config.output_path)) + + mindspore.set_context( + save_graphs_path=os.path.join( + config.output_path, str(get_rank_id()) + ) + ) + config.device_num = get_device_num() + config.device_id = get_device_id() + if not os.path.exists(config.output_path): + os.makedirs(config.output_path) + + if pre_process: + if "pre_args" in kwargs.keys(): + pre_process(*kwargs["pre_args"]) + else: + pre_process() + + # Run the main function + run_func(*args, **kwargs) + + # Upload data to train_url + if config.enable_modelarts: + if post_process: + if "post_args" in kwargs.keys(): + post_process(*kwargs["post_args"]) + else: + post_process() + + if config.train_url: + print("Start to copy output directory") + sync_data(config.output_path, config.train_url) + + return wrapped_func + + return wrapper diff --git a/community/cv/ADCAM/modelarts/train_start.py b/community/cv/ADCAM/modelarts/train_start.py new file mode 100644 index 0000000000000000000000000000000000000000..d2068505d8bafd98e5ea4d91367afc1cf86f8aab --- /dev/null +++ b/community/cv/ADCAM/modelarts/train_start.py @@ -0,0 +1,202 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train_start.py""" +import os +import time +import numpy as np +import mindspore +import mindspore.nn as nn +import mindspore.communication as comm +from mindspore.train.serialization import export, load_checkpoint, load_param_into_net +from mindspore import Tensor + +from src.yolo import YOLOV5, YoloWithLossCell, YOLOV5s_Infer +from src.logger import get_logger +from src.util import AverageMeter, get_param_groups, cpu_affinity +from src.lr_scheduler import get_lr +from src.yolo_dataset import create_yolo_dataset +from src.initializer import default_recurisive_init, load_yolov5_params + +from model_utils.config import config +from model_utils.device_adapter import get_device_id +from model_utils.moxing_adapter import moxing_wrapper, modelarts_pre_process + + +mindspore.set_seed(1) + + +def init_distribute(): + comm.init() + config.rank = comm.get_rank() + config.group_size = comm.get_group_size() + mindspore.set_auto_parallel_context( + parallel_mode=mindspore.ParallelMode.DATA_PARALLEL, + gradients_mean=True, + device_num=config.group_size, + ) + + +def train_preprocess(): + """train_preprocess""" + if config.lr_scheduler == "cosine_annealing" and config.max_epoch > config.T_max: + config.T_max = config.max_epoch + + config.lr_epochs = list(map(int, config.lr_epochs.split(","))) + config.data_root = os.path.join(config.data_dir, config.train_img_dir) + config.annFile = os.path.join(config.data_dir, config.train_json_file) + if config.pretrained_checkpoint: + config.pretrained_checkpoint = os.path.join( + config.load_path, config.pretrained_checkpoint + ) + device_id = get_device_id() + mindspore.set_context( + mode=0, device_target=config.device_target, device_id=device_id + ) + + if config.is_distributed: + # init distributed + init_distribute() + + # for promoting performance in GPU device + if config.device_target == "GPU" and config.bind_cpu: + cpu_affinity(config.rank, min(config.group_size, config.device_num)) + + # logger module is managed by config, it is used in other function. e.x. config.logger.info("xxx") + config.logger = get_logger(config.output_dir, config.rank) + config.logger.save_args(config) + + +def export_models(ckpt_path): + """export_models""" + config.logger.info("exporting best model....") + dict_version = {"yolov5s": 0, "yolov5m": 1, "yolov5l": 2, "yolov5x": 3} + net = YOLOV5s_Infer( + config.testing_shape[0], version=dict_version[config.yolov5_version] + ) + net.set_train(False) + + outputs_path = os.path.join(config.output_dir, "yolov5") + param_dict = load_checkpoint(ckpt_path) + load_param_into_net(net, param_dict) + input_arr = Tensor( + np.zeros([1, 12, config.testing_shape[0] // 2, config.testing_shape[1] // 2]), + mindspore.float32, + ) + export(net, input_arr, file_name=outputs_path, file_format=config.file_format) + config.logger.info("export best model finished....") + + +@moxing_wrapper(pre_process=modelarts_pre_process, pre_args=[config]) +def run_train(): + """run_train""" + train_preprocess() + + loss_meter = AverageMeter("loss") + dict_version = {"yolov5s": 0, "yolov5m": 1, "yolov5l": 2, "yolov5x": 3} + network = YOLOV5(is_training=True, version=dict_version[config.yolov5_version]) + # default is kaiming-normal + default_recurisive_init(network) + load_yolov5_params(config, network) + network = YoloWithLossCell(network) + + ds = create_yolo_dataset( + image_dir=config.data_root, + anno_path=config.annFile, + is_training=True, + batch_size=config.per_batch_size, + device_num=config.group_size, + rank=config.rank, + config=config, + ) + config.logger.info("Finish loading dataset") + + steps_per_epoch = ds.get_dataset_size() + lr = get_lr(config, steps_per_epoch) + opt = nn.Momentum( + params=get_param_groups(network), + momentum=config.momentum, + learning_rate=mindspore.Tensor(lr), + weight_decay=config.weight_decay, + loss_scale=config.loss_scale, + ) + network = nn.TrainOneStepCell(network, opt, config.loss_scale // 2) + network.set_train() + + data_loader = ds.create_tuple_iterator(do_copy=False) + first_step = True + t_end = time.time() + + for epoch_idx in range(config.max_epoch): + for step_idx, data in enumerate(data_loader): + images = data[0] + input_shape = images.shape[2:4] + input_shape = mindspore.Tensor(tuple(input_shape[::-1]), mindspore.float32) + loss = network( + images, + data[2], + data[3], + data[4], + data[5], + data[6], + data[7], + input_shape, + ) + loss_meter.update(loss.asnumpy()) + + # it is used for loss, performance output per config.log_interval steps. + if (epoch_idx * steps_per_epoch + step_idx) % config.log_interval == 0: + time_used = time.time() - t_end + if first_step: + fps = config.per_batch_size * config.group_size / time_used + per_step_time = time_used * 1000 + first_step = False + else: + fps = ( + config.per_batch_size + * config.log_interval + * config.group_size + / time_used + ) + per_step_time = time_used / config.log_interval * 1000 + config.logger.info( + "epoch[{}], iter[{}], {}, fps:{:.2f} imgs/sec, " + "lr:{}, per step time: {}ms".format( + epoch_idx + 1, + step_idx + 1, + loss_meter, + fps, + lr[step_idx], + per_step_time, + ) + ) + t_end = time.time() + loss_meter.reset() + if config.rank == 0: + ckpt_name = os.path.join( + config.output_dir, + "yolov5_{}_{}.ckpt".format(epoch_idx + 1, steps_per_epoch), + ) + mindspore.save_checkpoint(network, ckpt_name) + export_models(ckpt_name) + config.logger.info("==========end training===============") + + if config.enable_modelarts: + import moxing as mox + + mox.file.copy_parallel(src_url=config.output_dir, dst_url=config.outputs_url) + + +if __name__ == "__main__": + run_train() diff --git a/community/cv/ADCAM/postprocess.py b/community/cv/ADCAM/postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..89477fb94d20fce281b717b4f0abdd3d1c212ba7 --- /dev/null +++ b/community/cv/ADCAM/postprocess.py @@ -0,0 +1,63 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""post process""" +import os +import time +import numpy as np +from pycocotools.coco import COCO +from src.logger import get_logger +from src.util import DetectionEngine +from model_utils.config import config + + +if __name__ == "__main__": + start_time = time.time() + config.output_dir = config.log_path + config.logger = get_logger(config.output_dir, 0) + + detection = DetectionEngine(config, config.test_ignore_threshold) + + coco = COCO(config.ann_file) + result_path = config.result_files + + files = os.listdir(config.dataset_path) + + for file in files: + img_ids_name = file.split('.')[0] + img_id_ = int(np.squeeze(img_ids_name)) + imgIds = coco.getImgIds(imgIds=[img_id_]) + img = coco.loadImgs(imgIds[np.random.randint(0, len(imgIds))])[0] + image_shape = ((img['width'], img['height']),) + img_id_ = (np.squeeze(img_ids_name),) + + result_path_0 = os.path.join(result_path, img_ids_name + "_0.bin") + result_path_1 = os.path.join(result_path, img_ids_name + "_1.bin") + result_path_2 = os.path.join(result_path, img_ids_name + "_2.bin") + + output_small = np.fromfile(result_path_0, dtype=np.float32).reshape(1, 20, 20, 3, 85) + output_me = np.fromfile(result_path_1, dtype=np.float32).reshape(1, 40, 40, 3, 85) + output_big = np.fromfile(result_path_2, dtype=np.float32).reshape(1, 80, 80, 3, 85) + + detection.detect([output_small, output_me, output_big], config.batch_size, image_shape, img_id_) + + config.logger.info('Calculating mAP...') + detection.do_nms_for_results() + detection.write_result() + eval_result, _ = detection.get_eval_result() + + cost_time = time.time() - start_time + config.logger.info('=============IP102 infer reulst=========') + config.logger.info(eval_result) + config.logger.info('testing cost time {:.2f}h'.format(cost_time / 3600.)) diff --git a/community/cv/ADCAM/requirements.txt b/community/cv/ADCAM/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..74cbc62920a3f893bbd49b13e1a4ccbb9c775d00 --- /dev/null +++ b/community/cv/ADCAM/requirements.txt @@ -0,0 +1,5 @@ +numpy +pillow +opencv-python +pycocotools >= 2.0.5 +onnxruntime-gpu diff --git a/community/cv/ADCAM/scripts/docker_start.sh b/community/cv/ADCAM/scripts/docker_start.sh new file mode 100644 index 0000000000000000000000000000000000000000..09926ada657ea243b35b415c84690e968cf44ba8 --- /dev/null +++ b/community/cv/ADCAM/scripts/docker_start.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +docker_image=$1 +data_dir=$2 +model_dir=$3 + +docker run -it --ipc=host \ + --device=/dev/davinci0 \ + --device=/dev/davinci1 \ + --device=/dev/davinci2 \ + --device=/dev/davinci3 \ + --device=/dev/davinci4 \ + --device=/dev/davinci5 \ + --device=/dev/davinci6 \ + --device=/dev/davinci7 \ + --device=/dev/davinci_manager \ + --device=/dev/devmm_svm --device=/dev/hisi_hdc \ + -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ + -v /usr/local/Ascend/add-ons/:/usr/local/Ascend/add-ons/ \ + -v ${model_dir}:${model_dir} \ + -v ${data_dir}:${data_dir} \ + -v /root/ascend/log:/root/ascend/log ${docker_image} /bin/bash \ No newline at end of file diff --git a/community/cv/ADCAM/scripts/run_distribute_train.sh b/community/cv/ADCAM/scripts/run_distribute_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..6f83102c65aa187031e6cc78af5f8836d2762be5 --- /dev/null +++ b/community/cv/ADCAM/scripts/run_distribute_train.sh @@ -0,0 +1,77 @@ +#!/bin/bash +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +if [ $# != 2 ] +then + echo "Usage: bash run_distribute_train.sh [DATASET_PATH] [RANK_TABLE_FILE]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +DATASET_PATH=$(get_real_path $1) +RANK_TABLE_FILE=$(get_real_path $2) +echo $DATASET_PATH +echo $RANK_TABLE_FILE + +if [ ! -d $DATASET_PATH ] +then + echo "error: DATASET_PATH=$DATASET_PATH is not a directory" +exit 1 +fi + +if [ ! -f $RANK_TABLE_FILE ] +then + echo "error: RANK_TABLE_FILE=$RANK_TABLE_FILE is not a file" +exit 1 +fi + +export DEVICE_NUM=8 +export RANK_SIZE=8 +export RANK_TABLE_FILE=$RANK_TABLE_FILE + +cpus=`cat /proc/cpuinfo| grep "processor"| wc -l` +avg=`expr $cpus \/ $DEVICE_NUM` +gap=`expr $avg \- 1` + +for((i=0; i<${DEVICE_NUM}; i++)) +do + start=`expr $i \* $avg` + end=`expr $start \+ $gap` + cmdopt=$start"-"$end + export DEVICE_ID=$i + export RANK_ID=$i + rm -rf ./train_parallel$i + mkdir ./train_parallel$i + cp ../*.py ./train_parallel$i + cp ../*.yaml ./train_parallel$i + cp -r ../src ./train_parallel$i + cp -r ../model_utils ./train_parallel$i + cd ./train_parallel$i || exit + echo "start training for rank $RANK_ID, device $DEVICE_ID" + env > env.log + taskset -c $cmdopt python train.py \ + --data_dir=$DATASET_PATH \ + --is_distributed=1 \ + --lr=0.02 \ + --per_batch_size=16 > log.txt 2>&1 & + cd .. +done diff --git a/community/cv/ADCAM/src/__init__.py b/community/cv/ADCAM/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40bc337dc255b845ff202691d41a4adc9e7032fa --- /dev/null +++ b/community/cv/ADCAM/src/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/community/cv/ADCAM/src/backbone.py b/community/cv/ADCAM/src/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..5bd457dc40656a6200e01e72df2e7b295bcc102a --- /dev/null +++ b/community/cv/ADCAM/src/backbone.py @@ -0,0 +1,251 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""DarkNet model""" +import mindspore.nn as nn +import mindspore.ops as ops + + +class CAAttention(nn.Cell): + """CAAttention""" + def __init__(self, inp, oup, groups=32, act=True): + super(CAAttention, self).__init__() + mip = max(8, inp // groups) + + self.conv1 = nn.Conv2d( + inp, mip, kernel_size=1, stride=1, padding=0, has_bias=True + ) + self.bn1 = nn.BatchNorm2d(mip) + self.conv2 = nn.Conv2d( + mip, oup, kernel_size=1, stride=1, padding=0, has_bias=True + ) + self.conv3 = nn.Conv2d( + mip, oup, kernel_size=1, stride=1, padding=0, has_bias=True + ) + self.relu = ( + SiLU() + if act is True + else (act if isinstance(act, nn.Cell) else ops.Identity()) + ) + + self.mean = ops.ReduceMean(keep_dims=True) + self.concat = ops.Concat(axis=2) + self.transpose = ops.Transpose() + self.sigmoid = ops.Sigmoid() + self.tile = ops.Tile() + + def construct(self, x): + """construct""" + identity = x + h, w = x.shape + + x_h = self.mean(x, 3) + x_w = self.mean(x, 2) + x_w = self.transpose(x_w, (0, 1, 3, 2)) + + y = self.concat((x_h, x_w)) + + y = self.conv1(y) + y = self.bn1(y) + y = self.relu(y) + + x_h, x_w = y[:, :, :h, :], y[:, :, h:, :] + + x_w = self.transpose(x_w, (0, 1, 3, 2)) + + x_h = self.sigmoid(self.conv2(x_h)) + x_w = self.sigmoid(self.conv3(x_w)) + + x_h = self.tile(x_h, (1, 1, 1, w)) + x_w = self.tile(x_w, (1, 1, h, 1)) + + y = identity * x_h * x_w + return y + identity + + +class Bottleneck(nn.Cell): + """Standard bottleneck""" + # ch_in, ch_out, shortcut, groups, expansion + def __init__(self, c1, c2, shortcut=True, e=0.5): + super(Bottleneck, self).__init__() + c_ = int(c2 * e) # hidden channels + self.conv1 = Conv(c1, c_, 1, 1) + self.conv2 = Conv(c_, c2, 3, 1) + self.add = shortcut and c1 == c2 + + def construct(self, x): + c1 = self.conv1(x) + c2 = self.conv2(c1) + out = c2 + if self.add: + out = x + out + return out + + +class BottleneckCSP(nn.Cell): + """CSP Bottleneck with 3 convolutions""" + def __init__(self, c1, c2, n=1, shortcut=True, e=0.5): + super(BottleneckCSP, self).__init__() + c_ = int(c2 * e) # hidden channels + self.conv1 = Conv(c1, c_, 1, 1) + self.conv2 = Conv(c1, c_, 1, 1) + self.conv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2) + self.m = nn.SequentialCell( + [Bottleneck(c_, c_, shortcut, e=1.0) for _ in range(n)] + ) + self.concat = ops.Concat(axis=1) + + def construct(self, x): + c1 = self.conv1(x) + c2 = self.m(c1) + c3 = self.conv2(x) + c4 = self.concat((c2, c3)) + c5 = self.conv3(c4) + return c5 + + +class BottleneckCSPWithCA(nn.Cell): + """CSP Bottleneck with 3 convolutions""" + def __init__(self, c1, c2, n=1, shortcut=True, e=0.5): + super(BottleneckCSPWithCA, self).__init__() + print("BottleneckCSPWithCA........", "c1===", c1, "c2===", c2) + c_ = int(c2 * e) # hidden channels + self.conv1 = Conv(c1, c_, 1, 1) + self.conv2 = Conv(c1, c_, 1, 1) + self.conv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2) + self.m = nn.SequentialCell( + [Bottleneck(c_, c_, shortcut, e=1.0) for _ in range(n)] + ) + self.concat = ops.Concat(axis=1) + self.coordatt = CAAttention(c1, c2) + + def construct(self, x): + c1 = self.conv1(x) + c2 = self.m(c1) + c3 = self.conv2(x) + c4 = self.concat((c2, c3)) + c5 = self.conv3(c4) + c6 = self.coordatt(c5) + return c6 + + +class SPP(nn.Cell): + """Spatial pyramid pooling layer used in YOLOv3-SPP""" + def __init__(self, c1, c2, k=(5, 9, 13)): + super(SPP, self).__init__() + c_ = c1 // 2 # hidden channels + self.conv1 = Conv(c1, c_, 1, 1) + self.conv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) + + self.maxpool1 = nn.MaxPool2d(kernel_size=5, stride=1, pad_mode="same") + self.maxpool2 = nn.MaxPool2d(kernel_size=9, stride=1, pad_mode="same") + self.maxpool3 = nn.MaxPool2d(kernel_size=13, stride=1, pad_mode="same") + self.concat = ops.Concat(axis=1) + + def construct(self, x): + c1 = self.conv1(x) + m1 = self.maxpool1(c1) + m2 = self.maxpool2(c1) + m3 = self.maxpool3(c1) + c4 = self.concat((c1, m1, m2, m3)) + c5 = self.conv2(c4) + return c5 + + +class Focus(nn.Cell): + """Focus wh information into c-space""" + def __init__(self, c1, c2, k=1, s=1, p=None, act=True): + super(Focus, self).__init__() + self.conv = Conv(c1 * 4, c2, k, s, p, act) + + def construct(self, x): + c1 = self.conv(x) + return c1 + + +class SiLU(nn.Cell): + def __init__(self): + super(SiLU, self).__init__() + self.sigmoid = ops.Sigmoid() + + def construct(self, x): + return x * self.sigmoid(x) + + +def auto_pad(k, p=None): + # kernel, padding + # Pad to 'same' + if p is None: + p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad + return p + + +class Conv(nn.Cell): + """Standard convolution""" + def __init__(self, c1, c2, k=1, s=1, p=None, dilation=1, alpha=0.1, + momentum=0.97, eps=1e-3, pad_mode="same", act=True,): + # ch_in, ch_out, kernel, stride, padding + super(Conv, self).__init__() + self.padding = auto_pad(k, p) + self.pad_mode = None + if self.padding == 0: + self.pad_mode = "same" + elif self.padding == 1: + self.pad_mode = "pad" + self.conv = nn.Conv2d( + c1, c2, k, s, padding=self.padding, pad_mode=self.pad_mode, has_bias=False + ) + self.bn = nn.BatchNorm2d(c2, momentum=momentum, eps=eps) + self.act = ( + SiLU() + if act is True + else (act if isinstance(act, nn.Cell) else ops.Identity()) + ) + + def construct(self, x): + return self.act(self.bn(self.conv(x))) + + +class YOLOv5Backbone(nn.Cell): + """YOLOv5 backbone""" + def __init__(self, shape): + super(YOLOv5Backbone, self).__init__() + self.focus = Focus(shape[0], shape[1], k=3, s=1) + self.conv1 = Conv(shape[1], shape[2], k=3, s=2) + self.CSP1 = BottleneckCSP(shape[2], shape[2], n=1 * shape[6]) + self.conv2 = Conv(shape[2], shape[3], k=3, s=2) + self.CSP2 = BottleneckCSP(shape[3], shape[3], n=3 * shape[6]) + self.conv3 = Conv(shape[3], shape[4], k=3, s=2) + self.CSP3 = BottleneckCSP(shape[4], shape[4], n=3 * shape[6]) + self.conv4 = Conv(shape[4], shape[5], k=3, s=2) + self.spp = SPP(shape[5], shape[5], k=[5, 9, 13]) + self.CSP4 = BottleneckCSP(shape[5], shape[5], n=1 * shape[6], shortcut=False) + + def construct(self, x): + """construct method""" + c1 = self.focus(x) + c2 = self.conv1(c1) + c3 = self.CSP1(c2) + c4 = self.conv2(c3) + # out + c5 = self.CSP2(c4) + c6 = self.conv3(c5) + # out + c7 = self.CSP3(c6) + c8 = self.conv4(c7) + c9 = self.spp(c8) + # out + c10 = self.CSP4(c9) + + return c5, c7, c10 diff --git a/community/cv/ADCAM/src/distributed_sampler.py b/community/cv/ADCAM/src/distributed_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..5a729b9ee5de24f2bcbb27d42c4294012a4e1aa0 --- /dev/null +++ b/community/cv/ADCAM/src/distributed_sampler.py @@ -0,0 +1,68 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""distributed_sampler.py""" +from __future__ import division +import math +import numpy as np + + +class DistributedSampler: + """Distributed sampler.""" + + def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True): + if num_replicas is None: + print( + "***********Setting world_size to 1 since it is not passed in ******************" + ) + num_replicas = 1 + if rank is None: + print( + "***********Setting rank to 0 since it is not passed in ******************" + ) + rank = 0 + self.dataset_size = dataset_size + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + + def __iter__(self): + # deterministically shuffle based on epoch + if self.shuffle: + indices = np.random.RandomState(seed=self.epoch).permutation( + self.dataset_size + ) + # np.array type. number from 0 to len(dataset_size)-1, used as + # index of dataset + indices = indices.tolist() + self.epoch += 1 + # change to list type + else: + indices = list(range(self.dataset_size)) + + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples diff --git a/community/cv/ADCAM/src/initializer.py b/community/cv/ADCAM/src/initializer.py new file mode 100644 index 0000000000000000000000000000000000000000..f35fd0b8fb82e2f2f29aeeca39e0603f26061f33 --- /dev/null +++ b/community/cv/ADCAM/src/initializer.py @@ -0,0 +1,90 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""initializer.""" +import math +import mindspore +from mindspore import nn + + +def default_recurisive_init(custom_cell): + """Initialize parameter.""" + + +for _, cell in custom_cell.cells_and_names(): + if isinstance(cell, (nn.Conv2d, nn.Dense)): + cell.weight.set_data( + mindspore.common.initializer.initializer( + mindspore.common.initializer.HeUniform(math.sqrt(5)), + cell.weight.shape, + cell.weight.dtype, + ) + ) + + +def load_yolov5_params(args, network): + """Load yolov5 backbone parameter from checkpoint.""" + if args.resume_yolov5: + param_dict = mindspore.load_checkpoint(args.resume_yolov5) + param_dict_new = {} + for key, values in param_dict.items(): + if key.startswith("moments."): + continue + elif key.startswith("yolo_network."): + param_dict_new[key[13:]] = values + args.logger.info("in resume {}".format(key)) + else: + param_dict_new[key] = values + args.logger.info("in resume {}".format(key)) + + args.logger.info("resume finished") + mindspore.load_param_into_net(network, param_dict_new) + args.logger.info("load_model {} success".format(args.resume_yolov5)) + + if args.pretrained_checkpoint: + param_dict = mindspore.load_checkpoint(args.pretrained_checkpoint) + param_dict_new = {} + for key, values in param_dict.items(): + if key.startswith("moments."): + continue + elif (key.startswith("yolo_network.") and key[13:] in args.checkpoint_filter_list): + args.logger.info("remove {}".format(key)) + continue + elif key.startswith("yolo_network."): + param_dict_new[key[13:]] = values + args.logger.info("in load {}".format(key)) + else: + param_dict_new[key] = values + args.logger.info("in load {}".format(key)) + + args.logger.info("pretrained finished") + mindspore.load_param_into_net(network, param_dict_new) + args.logger.info("load_model {} success".format(args.pretrained_backbone)) + + if args.pretrained_backbone: + param_dict = mindspore.load_checkpoint(args.pretrained_backbone) + param_dict_new = {} + for key, values in param_dict.items(): + if key.startswith("moments."): + continue + elif key.startswith("yolo_network."): + param_dict_new[key[13:]] = values + args.logger.info("in resume {}".format(key)) + else: + param_dict_new[key] = values + args.logger.info("in resume {}".format(key)) + + args.logger.info("pretrained finished") + mindspore.load_param_into_net(network, param_dict_new) + args.logger.info("load_model {} success".format(args.pretrained_backbone)) diff --git a/community/cv/ADCAM/src/logger.py b/community/cv/ADCAM/src/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..0a9c0915127e430ee27576752db85d3dbbd7f9ca --- /dev/null +++ b/community/cv/ADCAM/src/logger.py @@ -0,0 +1,83 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Logger.""" +import os +import sys +import logging +from datetime import datetime + + +class LOGGER(logging.Logger): + """ + Logger. + + Args: + logger_name: String. Logger name. + rank: Integer. Rank id. + """ + + def __init__(self, logger_name, rank=0): + super(LOGGER, self).__init__(logger_name) + self.rank = rank + if rank % 8 == 0: + console = logging.StreamHandler(sys.stdout) + console.setLevel(logging.INFO) + formatter = logging.Formatter("%(asctime)s:%(levelname)s:%(message)s") + console.setFormatter(formatter) + self.addHandler(console) + + def setup_logging_file(self, log_dir, rank=0): + """Setup logging file.""" + self.rank = rank + if not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + log_name = datetime.now().strftime( + "%Y-%m-%d_time_%H_%M_%S" + ) + "_rank_{}.log".format(rank) + self.log_fn = os.path.join(log_dir, log_name) + fh = logging.FileHandler(self.log_fn) + fh.setLevel(logging.INFO) + formatter = logging.Formatter("%(asctime)s:%(levelname)s:%(message)s") + fh.setFormatter(formatter) + self.addHandler(fh) + + def info(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.INFO): + self._log(logging.INFO, msg, args, **kwargs) + + def save_args(self, args): + self.info("Args:") + args_dict = vars(args) + for key in args_dict.keys(): + self.info("--> %s: %s", key, args_dict[key]) + self.info("") + + def important_info(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.INFO) and self.rank == 0: + line_width = 2 + important_msg = "\n" + important_msg += ("*" * 70 + "\n") * line_width + important_msg += ("*" * line_width + "\n") * 2 + important_msg += "*" * line_width + " " * 8 + msg + "\n" + important_msg += ("*" * line_width + "\n") * 2 + important_msg += ("*" * 70 + "\n") * line_width + self.info(important_msg, *args, **kwargs) + + +def get_logger(path, rank): + """Get Logger.""" + logger = LOGGER("YOLOV5", rank) + logger.setup_logging_file(path, rank) + return logger diff --git a/community/cv/ADCAM/src/loss.py b/community/cv/ADCAM/src/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..96d6f7d5d4e877708b479a47a26c0feb7a912b31 --- /dev/null +++ b/community/cv/ADCAM/src/loss.py @@ -0,0 +1,49 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""loss.""" +import mindspore.ops as ops +import mindspore.nn as nn + + +class ConfidenceLoss(nn.Cell): + """Loss for confidence.""" + + def __init__(self): + super(ConfidenceLoss, self).__init__() + self.cross_entropy = ops.SigmoidCrossEntropyWithLogits() + self.reduce_sum = ops.ReduceSum() + + def construct(self, object_mask, predict_confidence, ignore_mask): + confidence_loss = self.cross_entropy(predict_confidence, object_mask) + confidence_loss = ( + object_mask * confidence_loss + + (1 - object_mask) * confidence_loss * ignore_mask + ) + confidence_loss = self.reduce_sum(confidence_loss, ()) + return confidence_loss + + +class ClassLoss(nn.Cell): + """Loss for classification.""" + + def __init__(self): + super(ClassLoss, self).__init__() + self.cross_entropy = ops.SigmoidCrossEntropyWithLogits() + self.reduce_sum = ops.ReduceSum() + + def construct(self, object_mask, predict_class, class_probs): + class_loss = object_mask * self.cross_entropy(predict_class, class_probs) + class_loss = self.reduce_sum(class_loss, ()) + return class_loss diff --git a/community/cv/ADCAM/src/lr_scheduler.py b/community/cv/ADCAM/src/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..8e14bb62c681ffb8fbd505f6bc24da2b9138ef21 --- /dev/null +++ b/community/cv/ADCAM/src/lr_scheduler.py @@ -0,0 +1,207 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""lr_scheduler.py""" +import math +from collections import Counter + +import numpy as np + + +def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): + """Linear learning rate.""" + lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) + lr = float(init_lr) + lr_inc * current_step + return lr + + +def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1): + """Warmup step learning rate.""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + milestones = lr_epochs + milestones_steps = [] + for milestone in milestones: + milestones_step = milestone * steps_per_epoch + milestones_steps.append(milestones_step) + + lr_each_step = [] + lr = base_lr + milestones_steps_counter = Counter(milestones_steps) + for i in range(total_steps): + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = lr * gamma ** milestones_steps_counter[i] + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1): + return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma) + + +def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1): + lr_epochs = [] + for i in range(1, max_epoch): + if i % epoch_size == 0: + lr_epochs.append(i) + return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma) + + +def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): + """Cosine annealing learning rate.""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + lr_each_step = [] + for i in range(total_steps): + last_epoch = i // steps_per_epoch + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = ( + eta_min + + (base_lr - eta_min) + * (1.0 + math.cos(math.pi * last_epoch / T_max)) + / 2 + ) + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +def warmup_cosine_annealing_lr_V2(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): + """Cosine annealing learning rate V2.""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + last_lr = 0 + last_epoch_V1 = 0 + + T_max_V2 = int(max_epoch * 1 / 3) + + lr_each_step = [] + for i in range(total_steps): + last_epoch = i // steps_per_epoch + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + if i < total_steps * 2 / 3: + lr = ( + eta_min + + (base_lr - eta_min) + * (1.0 + math.cos(math.pi * last_epoch / T_max)) + / 2 + ) + last_lr = lr + last_epoch_V1 = last_epoch + else: + base_lr = last_lr + last_epoch = last_epoch - last_epoch_V1 + lr = ( + eta_min + + (base_lr - eta_min) + * (1.0 + math.cos(math.pi * last_epoch / T_max_V2)) + / 2 + ) + + lr_each_step.append(lr) + return np.array(lr_each_step).astype(np.float32) + + +def warmup_cosine_annealing_lr_sample(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): + """Warmup cosine annealing learning rate.""" + start_sample_epoch = 60 + step_sample = 2 + tobe_sampled_epoch = 60 + end_sampled_epoch = start_sample_epoch + step_sample * tobe_sampled_epoch + max_sampled_epoch = max_epoch + tobe_sampled_epoch + T_max = max_sampled_epoch + + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + total_sampled_steps = int(max_sampled_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + lr_each_step = [] + + for i in range(total_sampled_steps): + last_epoch = i // steps_per_epoch + if last_epoch in range(start_sample_epoch, end_sampled_epoch, step_sample): + continue + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = ( + eta_min + + (base_lr - eta_min) + * (1.0 + math.cos(math.pi * last_epoch / T_max)) + / 2 + ) + lr_each_step.append(lr) + + assert total_steps == len(lr_each_step) + return np.array(lr_each_step).astype(np.float32) + + +def get_lr(args, steps_per_epoch): + """generate learning rate.""" + if args.lr_scheduler == "exponential": + lr = warmup_step_lr( + args.lr, + args.lr_epochs, + steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + gamma=args.lr_gamma, + ) + elif args.lr_scheduler == "cosine_annealing": + lr = warmup_cosine_annealing_lr( + args.lr, + steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + args.T_max, + args.eta_min, + ) + elif args.lr_scheduler == "cosine_annealing_V2": + lr = warmup_cosine_annealing_lr_V2( + args.lr, + steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + args.T_max, + args.eta_min, + ) + elif args.lr_scheduler == "cosine_annealing_sample": + lr = warmup_cosine_annealing_lr_sample( + args.lr, + steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + args.T_max, + args.eta_min, + ) + else: + raise NotImplementedError(args.lr_scheduler) + return lr diff --git a/community/cv/ADCAM/src/transforms.py b/community/cv/ADCAM/src/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..bfbe8149c812fde461854436acd22f5bc0269d0e --- /dev/null +++ b/community/cv/ADCAM/src/transforms.py @@ -0,0 +1,526 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Transforms.""" +import random +import threading +import copy +import numpy as np +from PIL import Image +import cv2 +import mindspore.dataset.vision as vision + + +def _rand(a=0., b=1.): + return np.random.rand() * (b - a) + a + + +def bbox_iou(bbox_a, bbox_b, offset=0): + """Calculate the iou of boxes""" + if bbox_a.shape[1] < 4 or bbox_b.shape[1] < 4: + raise IndexError("Bounding boxes axis 1 must have at least length 4") + + tl = np.maximum(bbox_a[:, None, :2], bbox_b[:, :2]) + br = np.minimum(bbox_a[:, None, 2:4], bbox_b[:, 2:4]) + + area_i = np.prod(br - tl + offset, axis=2) * (tl < br).all(axis=2) + area_a = np.prod(bbox_a[:, 2:4] - bbox_a[:, :2] + offset, axis=1) + area_b = np.prod(bbox_b[:, 2:4] - bbox_b[:, :2] + offset, axis=1) + return area_i / (area_a[:, None] + area_b - area_i) + + +def get_interp_method(interp, sizes=()): + """Get the interpolation method for resize""" + if interp == 9: + if sizes: + assert len(sizes) == 4 + oh, ow, nh, nw = sizes + if nh > oh and nw > ow: + return 2 + if nh < oh and nw < ow: + return 0 + return 1 + return 2 + if interp == 10: + return random.randint(0, 4) + if interp not in (0, 1, 2, 3, 4): + raise ValueError('Unknown interp method %d' % interp) + return interp + + +def pil_image_reshape(interp): + reshape_type = { + 0: Image.NEAREST, + 1: Image.BILINEAR, + 2: Image.BICUBIC, + 3: Image.NEAREST, + 4: Image.LANCZOS, + } + return reshape_type[interp] + + +def _preprocess_true_boxes(true_boxes, anchors, in_shape, num_classes, max_boxes, label_smooth, + label_smooth_factor=0.1, iou_threshold=0.213): + """Preprocess true boxes to training input format""" + anchors = np.array(anchors) + num_layers = anchors.shape[0] // 3 + anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + true_boxes = np.array(true_boxes, dtype='float32') + input_shape = np.array(in_shape, dtype='int32') + boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2. + # trans to box center point + boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2] + # input_shape is [h, w] + true_boxes[..., 0:2] = boxes_xy / input_shape[::-1] + true_boxes[..., 2:4] = boxes_wh / input_shape[::-1] + # true_boxes = [xywh] + grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8] + # grid_shape [h, w] + y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]), + 5 + num_classes), dtype='float32') for l in range(num_layers)] + # y_true [gridy, gridx] + anchors = np.expand_dims(anchors, 0) + anchors_max = anchors / 2. + anchors_min = -anchors_max + valid_mask = boxes_wh[..., 0] > 0 + wh = boxes_wh[valid_mask] + if wh.size != 0: + wh = np.expand_dims(wh, -2) + # wh shape[box_num, 1, 2] + boxes_max = wh / 2. + boxes_min = -boxes_max + intersect_min = np.maximum(boxes_min, anchors_min) + intersect_max = np.minimum(boxes_max, anchors_max) + intersect_wh = np.maximum(intersect_max - intersect_min, 0.) + intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] + box_area = wh[..., 0] * wh[..., 1] + anchor_area = anchors[..., 0] * anchors[..., 1] + iou = intersect_area / (box_area + anchor_area - intersect_area) + + # topk iou + topk = 4 + topk_flag = iou.argsort() + topk_flag = topk_flag >= topk_flag.shape[1] - topk + flag = topk_flag.nonzero() + for index in range(len(flag[0])): + t = flag[0][index] + n = flag[1][index] + if iou[t][n] < iou_threshold: + continue + for l in range(num_layers): + if n in anchor_mask[l]: + i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32') # grid_y + j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32') # grid_x + + k = anchor_mask[l].index(n) + c = true_boxes[t, 4].astype('int32') + y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4] + y_true[l][j, i, k, 4] = 1. + + # lable-smooth + if label_smooth: + sigma = label_smooth_factor / (num_classes - 1) + y_true[l][j, i, k, 5:] = sigma + y_true[l][j, i, k, 5 + c] = 1 - label_smooth_factor + else: + y_true[l][j, i, k, 5 + c] = 1. + # best anchor for gt + best_anchor = np.argmax(iou, axis=-1) + for t, n in enumerate(best_anchor): + for l in range(num_layers): + if n in anchor_mask[l]: + i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32') # grid_y + j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32') # grid_x + + k = anchor_mask[l].index(n) + c = true_boxes[t, 4].astype('int32') + y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4] + y_true[l][j, i, k, 4] = 1. + + # lable-smooth + if label_smooth: + sigma = label_smooth_factor / (num_classes - 1) + y_true[l][j, i, k, 5:] = sigma + y_true[l][j, i, k, 5 + c] = 1 - label_smooth_factor + else: + y_true[l][j, i, k, 5 + c] = 1. + + # pad_gt_boxes for avoiding dynamic shape + pad_gt_box0 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) + pad_gt_box1 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) + pad_gt_box2 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) + + mask0 = np.reshape(y_true[0][..., 4:5], [-1]) + gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4]) + # gt_box [boxes, [x,y,w,h]] + gt_box0 = gt_box0[mask0 == 1] + # gt_box0: get all boxes which have object + if gt_box0.shape[0] < max_boxes: + pad_gt_box0[:gt_box0.shape[0]] = gt_box0 + else: + pad_gt_box0 = gt_box0[:max_boxes] + # gt_box0.shape[0]: total number of boxes in gt_box0 + # top N of pad_gt_box0 is real box, and after are pad by zero + + mask1 = np.reshape(y_true[1][..., 4:5], [-1]) + gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4]) + gt_box1 = gt_box1[mask1 == 1] + if gt_box1.shape[0] < max_boxes: + pad_gt_box1[:gt_box1.shape[0]] = gt_box1 + else: + pad_gt_box1 = gt_box1[:max_boxes] + + mask2 = np.reshape(y_true[2][..., 4:5], [-1]) + gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4]) + + gt_box2 = gt_box2[mask2 == 1] + if gt_box2.shape[0] < max_boxes: + pad_gt_box2[:gt_box2.shape[0]] = gt_box2 + else: + pad_gt_box2 = gt_box2[:max_boxes] + return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2 + + +class PreprocessTrueBox: + """PreprocessTrueBox.""" + def __init__(self, config): + self.anchor_scales = config.anchor_scales + self.num_classes = config.num_classes + self.max_box = config.max_box + self.label_smooth = config.label_smooth + self.label_smooth_factor = config.label_smooth_factor + + def __call__(self, anno, input_shape): + bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ + _preprocess_true_boxes(true_boxes=anno, anchors=self.anchor_scales, in_shape=input_shape, + num_classes=self.num_classes, max_boxes=self.max_box, + label_smooth=self.label_smooth, label_smooth_factor=self.label_smooth_factor) + return anno, np.array(bbox_true_1), np.array(bbox_true_2), np.array(bbox_true_3), \ + np.array(gt_box1), np.array(gt_box2), np.array(gt_box3) + + +def _reshape_data(image, image_size): + """_reshape_data.""" + if not isinstance(image, Image.Image): + image = Image.fromarray(image) + ori_w, ori_h = image.size + ori_image_shape = np.array([ori_w, ori_h], np.int32) + # original image shape fir:H sec:W + h, w = image_size + interp = get_interp_method(interp=9, sizes=(ori_h, ori_w, h, w)) + image = image.resize((w, h), pil_image_reshape(interp)) + image_data = np.array(image) + if len(image_data.shape) == 2: + image_data = np.expand_dims(image_data, axis=-1) + image_data = np.concatenate([image_data, image_data, image_data], axis=-1) + return image_data, ori_image_shape + + +def color_distortion(img, hue, sat, val, device_num): + """color_distortion.""" + hue = _rand(-hue, hue) + sat = _rand(1, sat) if _rand() < .5 else 1 / _rand(1, sat) + val = _rand(1, val) if _rand() < .5 else 1 / _rand(1, val) + if device_num != 1: + cv2.setNumThreads(1) + x = cv2.cvtColor(img, cv2.COLOR_RGB2HSV_FULL) + x = x / 255. + x[..., 0] += hue + x[..., 0][x[..., 0] > 1] -= 1 + x[..., 0][x[..., 0] < 0] += 1 + x[..., 1] *= sat + x[..., 2] *= val + x[x > 1] = 1 + x[x < 0] = 0 + x = x * 255. + x = x.astype(np.uint8) + image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB_FULL) + return image_data + + +def filp_pil_image(img): + return img.transpose(Image.FLIP_LEFT_RIGHT) + + +def convert_gray_to_color(img): + if len(img.shape) == 2: + img = np.expand_dims(img, axis=-1) + img = np.concatenate([img, img, img], axis=-1) + return img + + +def _is_iou_satisfied_constraint(min_iou, max_iou, box, crop_box): + iou = bbox_iou(box, crop_box) + return min_iou <= iou.min() and max_iou >= iou.max() + + +def _choose_candidate_by_constraints(max_trial, input_w, input_h, image_w, image_h, jitter, box, use_constraints): + """_choose_candidate_by_constraints.""" + if use_constraints: + constraints = ( + (0.1, None), + (0.3, None), + (0.5, None), + (0.7, None), + (0.9, None), + (None, 1), + ) + else: + constraints = ((None, None),) + # add default candidate + candidates = [(0, 0, input_w, input_h)] + for constraint in constraints: + min_iou, max_iou = constraint + min_iou = -np.inf if min_iou is None else min_iou + max_iou = np.inf if max_iou is None else max_iou + + for _ in range(max_trial): + # box_data should have at least one box + new_ar = float(input_w) / float(input_h) * _rand(1 - jitter, 1 + jitter) / _rand(1 - jitter, 1 + jitter) + scale = _rand(0.5, 2) + + if new_ar < 1: + nh = int(scale * input_h) + nw = int(nh * new_ar) + else: + nw = int(scale * input_w) + nh = int(nw / new_ar) + + dx = int(_rand(0, input_w - nw)) + dy = int(_rand(0, input_h - nh)) + + if box.size > 0: + t_box = copy.deepcopy(box) + t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(image_w) + dx + t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(image_h) + dy + + crop_box = np.array((0, 0, input_w, input_h)) + if not _is_iou_satisfied_constraint(min_iou, max_iou, t_box, crop_box[np.newaxis]): + continue + else: + candidates.append((dx, dy, nw, nh)) + else: + raise Exception("!!! annotation box is less than 1") + return candidates + + +def _correct_bbox_by_candidates(candidates, input_w, input_h, image_w, + image_h, flip, box, box_data, allow_outside_center, max_boxes): + """_correct_bbox_by_candidates.""" + while candidates: + if len(candidates) > 1: + # ignore default candidate which do not crop + candidate = candidates.pop(np.random.randint(1, len(candidates))) + else: + candidate = candidates.pop(np.random.randint(0, len(candidates))) + dx, dy, nw, nh = candidate + t_box = copy.deepcopy(box) + t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(image_w) + dx + t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(image_h) + dy + if flip: + t_box[:, [0, 2]] = input_w - t_box[:, [2, 0]] + + if allow_outside_center: + pass + else: + t_box = t_box[ + np.logical_and((t_box[:, 0] + t_box[:, 2]) / 2. >= 0., (t_box[:, 1] + t_box[:, 3]) / 2. >= 0.)] + t_box = t_box[np.logical_and((t_box[:, 0] + t_box[:, 2]) / 2. <= input_w, + (t_box[:, 1] + t_box[:, 3]) / 2. <= input_h)] + + # recorrect x, y for case x,y < 0 reset to zero, after dx and dy, some box can smaller than zero + t_box[:, 0:2][t_box[:, 0:2] < 0] = 0 + # recorrect w,h not higher than input size + t_box[:, 2][t_box[:, 2] > input_w] = input_w + t_box[:, 3][t_box[:, 3] > input_h] = input_h + box_w = t_box[:, 2] - t_box[:, 0] + box_h = t_box[:, 3] - t_box[:, 1] + # discard invalid box: w or h smaller than 1 pixel + t_box = t_box[np.logical_and(box_w > 1, box_h > 1)] + + if t_box.shape[0] > 0: + # break if number of find t_box + box_data[: len(t_box)] = t_box + return box_data, candidate + return np.zeros(shape=[max_boxes, 5], dtype=np.float64), (0, 0, nw, nh) + + +def _data_aug(image, box, jitter, hue, sat, val, image_input_size, max_boxes, + anchors, num_classes, max_trial=10, device_num=1): + """_data_aug.""" + if not isinstance(image, Image.Image): + image = Image.fromarray(image) + + image_w, image_h = image.size + input_h, input_w = image_input_size + + np.random.shuffle(box) + if len(box) > max_boxes: + box = box[:max_boxes] + flip = _rand() < .5 + box_data = np.zeros((max_boxes, 5)) + + candidates = _choose_candidate_by_constraints(use_constraints=False, max_trial=max_trial, input_w=input_w, + input_h=input_h, image_w=image_w, image_h=image_h, + jitter=jitter, box=box) + box_data, candidate = _correct_bbox_by_candidates(candidates=candidates, input_w=input_w, input_h=input_h, + image_w=image_w, image_h=image_h, flip=flip, box=box, + box_data=box_data, allow_outside_center=True, max_boxes=max_boxes) + dx, dy, nw, nh = candidate + interp = get_interp_method(interp=10) + image = image.resize((nw, nh), pil_image_reshape(interp)) + # place image, gray color as back graoud + new_image = Image.new('RGB', (input_w, input_h), (128, 128, 128)) + new_image.paste(image, (dx, dy)) + image = new_image + + if flip: + image = filp_pil_image(image) + + image = np.array(image) + image = convert_gray_to_color(image) + image_data = color_distortion(image, hue, sat, val, device_num) + return image_data, box_data + + +def preprocess_fn(image, box, config, input_size, device_num): + """preprocess_fn.""" + config_anchors = config.anchor_scales + anchors = np.array([list(x) for x in config_anchors]) + max_boxes = config.max_box + num_classes = config.num_classes + jitter = config.jitter + hue = config.hue + sat = config.saturation + val = config.value + image, anno = _data_aug(image, box, jitter=jitter, hue=hue, sat=sat, val=val, + image_input_size=input_size, max_boxes=max_boxes, + num_classes=num_classes, anchors=anchors, device_num=device_num) + return image, anno + + +def reshape_fn(image, img_id, config): + """reshape_fn.""" + input_size = config.test_img_shape + image, ori_image_shape = _reshape_data(image, image_size=input_size) + return image, ori_image_shape, img_id + + +class MultiScaleTrans: + """MultiScaleTrans.""" + def __init__(self, config, device_num): + self.config = config + self.seed = 0 + self.size_list = [] + self.resize_rate = config.resize_rate + self.dataset_size = config.dataset_size + self.size_dict = {} + self.seed_num = int(1e6) + self.seed_list = self.generate_seed_list(seed_num=self.seed_num) + self.resize_count_num = int(np.ceil(self.dataset_size / self.resize_rate)) + self.device_num = device_num + self.anchor_scales = config.anchor_scales + self.num_classes = config.num_classes + self.max_box = config.max_box + self.label_smooth = config.label_smooth + self.label_smooth_factor = config.label_smooth_factor + + def generate_seed_list(self, init_seed=1234, seed_num=int(1e6), seed_range=(1, 1000)): + seed_list = [] + random.seed(init_seed) + for _ in range(seed_num): + seed = random.randint(seed_range[0], seed_range[1]) + seed_list.append(seed) + return seed_list + + def __call__(self, img, anno, input_size, mosaic_flag): + if mosaic_flag[0] == 0: + img = vision.Decode(True)(img) + img, anno = preprocess_fn(img, anno, self.config, input_size, self.device_num) + return img, anno, np.array(img.shape[0:2]) + + +def thread_batch_preprocess_true_box(annos, config, input_shape, result_index, batch_bbox_true_1, batch_bbox_true_2, + batch_bbox_true_3, batch_gt_box1, batch_gt_box2, batch_gt_box3): + """thread_batch_preprocess_true_box.""" + i = 0 + for anno in annos: + bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ + _preprocess_true_boxes(true_boxes=anno, anchors=config.anchor_scales, in_shape=input_shape, + num_classes=config.num_classes, max_boxes=config.max_box, + label_smooth=config.label_smooth, label_smooth_factor=config.label_smooth_factor) + batch_bbox_true_1[result_index + i] = bbox_true_1 + batch_bbox_true_2[result_index + i] = bbox_true_2 + batch_bbox_true_3[result_index + i] = bbox_true_3 + batch_gt_box1[result_index + i] = gt_box1 + batch_gt_box2[result_index + i] = gt_box2 + batch_gt_box3[result_index + i] = gt_box3 + i = i + 1 + + +def batch_preprocess_true_box(annos, config, input_shape): + """batch_preprocess_true_box.""" + batch_bbox_true_1 = [] + batch_bbox_true_2 = [] + batch_bbox_true_3 = [] + batch_gt_box1 = [] + batch_gt_box2 = [] + batch_gt_box3 = [] + threads = [] + + step = 4 + for index in range(0, len(annos), step): + for _ in range(step): + batch_bbox_true_1.append(None) + batch_bbox_true_2.append(None) + batch_bbox_true_3.append(None) + batch_gt_box1.append(None) + batch_gt_box2.append(None) + batch_gt_box3.append(None) + step_anno = annos[index: index + step] + t = threading.Thread(target=thread_batch_preprocess_true_box, + args=(step_anno, config, input_shape, index, batch_bbox_true_1, batch_bbox_true_2, + batch_bbox_true_3, batch_gt_box1, batch_gt_box2, batch_gt_box3)) + t.start() + threads.append(t) + + for t in threads: + t.join() + + return np.array(batch_bbox_true_1), np.array(batch_bbox_true_2), np.array(batch_bbox_true_3), \ + np.array(batch_gt_box1), np.array(batch_gt_box2), np.array(batch_gt_box3) + + +def batch_preprocess_true_box_single(annos, config, input_shape): + """batch_preprocess_true_box_single.""" + batch_bbox_true_1 = [] + batch_bbox_true_2 = [] + batch_bbox_true_3 = [] + batch_gt_box1 = [] + batch_gt_box2 = [] + batch_gt_box3 = [] + for anno in annos: + bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ + _preprocess_true_boxes(true_boxes=anno, anchors=config.anchor_scales, in_shape=input_shape, + num_classes=config.num_classes, max_boxes=config.max_box, + label_smooth=config.label_smooth, label_smooth_factor=config.label_smooth_factor) + batch_bbox_true_1.append(bbox_true_1) + batch_bbox_true_2.append(bbox_true_2) + batch_bbox_true_3.append(bbox_true_3) + batch_gt_box1.append(gt_box1) + batch_gt_box2.append(gt_box2) + batch_gt_box3.append(gt_box3) + + return np.array(batch_bbox_true_1), np.array(batch_bbox_true_2), np.array(batch_bbox_true_3), \ + np.array(batch_gt_box1), np.array(batch_gt_box2), np.array(batch_gt_box3) diff --git a/community/cv/ADCAM/src/util.py b/community/cv/ADCAM/src/util.py new file mode 100644 index 0000000000000000000000000000000000000000..3a4d77d4d25066853b9dacdc8742503bebdc1227 --- /dev/null +++ b/community/cv/ADCAM/src/util.py @@ -0,0 +1,506 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Util class or function""" +import os +import sys +from collections import defaultdict +import datetime +import copy +import json +from typing import Union, List +import numpy as np +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + +import mindspore +import mindspore.nn as nn +from mindspore import Tensor, ops + +from .yolo import YoloLossBlock + + +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=':f', tb_writer=None): + self.name = name + self.fmt = fmt + self.reset() + self.tb_writer = tb_writer + self.cur_step = 1 + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + if self.tb_writer is not None: + self.tb_writer.add_scalar(self.name, self.val, self.cur_step) + self.cur_step += 1 + + def __str__(self): + fmtstr = '{name}:{avg' + self.fmt + '}' + return fmtstr.format(**self.__dict__) + + +def default_wd_filter(x): + """default weight decay filter.""" + parameter_name = x.name + if parameter_name.endswith('.bias'): + # all bias not using weight decay + return False + if parameter_name.endswith('.gamma'): + # bn weight bias not using weight decay, be carefully for now x not + # include BN + return False + if parameter_name.endswith('.beta'): + # bn weight bias not using weight decay, be carefully for now x not + # include BN + return False + + return True + + +def get_param_groups(network): + """Param groups for optimizer.""" + decay_params = [] + no_decay_params = [] + for x in network.trainable_params(): + parameter_name = x.name + if parameter_name.endswith('.bias'): + # all bias not using weight decay + no_decay_params.append(x) + elif parameter_name.endswith('.gamma'): + # bn weight bias not using weight decay, be carefully for now x not + # include BN + no_decay_params.append(x) + elif parameter_name.endswith('.beta'): + # bn weight bias not using weight decay, be carefully for now x not + # include BN + no_decay_params.append(x) + else: + decay_params.append(x) + + return [{'params': no_decay_params, 'weight_decay': 0.0}, + {'params': decay_params}] + + +class ShapeRecord: + """Log image shape.""" + + def __init__(self): + self.shape_record = { + 416: 0, + 448: 0, + 480: 0, + 512: 0, + 544: 0, + 576: 0, + 608: 0, + 640: 0, + 672: 0, + 704: 0, + 736: 0, + 'total': 0 + } + + def set(self, shape): + if len(shape) > 1: + shape = shape[0] + shape = int(shape) + self.shape_record[shape] += 1 + self.shape_record['total'] += 1 + + def show(self, logger): + for key in self.shape_record: + rate = self.shape_record[key] / float(self.shape_record['total']) + logger.info('shape {}: {:.2f}%'.format(key, rate * 100)) + + +def keep_loss_fp32(network): + """Keep loss of network with float32""" + for _, cell in network.cells_and_names(): + if isinstance(cell, (YoloLossBlock,)): + cell.to_float(mindspore.float32) + + +class Redirct: + def __init__(self): + self.content = "" + + def write(self, content): + self.content += content + + def flush(self): + self.content = "" + + +def cpu_affinity(rank_id, device_num): + """Bind CPU cores according to rank_id and device_num.""" + import psutil + cores = psutil.cpu_count() + if cores < device_num: + return + process = psutil.Process() + used_cpu_num = cores // device_num + rank_id = rank_id % device_num + used_cpu_list = [i for i in range(rank_id * used_cpu_num, (rank_id + 1) * used_cpu_num)] + process.cpu_affinity(used_cpu_list) + print(f"==== {rank_id}/{device_num} ==== bind cpu: {used_cpu_list}") + + +class COCOEvaluator: + """COCO Evaluator.""" + def __init__(self, detection_config) -> None: + self.coco_gt = COCO(detection_config.val_ann_file) + self.coco_catIds = self.coco_gt.getCatIds() + self.coco_imgIds = list(sorted(self.coco_gt.imgs.keys())) + self.coco_transformed_catIds = detection_config.coco_ids + self.logger = detection_config.logger + self.last_mAP = 0.0 + + def get_mAP(self, coco_dt_ann_file: Union[str, List[str]]): + if isinstance(coco_dt_ann_file, str): + return self.get_mAP_single_file(coco_dt_ann_file) + if isinstance(coco_dt_ann_file, list): + return self.get_mAP_multiple_file(coco_dt_ann_file) + raise ValueError("Invalid 'coco_dt_ann_file' type. Support str or List[str].") + + def merge_result_files(self, file_path: List[str]) -> List: + """Merge result files.""" + dt_list = [] + dt_ids_set = set([]) + self.logger.info(f"Total {len(file_path)} json files") + self.logger.info(f"File list: {file_path}") + + for path in file_path: + ann_list = [] + try: + with open(path, 'r') as f: + ann_list = json.load(f) + except json.decoder.JSONDecodeError: + pass # json file is empty + else: + ann_ids = set(ann['image_id'] for ann in ann_list) + diff_ids = ann_ids - dt_ids_set + ann_list = [ann for ann in ann_list if ann['image_id'] in diff_ids] + dt_ids_set = dt_ids_set | diff_ids + dt_list.extend(ann_list) + return dt_list + + def get_coco_from_dt_list(self, dt_list) -> COCO: + """Get coco from dt list.""" + cocoDt = COCO() + cocoDt.dataset = {} + cocoDt.dataset['images'] = [img for img in self.coco_gt.dataset['images']] + cocoDt.dataset['categories'] = copy.deepcopy(self.coco_gt.dataset['categories']) + self.logger.info(f"Number of dt boxes: {len(dt_list)}") + for idx, ann in enumerate(dt_list): + bb = ann['bbox'] + x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]] + if 'segmentation' not in ann: + ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]] + ann['area'] = bb[2] * bb[3] + ann['id'] = idx + 1 + ann['iscrowd'] = 0 + cocoDt.dataset['annotations'] = dt_list + cocoDt.createIndex() + return cocoDt + + def get_mAP_multiple_file(self, coco_dt_ann_file: List[str]) -> str: + dt_list = self.merge_result_files(coco_dt_ann_file) + coco_dt = self.get_coco_from_dt_list(dt_list) + return self.compute_coco_mAP(coco_dt) + + def get_mAP_single_file(self, coco_dt_ann_file: str) -> str: + coco_dt = self.coco_gt.loadRes(coco_dt_ann_file) + return self.compute_coco_mAP(coco_dt) + + def compute_coco_mAP(self, coco_dt: COCO) -> str: + """Compute coco mAP.""" + coco_eval = COCOeval(self.coco_gt, coco_dt, 'bbox') + coco_eval.evaluate() + coco_eval.accumulate() + rdct = Redirct() + stdout = sys.stdout + sys.stdout = rdct + coco_eval.summarize() + sys.stdout = stdout + self.last_mAP = coco_eval.stats[0] + return rdct.content + + +class DetectionEngine: + """Detection engine.""" + + def __init__(self, args_detection, threshold): + self.ignore_threshold = threshold + self.labels = args_detection.labels + self.num_classes = len(self.labels) + self.results = {} + self.file_path = '' + self.save_prefix = args_detection.output_dir + self.ann_file = args_detection.val_ann_file + self.det_boxes = [] + self.nms_thresh = args_detection.eval_nms_thresh + self.multi_label = args_detection.multi_label + self.multi_label_thresh = args_detection.multi_label_thresh + + self.logger = args_detection.logger + self.eval_parallel = args_detection.eval_parallel + if self.eval_parallel: + self.save_prefix = args_detection.save_prefix + self.rank_id = args_detection.rank + self.dir_path = '' + self.coco_evaluator = COCOEvaluator(args_detection) + self.coco_catids = self.coco_evaluator.coco_gt.getCatIds() + self.coco_catIds = args_detection.coco_ids + self._img_ids = list(sorted(self.coco_evaluator.coco_gt.imgs.keys())) + + def do_nms_for_results(self): + """Get result boxes.""" + for img_id in self.results: + for clsi in self.results[img_id]: + dets = self.results[img_id][clsi] + dets = np.array(dets) + keep_index = self._diou_nms(dets, thresh=self.nms_thresh) + + keep_box = [{'image_id': int(img_id), 'category_id': int(clsi), + 'bbox': list(dets[i][:4].astype(float)), + 'score': dets[i][4].astype(float)} for i in keep_index] + self.det_boxes.extend(keep_box) + + def _nms(self, predicts, threshold): + """Calculate NMS.""" + # convert xywh -> xmin ymin xmax ymax + x1 = predicts[:, 0] + y1 = predicts[:, 1] + x2 = x1 + predicts[:, 2] + y2 = y1 + predicts[:, 3] + scores = predicts[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + reserved_boxes = [] + while order.size > 0: + i = order[0] + reserved_boxes.append(i) + max_x1 = np.maximum(x1[i], x1[order[1:]]) + max_y1 = np.maximum(y1[i], y1[order[1:]]) + min_x2 = np.minimum(x2[i], x2[order[1:]]) + min_y2 = np.minimum(y2[i], y2[order[1:]]) + + intersect_w = np.maximum(0.0, min_x2 - max_x1 + 1) + intersect_h = np.maximum(0.0, min_y2 - max_y1 + 1) + intersect_area = intersect_w * intersect_h + ovr = intersect_area / \ + (areas[i] + areas[order[1:]] - intersect_area) + + indexes = np.where(ovr <= threshold)[0] + order = order[indexes + 1] + return reserved_boxes + + def _diou_nms(self, dets, thresh=0.5): + """ + convert xywh -> xmin ymin xmax ymax + """ + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = x1 + dets[:, 2] + y2 = y1 + dets[:, 3] + scores = dets[:, 4] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + center_x1 = (x1[i] + x2[i]) / 2 + center_x2 = (x1[order[1:]] + x2[order[1:]]) / 2 + center_y1 = (y1[i] + y2[i]) / 2 + center_y2 = (y1[order[1:]] + y2[order[1:]]) / 2 + inter_diag = (center_x2 - center_x1) ** 2 + (center_y2 - center_y1) ** 2 + out_max_x = np.maximum(x2[i], x2[order[1:]]) + out_max_y = np.maximum(y2[i], y2[order[1:]]) + out_min_x = np.minimum(x1[i], x1[order[1:]]) + out_min_y = np.minimum(y1[i], y1[order[1:]]) + outer_diag = (out_max_x - out_min_x) ** 2 + (out_max_y - out_min_y) ** 2 + diou = ovr - inter_diag / outer_diag + diou = np.core.umath.clip(diou, -1, 1) + inds = np.where(diou <= thresh)[0] + order = order[inds + 1] + return keep + + def write_result(self, cur_epoch=0, cur_step=0): + """Save result to file.""" + self.logger.info("Save bbox prediction result.") + if self.eval_parallel: + rank_id = self.rank_id + self.dir_path = os.path.join(self.save_prefix, f"eval_epoch{cur_epoch}-step{cur_step}") + if not os.path.exists(self.dir_path): + os.makedirs(self.dir_path, exist_ok=True) + file_name = f"epoch{cur_epoch}-step{cur_step}-rank{rank_id}.json" + self.file_path = os.path.join(self.dir_path, file_name) + else: + t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S') + self.file_path = self.save_prefix + '/predict' + t + '.json' + try: + with open(self.file_path, 'w') as f: + json.dump(self.det_boxes, f) + except IOError as e: + raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e))) + else: + self.logger.info(f'Result file path: {self.file_path}') + self.det_boxes.clear() + + def get_eval_result(self): + """Get eval result.""" + if self.eval_parallel: + file_paths = [os.path.join(self.dir_path, path) for path in os.listdir(self.dir_path)] + eval_results = self.coco_evaluator.get_mAP(file_paths) + else: + eval_results = self.coco_evaluator.get_mAP(self.file_path) + mAP = self.coco_evaluator.last_mAP + return eval_results, mAP + + def detect(self, outputs, batch, image_shape, image_id): + """Detect boxes.""" + # output [|32, 52, 52, 3, 85| ] + for batch_id in range(batch): + for out_item in outputs: + # 52, 52, 3, 85 + out_item_single = out_item[batch_id, :] + ori_w, ori_h = image_shape[batch_id] + img_id = int(image_id[batch_id]) + if img_id not in self.results: + self.results[img_id] = defaultdict(list) + x = ori_w * out_item_single[..., 0].reshape(-1) + y = ori_h * out_item_single[..., 1].reshape(-1) + w = ori_w * out_item_single[..., 2].reshape(-1) + h = ori_h * out_item_single[..., 3].reshape(-1) + conf = out_item_single[..., 4:5] + cls_emb = out_item_single[..., 5:] + x_top_left = x - w / 2. + y_top_left = y - h / 2. + cls_emb = cls_emb.reshape(-1, self.num_classes) + if self.multi_label: + conf = conf.reshape(-1, 1) + confidence = conf * cls_emb + # create all False + flag = (cls_emb > self.multi_label_thresh) & (confidence >= self.ignore_threshold) + i, j = flag.nonzero() + x_left, y_left = np.maximum(0, x_top_left[i]), np.maximum(0, y_top_left[i]) + w, h = np.minimum(ori_w, w[i]), np.minimum(ori_h, h[i]) + cls_id = np.array(self.coco_catIds)[j] + conf = confidence[i, j] + for (x_i, y_i, w_i, h_i, conf_i, cls_id_i) in zip(x_left, y_left, w, h, conf, cls_id): + self.results[img_id][cls_id_i].append([x_i, y_i, w_i, h_i, conf_i]) + else: + cls_argmax = np.argmax(cls_emb, axis=-1) + # create all False + flag = np.random.random(cls_emb.shape) > sys.maxsize + for i in range(flag.shape[0]): + c = cls_argmax[i] + flag[i, c] = True + confidence = conf.reshape(-1) * cls_emb[flag] + for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, + w, h, confidence, cls_argmax): + if confi < self.ignore_threshold: + continue + x_lefti, y_lefti = max(0, x_lefti), max(0, y_lefti) + wi, hi = min(wi, ori_w), min(hi, ori_h) + # transform catId to match coco + coco_clsi = self.coco_catids[clsi] + self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi]) + + +class AllReduce(nn.Cell): + """AllReduce.""" + def __init__(self): + super(AllReduce, self).__init__() + self.all_reduce = ops.AllReduce() + + def construct(self, x): + return self.all_reduce(x) + + +class EvalWrapper: + """EvalWrapper.""" + def __init__(self, config, network, dataset, engine: DetectionEngine) -> None: + self.logger = config.logger + self.network = network + self.dataset = dataset + self.per_batch_size = config.per_batch_size + self.device_num = config.group_size + self.input_shape = Tensor(tuple(config.test_img_shape), mindspore.float32) + self.engine = engine + self.eval_parallel = config.eval_parallel + if config.eval_parallel: + self.reduce = AllReduce() + + def synchronize(self): + sync = Tensor(np.array([1]).astype(np.int32)) + sync = self.reduce(sync) # For synchronization + sync = sync.asnumpy()[0] + if sync != self.device_num: + raise ValueError( + f"Sync value {sync} is not equal to number of device {self.device_num}. " + f"There might be wrong with devices." + ) + + def inference(self): + """inference""" + for index, data in enumerate(self.dataset.create_dict_iterator(output_numpy=True, num_epochs=1)): + image = data["image"] + image = mindspore.Tensor(image) + image_shape_ = data["image_shape"] + image_id_ = data["img_id"] + output_big, output_me, output_small = self.network(image, self.input_shape) + output_big = output_big.asnumpy() + output_me = output_me.asnumpy() + output_small = output_small.asnumpy() + self.engine.detect([output_small, output_me, output_big], self.per_batch_size, image_shape_, image_id_) + + if index % 50 == 0: + self.logger.info('Processing... {:.2f}% '.format(index / self.dataset.get_dataset_size() * 100)) + + def get_results(self, cur_epoch=0, cur_step=0): + """get_results""" + self.logger.info('Calculating mAP...') + self.engine.do_nms_for_results() + self.engine.write_result(cur_epoch=cur_epoch, cur_step=cur_step) + if self.eval_parallel: + self.synchronize() # Synchronize to avoid read incomplete results + return self.engine.get_eval_result() diff --git a/community/cv/ADCAM/src/yolo.py b/community/cv/ADCAM/src/yolo.py new file mode 100644 index 0000000000000000000000000000000000000000..141f3ea175b3d4fbb4b02ca628c293ce8177cb8e --- /dev/null +++ b/community/cv/ADCAM/src/yolo.py @@ -0,0 +1,688 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""yolo""" +import math +import numpy as np +import mindspore +import mindspore.nn as nn +import mindspore.ops as ops +from src.backbone import ( + YOLOv5Backbone, + Conv, + BottleneckCSP, + CAAttention, +) +from src.loss import ConfidenceLoss, ClassLoss + +from model_utils.config import config as default_config + + +class YOLO(nn.Cell): + """YOLO""" + def __init__(self, backbone, shape): + super(YOLO, self).__init__() + self.backbone = backbone + self.config = default_config + self.config.out_channel = (self.config.num_classes + 5) * 3 + + self.conv1 = Conv(shape[5], shape[4], k=1, s=1) + self.CSP5 = BottleneckCSP(shape[5], shape[4], n=1 * shape[6], shortcut=False) + self.conv2 = Conv(shape[4], shape[3], k=1, s=1) + self.CSP6 = BottleneckCSP(shape[4], shape[3], n=1 * shape[6], shortcut=False) + self.conv3 = Conv(shape[3], shape[3], k=3, s=2) + self.CSP7 = BottleneckCSP(shape[4], shape[4], n=1 * shape[6], shortcut=False) + self.conv4 = Conv(shape[4], shape[4], k=3, s=2) + print("************----------********************") + self.CSP8 = BottleneckCSP(shape[5], shape[5], n=1 * shape[6], shortcut=False) + print("************----------********************") + + self.back_block1 = YoloBlock(shape[3], self.config.out_channel) + self.back_block2 = YoloBlock(shape[4], self.config.out_channel) + self.back_block3 = YoloBlock(shape[5], self.config.out_channel) + + self.pre_back_block1 = CAAttention( + self.config.out_channel, self.config.out_channel + ) + self.pre_back_block2 = CAAttention( + self.config.out_channel, self.config.out_channel + ) + self.pre_back_block3 = CAAttention( + self.config.out_channel, self.config.out_channel + ) + + self.concat = ops.Concat(axis=1) + + def construct(self, x): + """construct method""" + + img_height = x.shape[2] * 2 + img_width = x.shape[3] * 2 + + feature_map1, feature_map2, feature_map3 = self.backbone(x) + + c1 = self.conv1(feature_map3) + ups1 = ops.ResizeNearestNeighbor((img_height // 16, img_width // 16))(c1) + c2 = self.concat((ups1, feature_map2)) + c3 = self.CSP5(c2) + c4 = self.conv2(c3) + ups2 = ops.ResizeNearestNeighbor((img_height // 8, img_width // 8))(c4) + c5 = self.concat((ups2, feature_map1)) + # out + c6 = self.CSP6(c5) + c7 = self.conv3(c6) + + c8 = self.concat((c7, c4)) + # out + c9 = self.CSP7(c8) + c10 = self.conv4(c9) + c11 = self.concat((c10, c1)) + # out + c12 = self.CSP8(c11) + + c6 = self.back_block1(c6) + c9 = self.back_block2(c9) + c12 = self.back_block3(c12) + + small_object_output = self.pre_back_block1(c6) + medium_object_output = self.pre_back_block2(c9) + big_object_output = self.pre_back_block3(c12) + + # print("c6",c6.shape,"c9",c9.shape,"c12",c12.shape) + return small_object_output, medium_object_output, big_object_output + + +class YoloBlock(nn.Cell): + """YoloBlock""" + + def __init__(self, in_channels, out_channels): + super(YoloBlock, self).__init__() + + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, has_bias=True + ) + + def construct(self, x): + """construct method""" + + out = self.conv(x) + return out + + +class DetectionBlock(nn.Cell): + """DetectionBlock""" + + def __init__(self, scale, config=default_config, is_training=True): + super(DetectionBlock, self).__init__() + self.config = config + if scale == "s": + idx = (0, 1, 2) + self.scale_x_y = 1.2 + self.offset_x_y = 0.1 + elif scale == "m": + idx = (3, 4, 5) + self.scale_x_y = 1.1 + self.offset_x_y = 0.05 + elif scale == "l": + idx = (6, 7, 8) + self.scale_x_y = 1.05 + self.offset_x_y = 0.025 + else: + raise KeyError("Invalid scale value for DetectionBlock") + self.anchors = mindspore.Tensor( + [self.config.anchor_scales[i] for i in idx], mindspore.float32 + ) + self.num_anchors_per_scale = 3 + self.num_attrib = 4 + 1 + self.config.num_classes + self.lambda_coord = 1 + + self.sigmoid = nn.Sigmoid() + self.reshape = ops.Reshape() + self.tile = ops.Tile() + self.concat = ops.Concat(axis=-1) + self.pow = ops.Pow() + self.transpose = ops.Transpose() + self.exp = ops.Exp() + self.conf_training = is_training + + def construct(self, x, input_shape): + """construct method""" + num_batch = x.shape[0] + grid_size = x.shape[2:4] + + # Reshape and transpose the feature to [n, grid_size[0], grid_size[1], 3, num_attrib] + prediction = self.reshape( + x, + ( + num_batch, + self.num_anchors_per_scale, + self.num_attrib, + grid_size[0], + grid_size[1], + ), + ) + prediction = self.transpose(prediction, (0, 3, 4, 1, 2)) + + grid_x = mindspore.numpy.arange(grid_size[1]) + grid_y = mindspore.numpy.arange(grid_size[0]) + # Tensor of shape [grid_size[0], grid_size[1], 1, 1] representing the coordinate of x/y axis for each grid + # [batch, gridx, gridy, 1, 1] + grid_x = self.tile( + self.reshape(grid_x, (1, 1, -1, 1, 1)), (1, grid_size[0], 1, 1, 1) + ) + grid_y = self.tile( + self.reshape(grid_y, (1, -1, 1, 1, 1)), (1, 1, grid_size[1], 1, 1) + ) + # Shape is [grid_size[0], grid_size[1], 1, 2] + grid = self.concat((grid_x, grid_y)) + + box_xy = prediction[:, :, :, :, :2] + box_wh = prediction[:, :, :, :, 2:4] + box_confidence = prediction[:, :, :, :, 4:5] + box_probs = prediction[:, :, :, :, 5:] + + # gridsize1 is x + # gridsize0 is y + box_xy = ( + self.scale_x_y * self.sigmoid(box_xy) - self.offset_x_y + grid + ) / ops.cast( + ops.tuple_to_array((grid_size[1], grid_size[0])), mindspore.float32 + ) + # box_wh is w->h + box_wh = self.exp(box_wh) * self.anchors / input_shape + + box_confidence = self.sigmoid(box_confidence) + box_probs = self.sigmoid(box_probs) + + if self.conf_training: + return prediction, box_xy, box_wh + return self.concat((box_xy, box_wh, box_confidence, box_probs)) + + +class Iou(nn.Cell): + """Calculate the iou of boxes""" + def __init__(self): + super(Iou, self).__init__() + self.min = ops.Minimum() + self.max = ops.Maximum() + self.squeeze = ops.Squeeze(-1) + + def construct(self, box1, box2): + """ + box1: pred_box [batch, gx, gy, anchors, 1, 4] ->4: [x_center, y_center, w, h] + box2: gt_box [batch, 1, 1, 1, maxbox, 4] + convert to topLeft and rightDown + """ + box1_xy = box1[:, :, :, :, :, :2] + box1_wh = box1[:, :, :, :, :, 2:4] + box1_mins = box1_xy - box1_wh / ops.scalar_to_tensor(2.0) # topLeft + box1_maxs = box1_xy + box1_wh / ops.scalar_to_tensor(2.0) # rightDown + + box2_xy = box2[:, :, :, :, :, :2] + box2_wh = box2[:, :, :, :, :, 2:4] + box2_mins = box2_xy - box2_wh / ops.scalar_to_tensor(2.0) + box2_maxs = box2_xy + box2_wh / ops.scalar_to_tensor(2.0) + + intersect_mins = self.max(box1_mins, box2_mins) + intersect_maxs = self.min(box1_maxs, box2_maxs) + intersect_wh = self.max( + intersect_maxs - intersect_mins, ops.scalar_to_tensor(0.0) + ) + # self.squeeze: for effiecient slice + intersect_area = self.squeeze(intersect_wh[:, :, :, :, :, 0:1]) * self.squeeze( + intersect_wh[:, :, :, :, :, 1:2] + ) + box1_area = self.squeeze(box1_wh[:, :, :, :, :, 0:1]) * self.squeeze( + box1_wh[:, :, :, :, :, 1:2] + ) + box2_area = self.squeeze(box2_wh[:, :, :, :, :, 0:1]) * self.squeeze( + box2_wh[:, :, :, :, :, 1:2] + ) + iou = intersect_area / (box1_area + box2_area - intersect_area) + # iou : [batch, gx, gy, anchors, maxboxes] + return iou + + +class YoloLossBlock(nn.Cell): + """ + Loss block cell of YOLOV5 network. + """ + + def __init__(self, scale, config=default_config): + super(YoloLossBlock, self).__init__() + self.config = config + if scale == "s": + # anchor mask + idx = (0, 1, 2) + elif scale == "m": + idx = (3, 4, 5) + elif scale == "l": + idx = (6, 7, 8) + else: + raise KeyError("Invalid scale value for DetectionBlock") + self.anchors = mindspore.Tensor( + [self.config.anchor_scales[i] for i in idx], mindspore.float32 + ) + self.ignore_threshold = mindspore.Tensor( + self.config.ignore_threshold, mindspore.float32 + ) + self.concat = ops.Concat(axis=-1) + self.iou = Iou() + self.reduce_max = ops.ReduceMax(keep_dims=False) + self.confidence_loss = ConfidenceLoss() + self.class_loss = ClassLoss() + + self.reduce_sum = ops.ReduceSum() + self.select = ops.Select() + self.equal = ops.Equal() + self.reshape = ops.Reshape() + self.expand_dims = ops.ExpandDims() + self.ones_like = ops.OnesLike() + self.log = ops.Log() + self.tuple_to_array = ops.TupleToArray() + # self.g_iou = GIou() + self.g_iou = WIoU() + + def construct(self, prediction, pred_xy, pred_wh, y_true, gt_box, input_shape): + """ + prediction : origin output from yolo + pred_xy: (sigmoid(xy)+grid)/grid_size + pred_wh: (exp(wh)*anchors)/input_shape + y_true : after normalize + gt_box: [batch, maxboxes, xyhw] after normalize + """ + object_mask = y_true[:, :, :, :, 4:5] + class_probs = y_true[:, :, :, :, 5:] + true_boxes = y_true[:, :, :, :, :4] + + grid_shape = prediction.shape[1:3] + grid_shape = ops.cast(self.tuple_to_array(grid_shape[::-1]), mindspore.float32) + + pred_boxes = self.concat((pred_xy, pred_wh)) + true_wh = y_true[:, :, :, :, 2:4] + true_wh = self.select( + self.equal(true_wh, 0.0), self.ones_like(true_wh), true_wh + ) + true_wh = self.log(true_wh / self.anchors * input_shape) + # 2-w*h for large picture, use small scale, since small obj need more precise + box_loss_scale = 2 - y_true[:, :, :, :, 2:3] * y_true[:, :, :, :, 3:4] + + gt_shape = gt_box.shape + gt_box = self.reshape(gt_box, (gt_shape[0], 1, 1, 1, gt_shape[1], gt_shape[2])) + + # add one more dimension for broadcast + iou = self.iou(self.expand_dims(pred_boxes, -2), gt_box) + # gt_box is x,y,h,w after normalize + # [batch, grid[0], grid[1], num_anchor, num_gt] + best_iou = self.reduce_max(iou, -1) + # [batch, grid[0], grid[1], num_anchor] + + # ignore_mask IOU too small + ignore_mask = best_iou < self.ignore_threshold + ignore_mask = ops.cast(ignore_mask, mindspore.float32) + ignore_mask = self.expand_dims(ignore_mask, -1) + # ignore_mask backpro will cause a lot maximunGrad and minimumGrad time consume. + # so we turn off its gradient + ignore_mask = ops.stop_gradient(ignore_mask) + + confidence_loss = self.confidence_loss( + object_mask, prediction[:, :, :, :, 4:5], ignore_mask + ) + class_loss = self.class_loss( + object_mask, prediction[:, :, :, :, 5:], class_probs + ) + + object_mask_me = self.reshape(object_mask, (-1, 1)) # [8, 72, 72, 3, 1] + box_loss_scale_me = self.reshape(box_loss_scale, (-1, 1)) + pred_boxes_me = xywh2x1y1x2y2(pred_boxes) + pred_boxes_me = self.reshape(pred_boxes_me, (-1, 4)) + true_boxes_me = xywh2x1y1x2y2(true_boxes) + true_boxes_me = self.reshape(true_boxes_me, (-1, 4)) + c_iou = self.g_iou(pred_boxes_me, true_boxes_me) + c_iou_loss = object_mask_me * box_loss_scale_me * (1 - c_iou) + c_iou_loss_me = self.reduce_sum(c_iou_loss, ()) + loss = c_iou_loss_me * 4 + confidence_loss + class_loss + batch_size = prediction.shape[0] + return loss / batch_size + + +class YOLOV5(nn.Cell): + """ + YOLOV5 network. + + Args: + is_training: Bool. Whether train or not. + + Returns: + Cell, cell instance of YOLOV5 neural network. + + Examples: + YOLOV5s(True) + """ + + def __init__(self, is_training, version=0): + super(YOLOV5, self).__init__() + self.config = default_config + + # YOLOv5 network + self.shape = self.config.input_shape[version] + self.feature_map = YOLO( + backbone=YOLOv5Backbone(shape=self.shape), shape=self.shape + ) + + # prediction on the default anchor boxes + self.detect_1 = DetectionBlock("l", is_training=is_training) + self.detect_2 = DetectionBlock("m", is_training=is_training) + self.detect_3 = DetectionBlock("s", is_training=is_training) + self.mean = mindspore.Tensor( + np.array([0.485 * 255, 0.456 * 255, 0.406 * 255], dtype=np.float32) + ).reshape((1, 1, 1, 3)) + self.std = mindspore.Tensor( + np.array([0.229 * 255, 0.224 * 255, 0.225 * 255], dtype=np.float32) + ).reshape((1, 1, 1, 3)) + + def construct(self, x, input_shape): + """construct.""" + x = (x - self.mean) / self.std + x = ops.transpose(x, (0, 3, 1, 2)) + x = ops.concat( + ( + x[:, :, ::2, ::2], + x[:, :, 1::2, ::2], + x[:, :, ::2, 1::2], + x[:, :, 1::2, 1::2], + ), + 1, + ) + small_object_output, medium_object_output, big_object_output = self.feature_map( + x + ) + output_big = self.detect_1(big_object_output, input_shape) + output_me = self.detect_2(medium_object_output, input_shape) + output_small = self.detect_3(small_object_output, input_shape) + # big is the final output which has smallest feature map + return output_big, output_me, output_small + + +class YOLOV5s_Infer(nn.Cell): + """ + YOLOV5 Infer. + """ + + def __init__(self, input_shape, version=0): + super(YOLOV5s_Infer, self).__init__() + self.network = YOLOV5(is_training=False, version=version) + self.input_shape = input_shape + + def construct(self, x): + return self.network(x, self.input_shape) + + +class YoloWithLossCell(nn.Cell): + """YOLOV5 loss.""" + + def __init__(self, network): + super(YoloWithLossCell, self).__init__() + self.yolo_network = network + self.config = default_config + self.loss_big = YoloLossBlock("l", self.config) + self.loss_me = YoloLossBlock("m", self.config) + self.loss_small = YoloLossBlock("s", self.config) + self.tenser_to_array = ops.TupleToArray() + + def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2, input_shape): + yolo_out = self.yolo_network(x, input_shape) + loss_l = self.loss_big(*yolo_out[0], y_true_0, gt_0, input_shape) + loss_m = self.loss_me(*yolo_out[1], y_true_1, gt_1, input_shape) + loss_s = self.loss_small(*yolo_out[2], y_true_2, gt_2, input_shape) + return loss_l + loss_m + loss_s * 0.2 + + +class GIou(nn.Cell): + """Calculating giou""" + + def __init__(self): + super(GIou, self).__init__() + self.reshape = ops.Reshape() + self.min = ops.Minimum() + self.max = ops.Maximum() + self.concat = ops.Concat(axis=1) + self.mean = ops.ReduceMean() + self.div = ops.RealDiv() + self.eps = 0.000001 + + def construct(self, box_p, box_gt): + """construct method""" + print("**************************************GIOU***************************************************") + box_p_area = (box_p[..., 2:3] - box_p[..., 0:1]) * (box_p[..., 3:4] - box_p[..., 1:2]) + box_gt_area = (box_gt[..., 2:3] - box_gt[..., 0:1]) * (box_gt[..., 3:4] - box_gt[..., 1:2]) + x_1 = self.max(box_p[..., 0:1], box_gt[..., 0:1]) + x_2 = self.min(box_p[..., 2:3], box_gt[..., 2:3]) + y_1 = self.max(box_p[..., 1:2], box_gt[..., 1:2]) + y_2 = self.min(box_p[..., 3:4], box_gt[..., 3:4]) + intersection = (y_2 - y_1) * (x_2 - x_1) + xc_1 = self.min(box_p[..., 0:1], box_gt[..., 0:1]) + xc_2 = self.max(box_p[..., 2:3], box_gt[..., 2:3]) + yc_1 = self.min(box_p[..., 1:2], box_gt[..., 1:2]) + yc_2 = self.max(box_p[..., 3:4], box_gt[..., 3:4]) + c_area = (xc_2 - xc_1) * (yc_2 - yc_1) + union = box_p_area + box_gt_area - intersection + union = union + self.eps + c_area = c_area + self.eps + iou = self.div( + ops.cast(intersection, mindspore.float32), + ops.cast(union, mindspore.float32), + ) + res_mid0 = c_area - union + res_mid1 = self.div( + ops.cast(res_mid0, mindspore.float32), ops.cast(c_area, mindspore.float32) + ) + giou = iou - res_mid1 + giou = ops.clip_by_value(giou, -1.0, 1.0) + return giou + + +class CIou(nn.Cell): + """Calculating CIoU loss.""" + + def __init__(self): + super(CIou, self).__init__() + self.min = ops.Minimum() + self.max = ops.Maximum() + self.clip = ops.clip_by_value + self.atan = ops.Atan() + self.stop_gradient = ops.stop_gradient + self.eps = 1e-6 + + def construct(self, box_p, box_gt): + """Construct method to compute CIoU.""" + # 计算预测框和真实框的面积 + box_p_area = (box_p[..., 2] - box_p[..., 0]) * (box_p[..., 3] - box_p[..., 1]) + box_gt_area = (box_gt[..., 2] - box_gt[..., 0]) * ( + box_gt[..., 3] - box_gt[..., 1] + ) + + # 计算交集的坐标 + x1 = self.max(box_p[..., 0], box_gt[..., 0]) + y1 = self.max(box_p[..., 1], box_gt[..., 1]) + x2 = self.min(box_p[..., 2], box_gt[..., 2]) + y2 = self.min(box_p[..., 3], box_gt[..., 3]) + + # 计算交集的宽度和高度,并裁剪为非负值 + inter_w = self.clip(x2 - x1, 0.0, None) + inter_h = self.clip(y2 - y1, 0.0, None) + intersection = inter_w * inter_h + + # 计算并集的面积 + union = box_p_area + box_gt_area - intersection + self.eps + + # 计算IoU + iou = intersection / union + + box_p_center_x = (box_p[..., 0] + box_p[..., 2]) / 2 + box_p_center_y = (box_p[..., 1] + box_p[..., 3]) / 2 + box_gt_center_x = (box_gt[..., 0] + box_gt[..., 2]) / 2 + box_gt_center_y = (box_gt[..., 1] + box_gt[..., 3]) / 2 + + center_dist = (box_p_center_x - box_gt_center_x) ** 2 + ( + box_p_center_y - box_gt_center_y + ) ** 2 + + enclose_x1 = self.min(box_p[..., 0], box_gt[..., 0]) + enclose_y1 = self.min(box_p[..., 1], box_gt[..., 1]) + enclose_x2 = self.max(box_p[..., 2], box_gt[..., 2]) + enclose_y2 = self.max(box_p[..., 3], box_gt[..., 3]) + enclose_diag = ( + (enclose_x2 - enclose_x1) ** 2 + (enclose_y2 - enclose_y1) ** 2 + self.eps + ) + + distance_term = center_dist / enclose_diag + + box_p_w = box_p[..., 2] - box_p[..., 0] + self.eps + box_p_h = box_p[..., 3] - box_p[..., 1] + self.eps + box_gt_w = box_gt[..., 2] - box_gt[..., 0] + self.eps + box_gt_h = box_gt[..., 3] - box_gt[..., 1] + self.eps + + v = (4 / (math.pi**2)) * ( + self.atan(box_gt_w / box_gt_h) - self.atan(box_p_w / box_p_h) + ) ** 2 + + with ops.stop_gradient(): + S = 1 - iou + v + self.eps + alpha = v / S + + ciou = iou - (distance_term + alpha * v) + + ciou = self.clip(ciou, -1.0, 1.0) + + return ciou + + +def xywh2x1y1x2y2(box_xywh): + boxes_x1 = box_xywh[..., 0:1] - box_xywh[..., 2:3] / 2 + boxes_y1 = box_xywh[..., 1:2] - box_xywh[..., 3:4] / 2 + boxes_x2 = box_xywh[..., 0:1] + box_xywh[..., 2:3] / 2 + boxes_y2 = box_xywh[..., 1:2] + box_xywh[..., 3:4] / 2 + boxes_x1y1x2y2 = ops.Concat(-1)((boxes_x1, boxes_y1, boxes_x2, boxes_y2)) + + return boxes_x1y1x2y2 + + +class WIoU(nn.Cell): + """Calculating WIoU""" + + def __init__(self): + super(WIoU, self).__init__() + self.reshape = ops.Reshape() + self.min = ops.Minimum() + self.max = ops.Maximum() + self.concat = ops.Concat(axis=1) + self.mean = ops.ReduceMean() + self.div = ops.RealDiv() + self.eps = 0.000001 + + def construct(self, box_p, box_gt): + # print("*******************************************WIoU**********************************************************") + """construct method""" + box_p_area = (box_p[..., 2:3] - box_p[..., 0:1]) * ( + box_p[..., 3:4] - box_p[..., 1:2] + ) + box_gt_area = (box_gt[..., 2:3] - box_gt[..., 0:1]) * ( + box_gt[..., 3:4] - box_gt[..., 1:2] + ) + + x_1 = self.max(box_p[..., 0:1], box_gt[..., 0:1]) + y_1 = self.max(box_p[..., 1:2], box_gt[..., 1:2]) + x_2 = self.min(box_p[..., 2:3], box_gt[..., 2:3]) + y_2 = self.min(box_p[..., 3:4], box_gt[..., 3:4]) + + intersection = (x_2 - x_1).clip(0, None) * (y_2 - y_1).clip(0, None) + + union = box_p_area + box_gt_area - intersection + self.eps + + iou = self.div( + ops.cast(intersection, mindspore.float32), + ops.cast(union, mindspore.float32), + ) + + x_p_center = (box_p[..., 0:1] + box_p[..., 2:3]) / 2 + y_p_center = (box_p[..., 1:2] + box_p[..., 3:4]) / 2 + x_gt_center = (box_gt[..., 0:1] + box_gt[..., 2:3]) / 2 + y_gt_center = (box_gt[..., 1:2] + box_gt[..., 3:4]) / 2 + + rho2 = (x_p_center - x_gt_center) ** 2 + (y_p_center - y_gt_center) ** 2 + + xc_1 = self.min(box_p[..., 0:1], box_gt[..., 0:1]) + yc_1 = self.min(box_p[..., 1:2], box_gt[..., 1:2]) + xc_2 = self.max(box_p[..., 2:3], box_gt[..., 2:3]) + yc_2 = self.max(box_p[..., 3:4], box_gt[..., 3:4]) + c2 = (xc_2 - xc_1) ** 2 + (yc_2 - yc_1) ** 2 + self.eps + + wiou = iou - self.div( + ops.cast(rho2, mindspore.float32), ops.cast(c2, mindspore.float32) + ) + wiou = ops.clip_by_value(wiou, -1.0, 1.0) + return wiou + + +def ciou(boxes1, boxes2): + """ + cal CIOU of two boxes or batch boxes + :param boxes1:[xmin,ymin,xmax,ymax] or + [[xmin,ymin,xmax,ymax],[xmin,ymin,xmax,ymax],...] + :param boxes2:[xmin,ymin,xmax,ymax] + :return: + """ + + # cal the box's area of boxes1 and boxess + boxes1Area = (boxes1[..., 2] - boxes1[..., 0]) * (boxes1[..., 3] - boxes1[..., 1]) + boxes2Area = (boxes2[..., 2] - boxes2[..., 0]) * (boxes2[..., 3] - boxes2[..., 1]) + + # cal Intersection + left_up = np.maximum(boxes1[..., :2], boxes2[..., :2]) + right_down = np.minimum(boxes1[..., 2:], boxes2[..., 2:]) + + inter_section = np.maximum(right_down - left_up, 0.0) + inter_area = inter_section[..., 0] * inter_section[..., 1] + union_area = boxes1Area + boxes2Area - inter_area + ious = np.maximum(1.0 * inter_area / union_area, np.finfo(np.float32).eps) + + # cal outer boxes + outer_left_up = np.minimum(boxes1[..., :2], boxes2[..., :2]) + outer_right_down = np.maximum(boxes1[..., 2:], boxes2[..., 2:]) + outer = np.maximum(outer_right_down - outer_left_up, 0.0) + outer_diagonal_line = np.square(outer[..., 0]) + np.square(outer[..., 1]) + + # cal center distance + boxes1_center = (boxes1[..., :2] + boxes1[..., 2:]) * 0.5 + boxes2_center = (boxes2[..., :2] + boxes2[..., 2:]) * 0.5 + center_dis = np.square(boxes1_center[..., 0] - boxes2_center[..., 0]) + np.square( + boxes1_center[..., 1] - boxes2_center[..., 1] + ) + + # cal penalty term + # cal width,height + boxes1_size = np.maximum(boxes1[..., 2:] - boxes1[..., :2], 0.0) + boxes2_size = np.maximum(boxes2[..., 2:] - boxes2[..., :2], 0.0) + v = (4.0 / np.square(np.pi)) * np.square( + ( + np.arctan((boxes1_size[..., 0] / boxes1_size[..., 1])) + - np.arctan((boxes2_size[..., 0] / boxes2_size[..., 1])) + ) + ) + alpha = v / (1 - ious + v) + # cal ciou + cious = ious - (center_dis / outer_diagonal_line + alpha * v) + + return cious diff --git a/community/cv/ADCAM/src/yolo_dataset.py b/community/cv/ADCAM/src/yolo_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9f1dbb4ec63c9d03e733bd8db872c04d34639c3d --- /dev/null +++ b/community/cv/ADCAM/src/yolo_dataset.py @@ -0,0 +1,290 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""YOLOV5 Dataset """ +import os +import multiprocessing +import random +import numpy as np +import cv2 +from PIL import Image +from pycocotools.coco import COCO +import mindspore.dataset as ds +from src.distributed_sampler import DistributedSampler +from src.transforms import reshape_fn, MultiScaleTrans, PreprocessTrueBox + + +min_keypoints_per_image = 10 +GENERATOR_PARALLEL_WORKER = 8 + +def _has_only_empty_bbox(anno): + return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) + + +def _count_visible_keypoints(anno): + return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) + + +def has_valid_annotation(anno): + """Check annotation file.""" + # if it's empty, there is no annotation + if not anno: + return False + # if all boxes have close to zero area, there is no annotation + if _has_only_empty_bbox(anno): + return False + # keypoints task have a slight different criteria for considering + # if an annotation is valid + if "keypoints" not in anno[0]: + return True + # for keypoint detection tasks, only consider valid images those + # containing at least min_keypoints_per_image + if _count_visible_keypoints(anno) >= min_keypoints_per_image: + return True + return False + + +class COCOYoloDataset: + """YOLOV5 Dataset for COCO.""" + def __init__(self, root, ann_file, remove_images_without_annotations=True, + filter_crowd_anno=True, is_training=True): + self.coco = COCO(ann_file) + self.root = root + self.img_ids = list(sorted(self.coco.imgs.keys())) + self.filter_crowd_anno = filter_crowd_anno + self.is_training = is_training + self.mosaic = True + # filter images without any annotations + if remove_images_without_annotations: + img_ids = [] + for img_id in self.img_ids: + ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = self.coco.loadAnns(ann_ids) + if has_valid_annotation(anno): + img_ids.append(img_id) + self.img_ids = img_ids + + self.categories = {cat["id"]: cat["name"] for cat in self.coco.cats.values()} + + self.cat_ids_to_continuous_ids = { + v: i for i, v in enumerate(self.coco.getCatIds()) + } + self.continuous_ids_cat_ids = { + v: k for k, v in self.cat_ids_to_continuous_ids.items() + } + self.count = 0 + + def _mosaic_preprocess(self, index, input_size): + """Mosaic preprocess.""" + labels4 = [] + s = 384 + self.mosaic_border = [-s // 2, -s // 2] + yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] + indices = [index] + [random.randint(0, len(self.img_ids) - 1) for _ in range(3)] + for i, img_ids_index in enumerate(indices): + coco = self.coco + img_id = self.img_ids[img_ids_index] + # print("================") + img_path = coco.loadImgs(img_id)[0]["file_name"] + # print(coco.loadImgs(img_id)[0]) #{'id': 20180004325, 'file_name': 'IP025000201', 'width': 539, 'height': 356} + + # print(os.path.join(self.root, img_path)+"===========") + if not img_path.lower().endswith('.jpg'): + img_path += '.jpg' + img = Image.open(os.path.join(self.root, img_path)).convert("RGB") + img = np.array(img) + h, w = img.shape[:2] + + if i == 0: # top left + img4 = np.full((s * 2, s * 2, img.shape[2]), 128, dtype=np.uint8) # base image with 4 tiles + x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image) + x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image) + elif i == 1: # top right + x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc + x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h + elif i == 2: # bottom left + x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h) + x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h) + elif i == 3: # bottom right + x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h) + x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h) + + img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax] + + padw = x1a - x1b + padh = y1a - y1b + + ann_ids = coco.getAnnIds(imgIds=img_id) + target = coco.loadAnns(ann_ids) + # filter crowd annotations + if self.filter_crowd_anno: + annos = [anno for anno in target if anno["iscrowd"] == 0] + else: + annos = [anno for anno in target] + + target = {} + boxes = [anno["bbox"] for anno in annos] + target["bboxes"] = boxes + + classes = [anno["category_id"] for anno in annos] + classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes] + target["labels"] = classes + + bboxes = target['bboxes'] + labels = target['labels'] + out_target = [] + + for bbox, label in zip(bboxes, labels): + tmp = [] + # convert to [x_min y_min x_max y_max] + bbox = self._convetTopDown(bbox) + tmp.extend(bbox) + tmp.append(int(label)) + # tmp [x_min y_min x_max y_max, label] + out_target.append(tmp) + + labels = out_target.copy() + labels = np.array(labels) + out_target = np.array(out_target) + + labels[:, 0] = out_target[:, 0] + padw + labels[:, 1] = out_target[:, 1] + padh + labels[:, 2] = out_target[:, 2] + padw + labels[:, 3] = out_target[:, 3] + padh + labels4.append(labels) + + if labels4: + labels4 = np.concatenate(labels4, 0) + np.clip(labels4[:, :4], 0, 2 * s, out=labels4[:, :4]) # use with random_perspective + flag = np.array([1]) + return img4, labels4, input_size, flag + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + (img, target) (tuple): target is a dictionary contains "bbox", "segmentation" or "keypoints", + generated by the image's annotation. img is a PIL image. + """ + coco = self.coco + img_id = self.img_ids[index] + img_path = coco.loadImgs(img_id)[0]["file_name"] + + if not img_path.lower().endswith(('.jpg', '.png')): + img_path += '.jpg' + + if not self.is_training: + img = Image.open(os.path.join(self.root, img_path)).convert("RGB") + return img, img_id + + input_size = [640, 640] + if self.mosaic and random.random() < 0.5: + return self._mosaic_preprocess(index, input_size) + if not img_path.lower().endswith(('.jpg', '.png')): + img_path += '.jpg' + img = np.fromfile(os.path.join(self.root, img_path), dtype='int8') + ann_ids = coco.getAnnIds(imgIds=img_id) + target = coco.loadAnns(ann_ids) + # filter crowd annotations + if self.filter_crowd_anno: + annos = [anno for anno in target if anno["iscrowd"] == 0] + else: + annos = [anno for anno in target] + + target = {} + boxes = [anno["bbox"] for anno in annos] + target["bboxes"] = boxes + + classes = [anno["category_id"] for anno in annos] + classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes] + target["labels"] = classes + + bboxes = target['bboxes'] + labels = target['labels'] + out_target = [] + for bbox, label in zip(bboxes, labels): + tmp = [] + # convert to [x_min y_min x_max y_max] + bbox = self._convetTopDown(bbox) + tmp.extend(bbox) + tmp.append(int(label)) + # tmp [x_min y_min x_max y_max, label] + out_target.append(tmp) + flag = np.array([0]) + return img, out_target, input_size, flag + + def __len__(self): + return len(self.img_ids) + + def _convetTopDown(self, bbox): + x_min = bbox[0] + y_min = bbox[1] + w = bbox[2] + h = bbox[3] + return [x_min, y_min, x_min+w, y_min+h] + + +def create_yolo_dataset(image_dir, anno_path, batch_size, device_num, rank, + config=None, is_training=True, shuffle=True): + """Create dataset for YOLOV5.""" + cv2.setNumThreads(0) + ds.config.set_enable_shared_mem(True) + if is_training: + filter_crowd = True + remove_empty_anno = True + else: + filter_crowd = False + remove_empty_anno = False + + yolo_dataset = COCOYoloDataset(root=image_dir, ann_file=anno_path, filter_crowd_anno=filter_crowd, + remove_images_without_annotations=remove_empty_anno, is_training=is_training) + distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle) + yolo_dataset.size = len(distributed_sampler) + + config.dataset_size = len(yolo_dataset) + cores = multiprocessing.cpu_count() + num_parallel_workers = int(cores / device_num) + if is_training: + multi_scale_trans = MultiScaleTrans(config, device_num) + yolo_dataset.transforms = multi_scale_trans + + dataset_column_names = ["image", "annotation", "input_size", "mosaic_flag"] + output_column_names = ["image", "annotation", "bbox1", "bbox2", "bbox3", + "gt_box1", "gt_box2", "gt_box3"] + map1_out_column_names = ["image", "annotation", "size"] + map2_in_column_names = ["annotation", "size"] + map2_out_column_names = ["annotation", "bbox1", "bbox2", "bbox3", + "gt_box1", "gt_box2", "gt_box3"] + + dataset = ds.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, sampler=distributed_sampler, + python_multiprocessing=True, num_parallel_workers=min(4, num_parallel_workers)) + dataset = dataset.map(operations=multi_scale_trans, input_columns=dataset_column_names, + output_columns=map1_out_column_names, + num_parallel_workers=min(12, num_parallel_workers), python_multiprocessing=True) + dataset = dataset.map(operations=PreprocessTrueBox(config), input_columns=map2_in_column_names, + output_columns=map2_out_column_names, + num_parallel_workers=min(4, num_parallel_workers), python_multiprocessing=False) + dataset = dataset.project(output_column_names) + dataset = dataset.batch(batch_size, num_parallel_workers=min(4, num_parallel_workers), drop_remainder=True) + else: + dataset = ds.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"], + sampler=distributed_sampler) + compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config)) + dataset = dataset.map(operations=compose_map_func, input_columns=["image", "img_id"], + output_columns=["image", "image_shape", "img_id"], + num_parallel_workers=8) + dataset = dataset.batch(batch_size, drop_remainder=True) + return dataset diff --git a/community/cv/ADCAM/train.py b/community/cv/ADCAM/train.py new file mode 100644 index 0000000000000000000000000000000000000000..af0d4721ec216bc5b9d1faffa3c988a987c2e5e3 --- /dev/null +++ b/community/cv/ADCAM/train.py @@ -0,0 +1,253 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train.""" +import os +import time +from collections import deque +import mindspore +import mindspore.nn as nn +import mindspore.communication as comm +from mindspore import load_checkpoint, Parameter, save_checkpoint + +from model_utils.config import config +from model_utils.device_adapter import get_device_id +from model_utils.moxing_adapter import ( + moxing_wrapper, + modelarts_pre_process, + modelarts_post_process, +) + +from src.yolo import YOLOV5, YoloWithLossCell +from src.logger import get_logger +from src.util import ( + AverageMeter, + get_param_groups, + cpu_affinity, + EvalWrapper, + DetectionEngine, +) +from src.lr_scheduler import get_lr +from src.yolo_dataset import create_yolo_dataset +from src.initializer import default_recurisive_init, load_yolov5_params + + +mindspore.set_seed(1) + + +def init_distribute(): + """init_distribute.""" + comm.init() + config.rank = comm.get_rank() + config.group_size = comm.get_group_size() + mindspore.set_auto_parallel_context( + parallel_mode=mindspore.ParallelMode.DATA_PARALLEL, + gradients_mean=True, + device_num=config.group_size, + ) + + +def train_preprocess(): + """train_preprocess.""" + if config.lr_scheduler == "cosine_annealing" and config.max_epoch > config.T_max: + config.T_max = config.max_epoch + + config.lr_epochs = list(map(int, config.lr_epochs.split(","))) + config.train_img_dir = os.path.join(config.data_dir, config.train_img_dir) + config.train_ann_file = os.path.join(config.data_dir, config.train_ann_file) + device_id = get_device_id() + if config.device_target == "Ascend": + device_id = get_device_id() + mindspore.set_context( + mode=0, device_target=config.device_target, device_id=device_id + ) + else: + mindspore.set_context(mode=0, device_target=config.device_target) + + if config.is_distributed: + init_distribute() + + if config.device_target == "GPU" and config.bind_cpu: + cpu_affinity(config.rank, min(config.group_size, config.device_num)) + + config.logger = get_logger(config.output_dir, config.rank) + config.logger.save_args(config) + + +def get_val_dataset(): + """get_val_dataset.""" + config.val_img_dir = os.path.join(config.data_dir, config.val_img_dir) + config.val_ann_file = os.path.join(config.data_dir, config.val_ann_file) + ds_val = create_yolo_dataset( + config.val_img_dir, + config.val_ann_file, + is_training=False, + batch_size=config.per_batch_size, + device_num=config.group_size, + rank=config.rank, + config=config, + ) + config.logger.info("Finish loading val dataset!") + return ds_val + + +def load_parameters(val_network, train_network): + """load_parameters.""" + config.logger.info("Load parameters of train network") + param_dict_new = {} + for key, values in train_network.parameters_and_names(): + if key.startswith("moments."): + continue + elif key.startswith("yolo_network."): + param_dict_new[key[13:]] = values + else: + param_dict_new[key] = values + mindspore.load_param_into_net(val_network, param_dict_new) + config.logger.info("Load train network success") + + +def load_best_results(): + best_ckpt_path = os.path.join(config.output_dir, "best.ckpt") + if os.path.exists(best_ckpt_path): + param_dict = load_checkpoint(best_ckpt_path) + best_result = param_dict["best_result"].asnumpy().item() + best_epoch = param_dict["best_epoch"].asnumpy().item() + config.logger.info("cur best result %s at epoch %s", best_result, best_epoch) + return best_result, best_epoch + return 0.0, 0 + + +def save_best_checkpoint(network, best_result, best_epoch): + param_list = [ + {"name": "best_result", "data": Parameter(best_result)}, + {"name": "best_epoch", "data": Parameter(best_epoch)}, + ] + for name, param in network.parameters_and_names(): + param_list.append({"name": name, "data": param}) + save_checkpoint(param_list, os.path.join(config.output_dir, "best.ckpt")) + + +def is_val_epoch(epoch_idx: int): + epoch = epoch_idx + 1 + return (epoch >= config.eval_start_epoch) and ( + (epoch_idx + 1) % config.eval_epoch_interval == 0 + or (epoch_idx + 1) == config.max_epoch + ) + + +@moxing_wrapper( + pre_process=modelarts_pre_process, + post_process=modelarts_post_process, + pre_args=[config], +) +def run_train(): + """run_train.""" + train_preprocess() + config.eval_parallel = (config.run_eval and config.is_distributed and config.eval_parallel) + loss_meter = AverageMeter("loss") + dict_version = {"yolov5s": 0, "yolov5m": 1, "yolov5l": 2, "yolov5x": 3} + network = YOLOV5(is_training=True, version=dict_version[config.yolov5_version]) + val_network = YOLOV5(is_training=False, version=dict_version[config.yolov5_version]) + default_recurisive_init(network) + load_yolov5_params(config, network) + network = YoloWithLossCell(network) + + ds = create_yolo_dataset( + image_dir=config.train_img_dir, + anno_path=config.train_ann_file, + is_training=True, + batch_size=config.per_batch_size, + device_num=config.group_size, + rank=config.rank, + config=config, + ) + config.logger.info("Finish loading train dataset") + ds_val = get_val_dataset() + + steps_per_epoch = ds.get_dataset_size() + lr = get_lr(config, steps_per_epoch) + opt = nn.Momentum( + params=get_param_groups(network), + momentum=config.momentum, + learning_rate=mindspore.Tensor(lr), + weight_decay=config.weight_decay, + loss_scale=config.loss_scale, + ) + network = nn.TrainOneStepCell(network, opt, config.loss_scale // 2) + network.set_train() + + data_loader = ds.create_tuple_iterator(do_copy=False) + first_step = True + t_end = time.time() + best_result, best_epoch = load_best_results() + engine = DetectionEngine(config, config.test_ignore_threshold) + eval_wrapper = EvalWrapper(config, val_network, ds_val, engine) + ckpt_queue = deque() + for epoch_idx in range(config.max_epoch): + for step_idx, data in enumerate(data_loader): + images = data[0] + input_shape = images.shape[1:3] + input_shape = mindspore.Tensor(input_shape, mindspore.float32) + loss = network(images, data[2], data[3], data[4], data[5], data[6], data[7], input_shape,) + loss_meter.update(loss.asnumpy()) + + if (epoch_idx * steps_per_epoch + step_idx) % config.log_interval == 0: + time_used = time.time() - t_end + if first_step: + fps = config.per_batch_size * config.group_size / time_used + per_step_time = time_used * 1000 + first_step = False + else: + fps = (config.per_batch_size * config.log_interval * config.group_size / time_used) + per_step_time = time_used / config.log_interval * 1000 + config.logger.info( + "epoch[{}], iter[{}], {}, fps:{:.2f} imgs/sec, " + "lr:{}, per step time: {}ms".format(epoch_idx + 1, step_idx + 1, loss_meter, + fps, lr[step_idx], per_step_time,) + ) + t_end = time.time() + loss_meter.reset() + if config.rank == 0 and (epoch_idx % config.save_ckpt_interval == 0): + ckpt_name = os.path.join( + config.output_dir, + "yolov5_{}_{}.ckpt".format(epoch_idx + 1, steps_per_epoch), + ) + mindspore.save_checkpoint(network, ckpt_name) + if len(ckpt_queue) == config.save_ckpt_max_num: + ckpt_to_remove = ckpt_queue.popleft() + os.remove(ckpt_to_remove) + ckpt_queue.append(ckpt_name) + + if is_val_epoch(epoch_idx): + load_parameters(val_network, train_network=network) + eval_wrapper.inference() + eval_result, mAP = eval_wrapper.get_results( + cur_epoch=epoch_idx + 1, cur_step=steps_per_epoch + ) + if mAP >= best_result: + best_result = mAP + best_epoch = epoch_idx + 1 + if config.rank == 0: + save_best_checkpoint(network, best_result, best_epoch) + config.logger.info( + "Best result %s at %s epoch", best_result, best_epoch + ) + config.logger.info(eval_result) + config.logger.info("Ending inference...") + + config.logger.info("==========end training===============") + + +if __name__ == "__main__": + run_train()