From 4ca34d3e89b7c045cdf0be51bafa52ee0b8fad93 Mon Sep 17 00:00:00 2001 From: wan-zutao <1025494833@qq.com> Date: Thu, 21 Mar 2024 10:00:51 +0800 Subject: [PATCH 01/18] add Detection Sample --- .../dataset/.gitkeep | 0 Samples/DetectionRetrainingAndInfer/export.py | 46 ++ Samples/DetectionRetrainingAndInfer/main.cpp | 129 ++++++ .../omInfer/scripts/sample_build.sh | 40 ++ .../omInfer/scripts/sample_run.sh | 17 + .../omInfer/src/CMakeLists.txt | 56 +++ .../omInfer/src/main.cpp | 137 ++++++ .../DetectionRetrainingAndInfer/predata.py | 47 ++ Samples/DetectionRetrainingAndInfer/train.py | 223 ++++++++++ .../vision/__init__.py | 0 .../vision/dataset.py | 160 +++++++ .../vision/nn/__init__.py | 0 .../vision/nn/mobilenet.py | 52 +++ .../vision/nn/multibox_loss.py | 47 ++ .../vision/nn/scaled_l2_norm.py | 19 + .../vision/nn/squeezenet.py | 130 ++++++ .../vision/ssd/__init__.py | 0 .../vision/ssd/config/__init__.py | 0 .../ssd/config/mobilenetv1_ssd_config.py | 76 ++++ .../vision/ssd/data_preprocessing.py | 62 +++ .../vision/ssd/mobilenetv1_ssd.py | 74 ++++ .../vision/ssd/predictor.py | 82 ++++ .../vision/ssd/ssd.py | 177 ++++++++ .../vision/transforms/__init__.py | 0 .../vision/transforms/transforms.py | 409 ++++++++++++++++++ .../vision/utils/__init__.py | 1 + .../vision/utils/box_utils.py | 295 +++++++++++++ .../vision/utils/box_utils_numpy.py | 238 ++++++++++ .../vision/utils/measurements.py | 32 ++ .../vision/utils/misc.py | 45 ++ .../vision/utils/model_book.py | 81 ++++ 31 files changed, 2675 insertions(+) create mode 100644 Samples/DetectionRetrainingAndInfer/dataset/.gitkeep create mode 100644 Samples/DetectionRetrainingAndInfer/export.py create mode 100644 Samples/DetectionRetrainingAndInfer/main.cpp create mode 100644 Samples/DetectionRetrainingAndInfer/omInfer/scripts/sample_build.sh create mode 100644 Samples/DetectionRetrainingAndInfer/omInfer/scripts/sample_run.sh create mode 100644 Samples/DetectionRetrainingAndInfer/omInfer/src/CMakeLists.txt create mode 100644 Samples/DetectionRetrainingAndInfer/omInfer/src/main.cpp create mode 100644 Samples/DetectionRetrainingAndInfer/predata.py create mode 100644 Samples/DetectionRetrainingAndInfer/train.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/__init__.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/dataset.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/nn/__init__.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/nn/mobilenet.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/nn/multibox_loss.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/nn/scaled_l2_norm.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/nn/squeezenet.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/ssd/__init__.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/ssd/config/__init__.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/ssd/config/mobilenetv1_ssd_config.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/ssd/data_preprocessing.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/ssd/mobilenetv1_ssd.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/ssd/predictor.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/ssd/ssd.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/transforms/__init__.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/transforms/transforms.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/utils/__init__.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/utils/box_utils.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/utils/box_utils_numpy.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/utils/measurements.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/utils/misc.py create mode 100644 Samples/DetectionRetrainingAndInfer/vision/utils/model_book.py diff --git a/Samples/DetectionRetrainingAndInfer/dataset/.gitkeep b/Samples/DetectionRetrainingAndInfer/dataset/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/Samples/DetectionRetrainingAndInfer/export.py b/Samples/DetectionRetrainingAndInfer/export.py new file mode 100644 index 0000000..d71b92f --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/export.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# converts a saved PyTorch model to ONNX format +import os +import sys +import argparse + +import torch.onnx + +from vision.ssd.mobilenetv1_ssd import create_mobilenetv1_ssd +from vision.ssd.config import mobilenetv1_ssd_config + +# set the device +device = torch.device('cpu') +print(f"=> running on device {device}") + +input = "models/best.pth" +# format input model paths + +num_classes = 4 +resolution = 300 +net_name = 'ssd-mobilenet' +# construct the network architecture +print(f"=> creating network: {net_name}") +print(f"=> num classes: {num_classes}") +print(f"=> resolution: {resolution}x{resolution}") + +mobilenetv1_ssd_config.set_image_size(300) +net = create_mobilenetv1_ssd(num_classes, is_test=True) +# load the model checkpoint +print(f"=> loading checkpoint: {input}") + +net.load(input) +net.to(device) +net.eval() + +# create example image data +dummy_input = torch.randn(1, 3, resolution, resolution) +output = 'mobilenet-ssd.onnx' + +# export to ONNX +input_names = ['input_0'] +output_names = ['scores', 'boxes'] + +print("=> exporting model to ONNX...") +torch.onnx.export(net, dummy_input, output, verbose=True, input_names=input_names, output_names=output_names) +print(f"model exported to: {output}") diff --git a/Samples/DetectionRetrainingAndInfer/main.cpp b/Samples/DetectionRetrainingAndInfer/main.cpp new file mode 100644 index 0000000..f74d510 --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/main.cpp @@ -0,0 +1,129 @@ +#include +#include +#include +#include +#include "acllite_dvpp_lite/ImageProc.h" +#include "acllite_om_execute/ModelProc.h" + +using namespace std; +using namespace acllite; +typedef struct BoundBox { + float x; + float y; + float width; + float height; + float score; + int classIndex; +} BoundBox; +float iou(BoundBox box1, BoundBox box2) +{ + float xLeft = max(box1.x, box2.x); + float yTop = max(box1.y, box1.y); + float xRight = min(box1.x + box1.width, box1.x + box1.width); + float yBottom = min(box1.y + box1.height, box1.y + box1.height); + float width = max(0.0f, xRight - xLeft); + float hight = max(0.0f, yBottom - yTop); + float area = width * hight; + float iou = area / (box1.width * box1.height + box2.width * box2.height - area); + return iou; +} +bool sortScore(BoundBox box1, BoundBox box2) +{ + return box1.score > box2.score; +} +int main() +{ + vector labels = { {"with_mask"},{"mask_weared_incorrect"},{"without_mask"}}; + AclLiteResource aclResource; + bool ret = aclResource.Init(); + CHECK_RET(ret, LOG_PRINT("[ERROR] InitACLResource failed."); return 1); + + ImageProc imageProc; + ModelProc modelProc; + ret = modelProc.Load("../model/ssd-mobilenet.om"); + CHECK_RET(ret, LOG_PRINT("[ERROR] load model Resnet18.om failed."); return 1); + ImageData src = imageProc.Read("../data/8.jpg"); + CHECK_RET(src.size, LOG_PRINT("[ERROR] ImRead image failed."); return 1); + + ImageData dst; + ImageSize dsize(300, 300); + + imageProc.Resize(src, dst, dsize); + ret = modelProc.CreateInput(static_cast(dst.data.get()), dst.size); + CHECK_RET(ret, LOG_PRINT("[ERROR] Create model input failed."); return 1); + vector inferOutputs; + ret = modelProc.Execute(inferOutputs); + CHECK_RET(ret, LOG_PRINT("[ERROR] model execute failed."); return 1); + + uint32_t dataSize = inferOutputs[0].size; + // get result from output data set + float* scores = static_cast(inferOutputs[0].data.get()); + float* boxes = static_cast(inferOutputs[1].data.get()); + if (scores == nullptr || boxes == nullptr) { + LOG_PRINT("get result from output data set failed."); + return 1; + } + size_t classNum = 3; + size_t boxes_nums = 3000; + size_t candidate_size = 200; + size_t top_k = 20; + size_t prob_threshold = 0.4; + size_t iou_threshold = 0.4; + int half = 2; + const double fountScale = 0.5; + const uint32_t lineSolid = 2; + const uint32_t labelOffset = 11; + const cv::Scalar fountColor(0, 0, 255); + const vector colors{ + cv::Scalar(237, 149, 100), cv::Scalar(0, 215, 255), + cv::Scalar(50, 205, 50), cv::Scalar(139, 85, 26)}; + cv::Mat srcImage = cv::imread(imagePath); + for(int class = 0; i < classNum; class++) { + vector box_scores; + vector result; + for(int j = 0; j < boxes_nums; ++j){ + if(scores[j * classNum + class]) > prob_threshold{ + BoundBox box; + box.score = scores[j * classNum + class]; + box.x = boxes[4 * boxes_nums]; + box.y = boxes[4 * boxes_nums + 1]; + box.width = boxes[4 * boxes_nums + 2]; + box.height = boxes[4 * boxes_nums + 3]; + box.classIndex = class; + box_scores.push_back(box); + } + } + std::sort(box_scores.begin(),box_scores.end(),sortScore); + box_scores.erase(box_scores.begin() + candidate_size + 1, box_scores.end() + 1); + int len = box_scores.length(); + if(len > 0){ + for(int i = 0;i < box_scores.length(); i++){ + if(result.length() == top_k) break; + result.push_back(box_scores[i]); + for(int j = i + 1; j < box_scores.length();j++){ + float iou_t = iou(box_scores[i],box_scores[j]); + if(iou_t > iout_threshold){ + box_scores.erase(box_scores.begin() + j); + } + } + } + } + for (size_t i = 0; i < result.size(); ++i) { + cv::Point leftUpPoint, rightBottomPoint; + leftUpPoint.x = result[i].x - result[i].width / half; + leftUpPoint.y = result[i].y - result[i].height / half; + rightBottomPoint.x = result[i].x + result[i].width / half; + rightBottomPoint.y = result[i].y + result[i].height / half; + cv::rectangle(srcImage, leftUpPoint, rightBottomPoint, colors[i % colors.size()], lineSolid); + string className = label[result[i].classIndex]; + string markString = to_string(result[i].score) + ":" + className; + cv::putText(srcImage, markString, cv::Point(leftUpPoint.x, leftUpPoint.y + labelOffset), + cv::FONT_HERSHEY_COMPLEX, fountScale, fountColor); + } + + } + string savePath = "../output/out_0.jpg"; + cv::imwrite(savePath, srcImage); + return 0; +} + diff --git a/Samples/DetectionRetrainingAndInfer/omInfer/scripts/sample_build.sh b/Samples/DetectionRetrainingAndInfer/omInfer/scripts/sample_build.sh new file mode 100644 index 0000000..d5837a4 --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/omInfer/scripts/sample_build.sh @@ -0,0 +1,40 @@ +#!/bin/bash +ScriptPath="$( cd "$(dirname "$BASH_SOURCE")" ; pwd -P )" + +function build() +{ + if [ -d ${ScriptPath}/../out ];then + rm -rf ${ScriptPath}/../out + fi + + if [ -d ${ScriptPath}/../build/intermediates/host ];then + rm -rf ${ScriptPath}/../build/intermediates/host + fi + + mkdir -p ${ScriptPath}/../build/intermediates/host + cd ${ScriptPath}/../build/intermediates/host + + cmake ../../../src -DCMAKE_CXX_COMPILER=g++ -DCMAKE_SKIP_RPATH=TRUE + if [ $? -ne 0 ];then + echo "[ERROR] cmake error, Please check your environment!" + return 1 + fi + make + if [ $? -ne 0 ];then + echo "[ERROR] build failed, Please check your environment!" + return 1 + fi + cd - > /dev/null +} + +function main() +{ + echo "[INFO] Sample preparation" + build + if [ $? -ne 0 ];then + return 1 + fi + echo "[INFO] Sample preparation is complete" +} +main + diff --git a/Samples/DetectionRetrainingAndInfer/omInfer/scripts/sample_run.sh b/Samples/DetectionRetrainingAndInfer/omInfer/scripts/sample_run.sh new file mode 100644 index 0000000..2fe8dad --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/omInfer/scripts/sample_run.sh @@ -0,0 +1,17 @@ +#!/bin/bash +ScriptPath="$( cd "$(dirname "$BASH_SOURCE")" ; pwd -P )" + +function main() +{ + echo "[INFO] The sample starts to run" + running_command="./main" + cd ${ScriptPath}/../out + ${running_command} + if [ $? -ne 0 ];then + echo "[INFO] The program runs failed" + else + echo "[INFO] The program runs successfully" + fi +} +main + diff --git a/Samples/DetectionRetrainingAndInfer/omInfer/src/CMakeLists.txt b/Samples/DetectionRetrainingAndInfer/omInfer/src/CMakeLists.txt new file mode 100644 index 0000000..b448950 --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/omInfer/src/CMakeLists.txt @@ -0,0 +1,56 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2019. All rights reserved. + +cmake_minimum_required(VERSION 3.5.1) + +project(sampleUsbCamera) + +add_compile_options(-std=c++11) + +add_definitions(-DENABLE_DVPP_INTERFACE) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "../../../out") +set(CMAKE_CXX_FLAGS_DEBUG "-fPIC -O0 -g -Wall") +set(CMAKE_CXX_FLAGS_RELEASE "-fPIC -O2 -Wall") + +set(INC_PATH $ENV{DDK_PATH}) +if (NOT DEFINED ENV{DDK_PATH}) + set(INC_PATH "/usr/local/Ascend/ascend-toolkit/latest") + message(STATUS "set default INC_PATH: ${INC_PATH}") +else() + message(STATUS "set INC_PATH: ${INC_PATH}") +endif () + +set(LIB_PATH $ENV{NPU_HOST_LIB}) +if (NOT DEFINED ENV{NPU_HOST_LIB}) + set(LIB_PATH "/usr/local/Ascend/ascend-toolkit/latest/runtime/lib64/stub") + message(STATUS "set default LIB_PATH: ${LIB_PATH}") +else() + message(STATUS "set LIB_PATH: ${LIB_PATH}") +endif () + +find_package(OpenCV REQUIRED) +find_path(AVCODEC_INCLUDE_DIR libavcodec/avcodec.h) +find_library(AVCODEC_LIBRARY avcodec) + +include_directories( + ${OpenCV_INCLUDE_DIRS} + ${AVCODEC_INCLUDE_DIR} + ${INC_PATH}/runtime/include/ + ./ +) + +link_directories( + ${OpenCV_LIB_DIRS} + ${AVCODEC_LIBRARY} + ${LIB_PATH} +) + +add_executable(main + main.cpp) + +if(target STREQUAL "Simulator_Function") + target_link_libraries(main funcsim) +else() + target_link_libraries(main ascendcl acl_dvpp stdc++ dl rt pthread acllite_dvpp_lite acllite_media acllite_om_execute acllite_common ${AVCODEC_LIBRARY} ${OpenCV_LIBS}) +endif() + +install(TARGETS main DESTINATION ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) diff --git a/Samples/DetectionRetrainingAndInfer/omInfer/src/main.cpp b/Samples/DetectionRetrainingAndInfer/omInfer/src/main.cpp new file mode 100644 index 0000000..6fda9c4 --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/omInfer/src/main.cpp @@ -0,0 +1,137 @@ +#include +#include +#include +#include +#include "acllite_dvpp_lite/ImageProc.h" +#include "acllite_om_execute/ModelProc.h" +#include "opencv2/opencv.hpp" + +using namespace std; +using namespace acllite; +using namespace cv; + +typedef struct BoundBox { + float x; + float y; + float width; + float height; + float score; + int classIndex; +} BoundBox; +float iou(BoundBox box1, BoundBox box2) +{ + float xLeft = max(box1.x, box2.x); + float yTop = max(box1.y, box2.y); + float xRight = min(box1.x + box1.width, box2.x + box2.width); + float yBottom = min(box1.y + box1.height, box2.y + box2.height); + float width = max(0.0f, xRight - xLeft); + float hight = max(0.0f, yBottom - yTop); + float area = width * hight; + float iou = area / (box1.width * box1.height + box2.width * box2.height - area); + return iou; +} +bool sortScore(BoundBox box1, BoundBox box2) +{ + return box1.score > box2.score; +} +int main() +{ + vector labels = { {"BACKGROUND"},{"with_mask"},{"mask_weared_incorrect"},{"without_mask"}}; + AclLiteResource aclResource; + bool ret = aclResource.Init(); + CHECK_RET(ret, LOG_PRINT("[ERROR] InitACLResource failed."); return 1); + + ImageProc imageProc; + ModelProc modelProc; + ret = modelProc.Load("../model/mobilenet-ssd.om"); + CHECK_RET(ret, LOG_PRINT("[ERROR] load model mobilenet-ssd.om failed."); return 1); + string imagePath = "../data/mask3.jpg"; + ImageData src = imageProc.Read(imagePath); + + CHECK_RET(src.size, LOG_PRINT("[ERROR] ImRead image failed."); return 1); + ImageData dst; + ImageSize dsize(300, 300); + + imageProc.Resize(src, dst, dsize); + ret = modelProc.CreateInput(static_cast(dst.data.get()), dst.size); + CHECK_RET(ret, LOG_PRINT("[ERROR] Create model input failed."); return 1); + vector inferOutputs; + ret = modelProc.Execute(inferOutputs); + CHECK_RET(ret, LOG_PRINT("[ERROR] model execute failed."); return 1); + uint32_t dataSize = inferOutputs[0].size; + uint32_t size = inferOutputs[1].size; + // get result from output data set + float* scores = static_cast(inferOutputs[0].data.get()); + float* boxes = static_cast(inferOutputs[1].data.get()); + if (scores == nullptr || boxes == nullptr) { + LOG_PRINT("get result from output data set failed."); + return 1; + } + size_t classNum = 4; + size_t boxes_nums = 3000; + size_t candidate_size = 200; + size_t top_k = 20; + float prob_threshold = 0.7; + float iou_threshold = 0.45; + const double fountScale = 0.5; + const uint32_t lineSolid = 2; + const uint32_t labelOffset = 11; + const cv::Scalar fountColor(0, 0, 255); + const vector colors{ + cv::Scalar(237, 149, 100), cv::Scalar(0, 215, 255), + cv::Scalar(50, 205, 50), cv::Scalar(139, 85, 26)}; + cv::Mat srcImage = cv::imread(imagePath); + int width = srcImage.cols; + int height = srcImage.rows; + for(int index = 1; index < classNum; index++) { + vector box_scores; + vector result; + for(int j = 0; j < boxes_nums; ++j){ + if(scores[j * classNum + index] > prob_threshold){ + BoundBox box; + box.score = scores[j * classNum + index]; + box.width = (boxes[4 * j + 2]- boxes[4 * j]) * width ; + box.height = (boxes[4 * j + 3] - boxes[4 * j + 1]) * height; + box.x = boxes[4 * j] * width; + box.y = boxes[4 * j + 1] * height; + box.classIndex = index; + box_scores.push_back(box); + } + } + std::sort(box_scores.begin(),box_scores.end(),sortScore); + if(box_scores.size() > candidate_size){ + box_scores.erase(box_scores.begin() + candidate_size + 1, box_scores.end()); + } + int len = box_scores.size(); + if(len > 0){ + for(int i = 0;i < box_scores.size(); i++){ + if(result.size() == top_k) break; + result.push_back(box_scores[i]); + for(int j = i + 1; j < box_scores.size();j++){ + float iou_t = iou(box_scores[i],box_scores[j]); + if(iou_t > iou_threshold){ + box_scores.erase(box_scores.begin() + j); + j--; + } + } + } + } + for (size_t i = 0; i < result.size(); ++i) { + cv::Point leftUpPoint, rightBottomPoint; + leftUpPoint.x = result[i].x ; + leftUpPoint.y = result[i].y; + rightBottomPoint.x = result[i].x + result[i].width; + rightBottomPoint.y = result[i].y + result[i].height; + cv::rectangle(srcImage, leftUpPoint, rightBottomPoint, colors[i % colors.size()], lineSolid); + string className = labels[result[i].classIndex]; + string markString = to_string(result[i].score) + ":" + className; + cv::putText(srcImage, markString, cv::Point(leftUpPoint.x, leftUpPoint.y + labelOffset), + cv::FONT_HERSHEY_COMPLEX, fountScale, fountColor); + } + + } + string savePath = "../output/out_0.jpg"; + cv::imwrite(savePath, srcImage); + return 0; +} + diff --git a/Samples/DetectionRetrainingAndInfer/predata.py b/Samples/DetectionRetrainingAndInfer/predata.py new file mode 100644 index 0000000..1eecee1 --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/predata.py @@ -0,0 +1,47 @@ +import os +import shutil +def preparinbdata(main_xml_file, main_img_file, train_size, val_size): + for i in range(0, train_size): + source_xml = main_xml_file + "/" + metarial[i] + ".xml" + source_img = main_img_file + "/" + metarial[i] + ".png" + + mstring = metarial[i] + train_destination_xml = "./data/mask/train/labels" + "/" + metarial[i] + ".xml" + train_destination_png = "./data/mask/train/images" + "/" + metarial[i] + ".png" + + shutil.copy(source_xml, train_destination_xml) + shutil.copy(source_img, train_destination_png) + for n in range(train_size , train_size + val_size): + + source_xml = main_xml_file + "/" + metarial[n] + ".xml" + source_img = main_img_file + "/" + metarial[n] + ".png" + + mstring = metarial[n] + val_destination_xml = "./data/mask/val/labels" + "/" + metarial[n] + ".xml" + val_destination_png = "./data/mask/val/images" + "/" + metarial[n] + ".png" + + shutil.copy(source_xml, val_destination_xml) + shutil.copy(source_img, val_destination_png) + +if __name__ == '__main__': + metarial = [] + for i in os.listdir("./data/images"): + str = i[:-4] + metarial.append(str) + train_size = int(len(metarial) * 0.7) + val_size = int(len(metarial) * 0.3) + print("Sum of image: ", len(metarial)) + print("Sum of the train size: ", train_size) + print("Sum of the val size: ", val_size) + if not os.path.exists("./data/mask"): + os.mkdir('./data/mask') + os.mkdir('./data/mask/train') + os.mkdir('./data/mask/val') + os.mkdir('./data/mask/train/images') + os.mkdir('./data/mask/train/labels') + os.mkdir('./data/mask/val/images') + os.mkdir('./data/mask/val/labels') + preparinbdata(main_xml_file = "./data/annotations", + main_img_file = "./data/images", + train_size = train_size, + val_size = val_size) diff --git a/Samples/DetectionRetrainingAndInfer/train.py b/Samples/DetectionRetrainingAndInfer/train.py new file mode 100644 index 0000000..0f3cb0d --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/train.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +# +# train an SSD detection model on Pascal VOC or Open Images datasets +# https://github.com/dusty-nv/jetson-inference/blob/master/docs/pytorch-ssd.md +# +import os +import sys +import logging +import datetime +import torch +import torch_npu +from torch_npu.npu import amp + +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from torch.optim.lr_scheduler import CosineAnnealingLR + +from vision.utils.misc import Timer +from vision.ssd.ssd import MatchPrior +from vision.ssd.mobilenetv1_ssd import create_mobilenetv1_ssd +from vision.dataset import VOCDataset +from vision.nn.multibox_loss import MultiboxLoss +from vision.ssd.config import mobilenetv1_ssd_config +from vision.ssd.data_preprocessing import TrainAugmentation, TestTransform + + +DEFAULT_PRETRAINED_MODEL='models/mobilenet-v1-ssd-mp-0_675.pth' + +logging.basicConfig(stream=sys.stdout, level=getattr(logging, "INFO", logging.INFO), + format='%(asctime)s - %(message)s', datefmt="%Y-%m-%d %H:%M:%S") +# make sure that the checkpoint output dir exists +checkpoint_folder = "models" +checkpoint_folder = os.path.expanduser(checkpoint_folder) +if not os.path.exists(checkpoint_folder): + os.mkdir(checkpoint_folder) +tensorboard = SummaryWriter(log_dir=os.path.join(checkpoint_folder, "tensorboard", f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")) + +DEVICE = torch.device("npu:0") + +def train(loader, net, criterion, optimizer, device, scaler, debug_steps=100, epoch=-1): + net.train(True) + + train_loss = 0.0 + train_regression_loss = 0.0 + train_classification_loss = 0.0 + + running_loss = 0.0 + running_regression_loss = 0.0 + running_classification_loss = 0.0 + + num_batches = 0 + + for i, data in enumerate(loader): + images, boxes, labels = data + images = images.to(device) + boxes = boxes.to(device) + labels = labels.to(device) + + optimizer.zero_grad() + with amp.autocast(): + confidence, locations = net(images) + regression_loss, classification_loss = criterion(confidence, locations, labels, boxes) + loss = regression_loss + classification_loss + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + train_loss += loss.item() + train_regression_loss += regression_loss.item() + train_classification_loss += classification_loss.item() + + running_loss += loss.item() + running_regression_loss += regression_loss.item() + running_classification_loss += classification_loss.item() + + if i and i % debug_steps == 0: + avg_loss = running_loss / debug_steps + avg_reg_loss = running_regression_loss / debug_steps + avg_clf_loss = running_classification_loss / debug_steps + logging.info( + f"Epoch: {epoch}, Step: {i}/{len(loader)}, " + + f"Avg Loss: {avg_loss:.4f}, " + + f"Avg Regression Loss {avg_reg_loss:.4f}, " + + f"Avg Classification Loss: {avg_clf_loss:.4f}" + ) + running_loss = 0.0 + running_regression_loss = 0.0 + running_classification_loss = 0.0 + + num_batches += 1 + + train_loss /= num_batches + train_regression_loss /= num_batches + train_classification_loss /= num_batches + + logging.info( + f"Epoch: {epoch}, " + + f"Training Loss: {train_loss:.4f}, " + + f"Training Regression Loss {train_regression_loss:.4f}, " + + f"Training Classification Loss: {train_classification_loss:.4f}" + ) + + tensorboard.add_scalar('Loss/train', train_loss, epoch) + tensorboard.add_scalar('Regression Loss/train', train_regression_loss, epoch) + tensorboard.add_scalar('Classification Loss/train', train_classification_loss, epoch) + +def test(loader, net, criterion, device): + net.eval() + running_loss = 0.0 + running_regression_loss = 0.0 + running_classification_loss = 0.0 + num = 0 + for _, data in enumerate(loader): + images, boxes, labels = data + images = images.to(device) + boxes = boxes.to(device) + labels = labels.to(device) + num += 1 + with torch.no_grad(): + with amp.autocast(): + confidence, locations = net(images) + regression_loss, classification_loss = criterion(confidence, locations, labels, boxes) + loss = regression_loss + classification_loss + + running_loss += loss.item() + running_regression_loss += regression_loss.item() + running_classification_loss += classification_loss.item() + + return running_loss / num, running_regression_loss / num, running_classification_loss / num + +if __name__ == '__main__': + + timer = Timer() + create_net = create_mobilenetv1_ssd + config = mobilenetv1_ssd_config + config.set_image_size(300) + + # create data transforms for train/test/val + train_transform = TrainAugmentation(config.image_size, config.image_mean, config.image_std) + target_transform = MatchPrior(config.priors, config.center_variance, + config.size_variance, 0.5) + + test_transform = TestTransform(config.image_size, config.image_mean, config.image_std) + dataset_path = "dataset" + batch_size = 4 + num_workers = 3 + # load datasets (could be multiple) + logging.info("Prepare training datasets.") + train_dataset = VOCDataset(dataset_path, transform=train_transform, + target_transform=target_transform) + num_classes = len(train_dataset.class_names) + # create training dataset + logging.info("Train dataset size: {}".format(len(train_dataset))) + train_loader = DataLoader(train_dataset, batch_size, + num_workers=num_workers, + shuffle=True) + + # create validation dataset + val_dataset = VOCDataset(dataset_path, transform=test_transform, + target_transform=target_transform, is_test=True) + val_loader = DataLoader(val_dataset, batch_size, + num_workers=num_workers, + shuffle=False) + + # create the network + logging.info("Build network.") + net = create_net(num_classes) + last_epoch = -1 + + # load a previous model checkpoint (if requested) + timer.start("Load Model") + + logging.info(f"Init from pretrained SSD {DEFAULT_PRETRAINED_MODEL}") + + if not os.path.exists(DEFAULT_PRETRAINED_MODEL): + os.system(f"wget --quiet --show-progress --progress=bar:force:noscroll --no-check-certificate https://nvidia.box.com/shared/static/djf5w54rjvpqocsiztzaandq1m3avr7c.pth -O {DEFAULT_PRETRAINED_MODEL}") + + net.init_from_pretrained_ssd(DEFAULT_PRETRAINED_MODEL) + + logging.info(f'Took {timer.end("Load Model"):.2f} seconds to load the model.') + + # move the model to GPU + net.to(DEVICE) + + # define loss function and optimizer + criterion = MultiboxLoss(config.priors, iou_threshold=0.5, neg_pos_ratio=3, + center_variance=0.1, size_variance=0.2, device=DEVICE) + + optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, + weight_decay=5e-4) + scaler = amp.GradScaler() + logging.info("Uses CosineAnnealingLR scheduler.") + scheduler = CosineAnnealingLR(optimizer, 100, last_epoch=last_epoch) + + # train for the desired number of epochs + logging.info(f"Start training from epoch {last_epoch + 1}.") + num_epochs = 100 + best_loss = 10000 + model_path = os.path.join(checkpoint_folder, "best.pth") + for epoch in range(last_epoch + 1, num_epochs): + train(train_loader, net, criterion, optimizer, device=DEVICE, scaler=scaler, debug_steps=10, epoch=epoch) + scheduler.step() + val_loss, val_regression_loss, val_classification_loss = test(val_loader, net, criterion, DEVICE) + + logging.info( + f"Epoch: {epoch}, " + + f"Validation Loss: {val_loss:.4f}, " + + f"Validation Regression Loss {val_regression_loss:.4f}, " + + f"Validation Classification Loss: {val_classification_loss:.4f}" + ) + + tensorboard.add_scalar('Loss/val', val_loss, epoch) + tensorboard.add_scalar('Regression Loss/val', val_regression_loss, epoch) + tensorboard.add_scalar('Classification Loss/val', val_classification_loss, epoch) + + if val_loss < best_loss: + best_loss = val_loss + net.save(model_path) + logging.info(f"Saved model {model_path}") + + + logging.info("Task done, exiting program.") + tensorboard.close() \ No newline at end of file diff --git a/Samples/DetectionRetrainingAndInfer/vision/__init__.py b/Samples/DetectionRetrainingAndInfer/vision/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Samples/DetectionRetrainingAndInfer/vision/dataset.py b/Samples/DetectionRetrainingAndInfer/vision/dataset.py new file mode 100644 index 0000000..e5cffe9 --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/vision/dataset.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +import os +import logging + +import torch +import numpy as np +import xml.etree.ElementTree as ET + +from PIL import Image + + +class VOCDataset(torch.utils.data.Dataset): + """ + Object detection dataset for Pascal VOC (http://host.robots.ox.ac.uk/pascal/VOC/) + """ + def __init__(self, root, transform=None, target_transform=None, is_test=False, keep_difficult=False, label_file=None): + """ + Dataset for VOC data. + + Parameters: + root (string) -- path to the VOC2007 or VOC2012 dataset, containing the following sub-directories: + Annotations, ImageSets, JPEGImages, SegmentationClass, SegmentationObject + + is_test (bool) -- if true, then use the data subset from `ImageSets/Main/test.txt` + if false, then use the data subset from `ImageSets/Main/trainval.txt` + if these files don't exist, then `ImageSets/Main/default.txt` will be used + """ + self.root = root + self.is_test = is_test + self.transform = transform + self.target_transform = target_transform + if not os.path.exists(os.path.join(self.root, 'mask')): + logging.info("No dataset, please prepare dataset") + # determine the image set file to use + if is_test: + self.image_sets_file = os.path.join(self.root, 'mask/val') + else: + self.image_sets_file = os.path.join(self.root, 'mask/train') + + + # read the image set ID's + self.ids = self._read_image_ids() + self.keep_difficult = keep_difficult + + self.class_names = ('BACKGROUND','with_mask', + 'mask_weared_incorrect','without_mask') + + self.class_dict = {class_name: i for i, class_name in enumerate(self.class_names)} + + def __getitem__(self, index): + image_id = self.ids[index] + boxes, labels, is_difficult = self._get_annotation(image_id) + + if not self.keep_difficult: + boxes = boxes[is_difficult == 0] + labels = labels[is_difficult == 0] + + if logging.root.level is logging.DEBUG: + logging.debug(f"voc_dataset image_id={image_id}" + ' \n boxes=' + str(boxes) + ' \n labels=' + str(labels)) + + image = self._read_image(image_id) + + if self.transform: + image, boxes, labels = self.transform(image, boxes, labels) + if self.target_transform: + boxes, labels = self.target_transform(boxes, labels) + + return image, boxes, labels + + def get_image(self, index): + image_id = self.ids[index] + image = self._read_image(image_id) + if self.transform: + image, _ = self.transform(image) + return image + + def get_annotation(self, index): + image_id = self.ids[index] + return image_id, self._get_annotation(image_id) + + def __len__(self): + return len(self.ids) + + def _read_image_ids(self): + ids = [] + for i in os.listdir(os.path.join(self.image_sets_file,"images")): + image_id = i[:-4] + + if self._get_num_annotations(image_id) > 0: + if self._find_image(image_id) is not None: + ids.append(image_id) + else: + print('warning - could not find image {:s} - ignoring from dataset'.format(image_id)) + else: + print('warning - image {:s} has no box/labels annotations, ignoring from dataset'.format(image_id)) + + return ids + + def _get_num_annotations(self, image_id): + annotation_file = os.path.join(self.image_sets_file, f'labels/{image_id}.xml') + objects = ET.parse(annotation_file).findall("object") + return len(objects) + + def _get_annotation(self, image_id): + annotation_file = os.path.join(self.image_sets_file, f'labels/{image_id}.xml') + objects = ET.parse(annotation_file).findall("object") + boxes = [] + labels = [] + is_difficult = [] + for object in objects: + class_name = object.find('name').text.strip() #.lower().strip() + # we're only concerned with clases in our list + if class_name in self.class_dict: + bbox = object.find('bndbox') + + # VOC dataset format follows Matlab, in which indexes start from 0 + x1 = float(bbox.find('xmin').text) - 1 + y1 = float(bbox.find('ymin').text) - 1 + x2 = float(bbox.find('xmax').text) - 1 + y2 = float(bbox.find('ymax').text) - 1 + boxes.append([x1, y1, x2, y2]) + + labels.append(self.class_dict[class_name]) + + # retrieve element + is_difficult_obj = object.find('difficult') + is_difficult_str = '0' + + if is_difficult_obj is not None: + is_difficult_str = object.find('difficult').text + + is_difficult.append(int(is_difficult_str) if is_difficult_str else 0) + else: + print(f"warning - image {image_id} has object with unknown class '{class_name}'") + + return (np.array(boxes, dtype=np.float32), + np.array(labels, dtype=np.int64), + np.array(is_difficult, dtype=np.uint8)) + + def _find_image(self, image_id): + image_file = os.path.join(self.image_sets_file, f'images/{image_id}.png') + if os.path.exists(image_file): + return image_file + return None + + def _read_image(self, image_id): + image_file = self._find_image(image_id) + + if image_file is None: + raise IOError(f"failed to find {image_file}") + + image = Image.open(image_file).convert('RGB') + + if image is None or image.size == 0: + raise IOError(f"invalid/corrupt image {image_file}") + + return np.asarray(image) + + + diff --git a/Samples/DetectionRetrainingAndInfer/vision/nn/__init__.py b/Samples/DetectionRetrainingAndInfer/vision/nn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Samples/DetectionRetrainingAndInfer/vision/nn/mobilenet.py b/Samples/DetectionRetrainingAndInfer/vision/nn/mobilenet.py new file mode 100644 index 0000000..98300df --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/vision/nn/mobilenet.py @@ -0,0 +1,52 @@ +# borrowed from "https://github.com/marvis/pytorch-mobilenet" + +import torch.nn as nn +import torch.nn.functional as F + + +class MobileNetV1(nn.Module): + def __init__(self, num_classes=1024): + super(MobileNetV1, self).__init__() + + def conv_bn(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True) + ) + + def conv_dw(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.ReLU(inplace=True), + + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True), + ) + + self.model = nn.Sequential( + conv_bn(3, 32, 2), + conv_dw(32, 64, 1), + conv_dw(64, 128, 2), + conv_dw(128, 128, 1), + conv_dw(128, 256, 2), + conv_dw(256, 256, 1), + conv_dw(256, 512, 2), + conv_dw(512, 512, 1), + conv_dw(512, 512, 1), + conv_dw(512, 512, 1), + conv_dw(512, 512, 1), + conv_dw(512, 512, 1), + conv_dw(512, 1024, 2), + conv_dw(1024, 1024, 1), + ) + self.fc = nn.Linear(1024, num_classes) + + def forward(self, x): + x = self.model(x) + x = F.avg_pool2d(x, 7) + x = x.view(-1, 1024) + x = self.fc(x) + return x \ No newline at end of file diff --git a/Samples/DetectionRetrainingAndInfer/vision/nn/multibox_loss.py b/Samples/DetectionRetrainingAndInfer/vision/nn/multibox_loss.py new file mode 100644 index 0000000..2351c76 --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/vision/nn/multibox_loss.py @@ -0,0 +1,47 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch + + +from ..utils import box_utils + + +class MultiboxLoss(nn.Module): + def __init__(self, priors, iou_threshold, neg_pos_ratio, + center_variance, size_variance, device): + """Implement SSD Multibox Loss. + + Basically, Multibox loss combines classification loss + and Smooth L1 regression loss. + """ + super(MultiboxLoss, self).__init__() + self.iou_threshold = iou_threshold + self.neg_pos_ratio = neg_pos_ratio + self.center_variance = center_variance + self.size_variance = size_variance + self.priors = priors + self.priors.to(device) + + def forward(self, confidence, predicted_locations, labels, gt_locations): + """Compute classification loss and smooth l1 loss. + + Args: + confidence (batch_size, num_priors, num_classes): class predictions. + locations (batch_size, num_priors, 4): predicted locations. + labels (batch_size, num_priors): real labels of all the priors. + boxes (batch_size, num_priors, 4): real boxes corresponding all the priors. + """ + num_classes = confidence.size(2) + with torch.no_grad(): + # derived from cross_entropy=sum(log(p)) + loss = -F.log_softmax(confidence, dim=2)[:, :, 0] + mask = box_utils.hard_negative_mining(loss, labels, self.neg_pos_ratio) + + confidence = confidence[mask, :] + classification_loss = F.cross_entropy(confidence.reshape(-1, num_classes), labels[mask], size_average=False) + pos_mask = labels > 0 + predicted_locations = predicted_locations[pos_mask, :].reshape(-1, 4) + gt_locations = gt_locations[pos_mask, :].reshape(-1, 4) + smooth_l1_loss = F.smooth_l1_loss(predicted_locations, gt_locations, size_average=False) + num_pos = gt_locations.size(0) + return smooth_l1_loss/num_pos, classification_loss/num_pos diff --git a/Samples/DetectionRetrainingAndInfer/vision/nn/scaled_l2_norm.py b/Samples/DetectionRetrainingAndInfer/vision/nn/scaled_l2_norm.py new file mode 100644 index 0000000..c1fd642 --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/vision/nn/scaled_l2_norm.py @@ -0,0 +1,19 @@ +import torch.nn as nn +import torch +import torch.nn.functional as F + + +class ScaledL2Norm(nn.Module): + def __init__(self, in_channels, initial_scale): + super(ScaledL2Norm, self).__init__() + self.in_channels = in_channels + self.scale = nn.Parameter(torch.Tensor(in_channels)) + self.initial_scale = initial_scale + self.reset_parameters() + + def forward(self, x): + return (F.normalize(x, p=2, dim=1) + * self.scale.unsqueeze(0).unsqueeze(2).unsqueeze(3)) + + def reset_parameters(self): + self.scale.data.fill_(self.initial_scale) \ No newline at end of file diff --git a/Samples/DetectionRetrainingAndInfer/vision/nn/squeezenet.py b/Samples/DetectionRetrainingAndInfer/vision/nn/squeezenet.py new file mode 100644 index 0000000..d961678 --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/vision/nn/squeezenet.py @@ -0,0 +1,130 @@ +import math +import torch +import torch.nn as nn +import torch.nn.init as init +import torch.utils.model_zoo as model_zoo + + +__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] + + +model_urls = { + 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', + 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth', +} + + +class Fire(nn.Module): + + def __init__(self, inplanes, squeeze_planes, + expand1x1_planes, expand3x3_planes): + super(Fire, self).__init__() + self.inplanes = inplanes + self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) + self.squeeze_activation = nn.ReLU(inplace=True) + self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, + kernel_size=1) + self.expand1x1_activation = nn.ReLU(inplace=True) + self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, + kernel_size=3, padding=1) + self.expand3x3_activation = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.squeeze_activation(self.squeeze(x)) + return torch.cat([ + self.expand1x1_activation(self.expand1x1(x)), + self.expand3x3_activation(self.expand3x3(x)) + ], 1) + + +class SqueezeNet(nn.Module): + + def __init__(self, version=1.0, num_classes=1000): + super(SqueezeNet, self).__init__() + if version not in [1.0, 1.1]: + raise ValueError("Unsupported SqueezeNet version {version}:" + "1.0 or 1.1 expected".format(version=version)) + self.num_classes = num_classes + if version == 1.0: + self.features = nn.Sequential( + nn.Conv2d(3, 96, kernel_size=7, stride=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(96, 16, 64, 64), + Fire(128, 16, 64, 64), + Fire(128, 32, 128, 128), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(256, 32, 128, 128), + Fire(256, 48, 192, 192), + Fire(384, 48, 192, 192), + Fire(384, 64, 256, 256), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(512, 64, 256, 256), + ) + else: + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=3, stride=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + Fire(64, 16, 64, 64), + Fire(128, 16, 64, 64), + nn.MaxPool2d(kernel_size=3, stride=2), + Fire(128, 32, 128, 128), + Fire(256, 32, 128, 128), + nn.MaxPool2d(kernel_size=3, stride=2), + Fire(256, 48, 192, 192), + Fire(384, 48, 192, 192), + Fire(384, 64, 256, 256), + Fire(512, 64, 256, 256), + ) + # Final convolution is initialized differently form the rest + final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) + self.classifier = nn.Sequential( + nn.Dropout(p=0.5), + final_conv, + nn.ReLU(inplace=True), + nn.AvgPool2d(13, stride=1) + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m is final_conv: + init.normal_(m.weight, mean=0.0, std=0.01) + else: + init.kaiming_uniform_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + + def forward(self, x): + x = self.features(x) + x = self.classifier(x) + return x.view(x.size(0), self.num_classes) + + +def squeezenet1_0(pretrained=False, **kwargs): + r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level + accuracy with 50x fewer parameters and <0.5MB model size" + `_ paper. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = SqueezeNet(version=1.0, **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_0'])) + return model + + +def squeezenet1_1(pretrained=False, **kwargs): + r"""SqueezeNet 1.1 model from the `official SqueezeNet repo + `_. + SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters + than SqueezeNet 1.0, without sacrificing accuracy. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = SqueezeNet(version=1.1, **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1'])) + return model diff --git a/Samples/DetectionRetrainingAndInfer/vision/ssd/__init__.py b/Samples/DetectionRetrainingAndInfer/vision/ssd/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Samples/DetectionRetrainingAndInfer/vision/ssd/config/__init__.py b/Samples/DetectionRetrainingAndInfer/vision/ssd/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Samples/DetectionRetrainingAndInfer/vision/ssd/config/mobilenetv1_ssd_config.py b/Samples/DetectionRetrainingAndInfer/vision/ssd/config/mobilenetv1_ssd_config.py new file mode 100644 index 0000000..2267671 --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/vision/ssd/config/mobilenetv1_ssd_config.py @@ -0,0 +1,76 @@ +import numpy as np + +from vision.utils.box_utils import SSDSpec, SSDBoxSizes, generate_ssd_priors + +image_size = 300 +image_mean = np.array([127, 127, 127]) # RGB layout +image_std = 128.0 +iou_threshold = 0.45 +center_variance = 0.1 +size_variance = 0.2 + +specs = [ + SSDSpec(19, 16, SSDBoxSizes(60, 105), [2, 3]), + SSDSpec(10, 32, SSDBoxSizes(105, 150), [2, 3]), + SSDSpec(5, 64, SSDBoxSizes(150, 195), [2, 3]), + SSDSpec(3, 100, SSDBoxSizes(195, 240), [2, 3]), + SSDSpec(2, 150, SSDBoxSizes(240, 285), [2, 3]), + SSDSpec(1, 300, SSDBoxSizes(285, 330), [2, 3]) +] + +priors = generate_ssd_priors(specs, image_size) + + +def set_image_size(size=300, min_ratio=20, max_ratio=90): + global image_size + global specs + global priors + + from vision.ssd.mobilenetv1_ssd import create_mobilenetv1_ssd + + import torch + import math + import logging + + image_size = size + ssd = create_mobilenetv1_ssd(num_classes=3) # TODO does num_classes matter here? + x = torch.randn(1, 3, image_size, image_size) + + feature_maps = ssd(x, get_feature_map_size=True) + + steps = [ + math.ceil(image_size * 1.0 / feature_map) for feature_map in feature_maps + ] + step = int(math.floor((max_ratio - min_ratio) / (len(feature_maps) - 2))) + min_sizes = [] + max_sizes = [] + for ratio in range(min_ratio, max_ratio + 1, step): + min_sizes.append(image_size * ratio / 100.0) + max_sizes.append(image_size * (ratio + step) / 100.0) + min_sizes = [image_size * (min_ratio / 2) / 100.0] + min_sizes + max_sizes = [image_size * min_ratio / 100.0] + max_sizes + + # this update logic makes different boxes than the original for 300x300 (but better for power-of-two) + # for backwards-compatibility, keep the default 300x300 config if that's what's being called for + if image_size != 300: + specs = [] + + for i in range(len(feature_maps)): + specs.append( SSDSpec(feature_maps[i], steps[i], SSDBoxSizes(min_sizes[i], max_sizes[i]), [2, 3]) ) # ssd-mobilenet-* aspect ratio is [2,3] + + logging.info(f'model resolution {image_size}x{image_size}') + for spec in specs: + logging.info(str(spec)) + + priors = generate_ssd_priors(specs, image_size) + +#print(' ') +#print('SSD-Mobilenet-v1 priors:') +#print(priors.shape) +#print(priors) +#print(' ') + +#import torch +#torch.save(priors, 'mb1-ssd-priors.pt') + +#np.savetxt('mb1-ssd-priors.txt', priors.numpy()) diff --git a/Samples/DetectionRetrainingAndInfer/vision/ssd/data_preprocessing.py b/Samples/DetectionRetrainingAndInfer/vision/ssd/data_preprocessing.py new file mode 100644 index 0000000..ca79fed --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/vision/ssd/data_preprocessing.py @@ -0,0 +1,62 @@ +from ..transforms.transforms import * + + +class TrainAugmentation: + def __init__(self, size, mean=0, std=1.0): + """ + Args: + size: the size the of final image. + mean: mean pixel value per channel. + """ + self.mean = mean + self.size = size + self.augment = Compose([ + ConvertFromInts(), + PhotometricDistort(), + Expand(self.mean), + RandomSampleCrop(), + RandomMirror(), + ToPercentCoords(), + Resize(self.size), + SubtractMeans(self.mean), + lambda img, boxes=None, labels=None: (img / std, boxes, labels), + ToTensor(), + ]) + + def __call__(self, img, boxes, labels): + """ + + Args: + img: the output of cv.imread in RGB layout. + boxes: boundding boxes in the form of (x1, y1, x2, y2). + labels: labels of boxes. + """ + return self.augment(img, boxes, labels) + + +class TestTransform: + def __init__(self, size, mean=0.0, std=1.0): + self.transform = Compose([ + ToPercentCoords(), + Resize(size), + SubtractMeans(mean), + lambda img, boxes=None, labels=None: (img / std, boxes, labels), + ToTensor(), + ]) + + def __call__(self, image, boxes, labels): + return self.transform(image, boxes, labels) + + +class PredictionTransform: + def __init__(self, size, mean=0.0, std=1.0): + self.transform = Compose([ + Resize(size), + SubtractMeans(mean), + lambda img, boxes=None, labels=None: (img / std, boxes, labels), + ToTensor() + ]) + + def __call__(self, image): + image, _, _ = self.transform(image) + return image \ No newline at end of file diff --git a/Samples/DetectionRetrainingAndInfer/vision/ssd/mobilenetv1_ssd.py b/Samples/DetectionRetrainingAndInfer/vision/ssd/mobilenetv1_ssd.py new file mode 100644 index 0000000..22e8a86 --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/vision/ssd/mobilenetv1_ssd.py @@ -0,0 +1,74 @@ +import torch +from torch.nn import Conv2d, Sequential, ModuleList, ReLU +from ..nn.mobilenet import MobileNetV1 + +from .ssd import SSD +from .predictor import Predictor +from .config import mobilenetv1_ssd_config as config + + +def create_mobilenetv1_ssd(num_classes, is_test=False): + base_net = MobileNetV1(1001).model # disable dropout layer + + source_layer_indexes = [ + 12, + 14, + ] + extras = ModuleList([ + Sequential( + Conv2d(in_channels=1024, out_channels=256, kernel_size=1), + ReLU(), + Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1), + ReLU() + ), + Sequential( + Conv2d(in_channels=512, out_channels=128, kernel_size=1), + ReLU(), + Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), + ReLU() + ), + Sequential( + Conv2d(in_channels=256, out_channels=128, kernel_size=1), + ReLU(), + Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), + ReLU() + ), + Sequential( + Conv2d(in_channels=256, out_channels=128, kernel_size=1), + ReLU(), + Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), + ReLU() + ) + ]) + + regression_headers = ModuleList([ + Conv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1), + Conv2d(in_channels=1024, out_channels=6 * 4, kernel_size=3, padding=1), + Conv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1), + Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1), + Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1), + Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1), # TODO: change to kernel_size=1, padding=0? + ]) + + classification_headers = ModuleList([ + Conv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1), + Conv2d(in_channels=1024, out_channels=6 * num_classes, kernel_size=3, padding=1), + Conv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1), + Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), + Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), + Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), # TODO: change to kernel_size=1, padding=0? + ]) + + return SSD(num_classes, base_net, source_layer_indexes, + extras, classification_headers, regression_headers, is_test=is_test, config=config) + + +def create_mobilenetv1_ssd_predictor(net, candidate_size=200, nms_method=None, sigma=0.5, device=None): + predictor = Predictor(net, config.image_size, config.image_mean, + config.image_std, + nms_method=nms_method, + iou_threshold=config.iou_threshold, + candidate_size=candidate_size, + sigma=sigma, + device=device) + return predictor diff --git a/Samples/DetectionRetrainingAndInfer/vision/ssd/predictor.py b/Samples/DetectionRetrainingAndInfer/vision/ssd/predictor.py new file mode 100644 index 0000000..1afb5cb --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/vision/ssd/predictor.py @@ -0,0 +1,82 @@ +import torch + +from ..utils import box_utils +from .data_preprocessing import PredictionTransform +from ..utils.misc import Timer + + +class Predictor: + def __init__(self, net, size, mean=0.0, std=1.0, nms_method=None, + iou_threshold=0.45, filter_threshold=0.01, candidate_size=200, sigma=0.5, device=None): + self.net = net + self.transform = PredictionTransform(size, mean, std) + self.iou_threshold = iou_threshold + self.filter_threshold = filter_threshold + self.candidate_size = candidate_size + self.nms_method = nms_method + self.sigma = sigma + + if device: + self.device = device + else: + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + self.net.to(self.device) + self.timer = Timer() + + def predict(self, image, top_k=-1, prob_threshold=None): + cpu_device = torch.device("cpu") + height, width, _ = image.shape + + image = self.transform(image) + images = image.unsqueeze(0) + images = images.to(self.device) + + self.net.eval() + + with torch.no_grad(): + self.timer.start() + scores, boxes = self.net.forward(images) + #print("Inference time: ", self.timer.end()) + + boxes = boxes[0] + scores = scores[0] + + if not prob_threshold: + prob_threshold = self.filter_threshold + + # this version of nms is slower on GPU, so we move data to CPU. + boxes = boxes.to(cpu_device) + scores = scores.to(cpu_device) + picked_box_probs = [] + picked_labels = [] + + for class_index in range(1, scores.size(1)): + probs = scores[:, class_index] + mask = probs > prob_threshold + probs = probs[mask] + + if probs.size(0) == 0: + continue + + subset_boxes = boxes[mask, :] + box_probs = torch.cat([subset_boxes, probs.reshape(-1, 1)], dim=1) + box_probs = box_utils.nms(box_probs, self.nms_method, + score_threshold=prob_threshold, + iou_threshold=self.iou_threshold, + sigma=self.sigma, + top_k=top_k, + candidate_size=self.candidate_size) + picked_box_probs.append(box_probs) + picked_labels.extend([class_index] * box_probs.size(0)) + + if not picked_box_probs: + return torch.tensor([]), torch.tensor([]), torch.tensor([]) + + picked_box_probs = torch.cat(picked_box_probs) + picked_box_probs[:, 0] *= width + picked_box_probs[:, 1] *= height + picked_box_probs[:, 2] *= width + picked_box_probs[:, 3] *= height + + return picked_box_probs[:, :4], torch.tensor(picked_labels), picked_box_probs[:, 4] diff --git a/Samples/DetectionRetrainingAndInfer/vision/ssd/ssd.py b/Samples/DetectionRetrainingAndInfer/vision/ssd/ssd.py new file mode 100644 index 0000000..88bec1f --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/vision/ssd/ssd.py @@ -0,0 +1,177 @@ + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np + +from typing import List, Tuple +from collections import namedtuple + +from ..utils import box_utils + +GraphPath = namedtuple("GraphPath", ['s0', 'name', 's1']) + + +class SSD(nn.Module): + def __init__(self, num_classes: int, base_net: nn.ModuleList, source_layer_indexes: List[int], + extras: nn.ModuleList, classification_headers: nn.ModuleList, + regression_headers: nn.ModuleList, is_test=False, config=None, device=None): + """ + Compose a SSD model using the given components. + """ + super(SSD, self).__init__() + + self.num_classes = num_classes + self.base_net = base_net + self.source_layer_indexes = source_layer_indexes + self.extras = extras + self.classification_headers = classification_headers + self.regression_headers = regression_headers + self.is_test = is_test + self.config = config + + # register layers in source_layer_indexes by adding them to a module list + self.source_layer_add_ons = nn.ModuleList([t[1] for t in source_layer_indexes + if isinstance(t, tuple) and not isinstance(t, GraphPath)]) + if device: + self.device = device + else: + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + self.config = config + self.priors = config.priors.to(self.device) + + def forward(self, x: torch.Tensor, get_feature_map_size: bool=False) -> Tuple[torch.Tensor, torch.Tensor]: + confidences = [] + locations = [] + start_layer_index = 0 + header_index = 0 + if get_feature_map_size: + feature_maps = [] + for end_layer_index in self.source_layer_indexes: + if isinstance(end_layer_index, GraphPath): + path = end_layer_index + end_layer_index = end_layer_index.s0 + added_layer = None + elif isinstance(end_layer_index, tuple): + added_layer = end_layer_index[1] + end_layer_index = end_layer_index[0] + path = None + else: + added_layer = None + path = None + for layer in self.base_net[start_layer_index: end_layer_index]: + x = layer(x) + if added_layer: + y = added_layer(x) + else: + y = x + if path: + sub = getattr(self.base_net[end_layer_index], path.name) + for layer in sub[:path.s1]: + x = layer(x) + y = x + for layer in sub[path.s1:]: + x = layer(x) + end_layer_index += 1 + start_layer_index = end_layer_index + confidence, location = self.compute_header(header_index, y) + if get_feature_map_size: + feature_maps.append(y.shape[-1]) + header_index += 1 + confidences.append(confidence) + locations.append(location) + + for layer in self.base_net[end_layer_index:]: + x = layer(x) + + for layer in self.extras: + x = layer(x) + confidence, location = self.compute_header(header_index, x) + if get_feature_map_size: + feature_maps.append(x.shape[-1]) + header_index += 1 + confidences.append(confidence) + locations.append(location) + + if get_feature_map_size: + return feature_maps + + confidences = torch.cat(confidences, 1) + locations = torch.cat(locations, 1) + + if self.is_test: + confidences = F.softmax(confidences, dim=2) + boxes = box_utils.convert_locations_to_boxes( + locations, self.priors, self.config.center_variance, self.config.size_variance + ) + boxes = box_utils.center_form_to_corner_form(boxes) + return confidences, boxes + else: + return confidences, locations + + def compute_header(self, i, x): + confidence = self.classification_headers[i](x) + confidence = confidence.permute(0, 2, 3, 1).contiguous() + confidence = confidence.view(confidence.size(0), -1, self.num_classes) + + location = self.regression_headers[i](x) + location = location.permute(0, 2, 3, 1).contiguous() + location = location.view(location.size(0), -1, 4) + + return confidence, location + + def init_from_base_net(self, model): + self.base_net.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage), strict=True) + self.source_layer_add_ons.apply(_xavier_init_) + self.extras.apply(_xavier_init_) + self.classification_headers.apply(_xavier_init_) + self.regression_headers.apply(_xavier_init_) + + def init_from_pretrained_ssd(self, model): + state_dict = torch.load(model, map_location=lambda storage, loc: storage) + state_dict = {k: v for k, v in state_dict.items() if not (k.startswith("classification_headers") or k.startswith("regression_headers"))} + model_dict = self.state_dict() + model_dict.update(state_dict) + self.load_state_dict(model_dict) + self.classification_headers.apply(_xavier_init_) + self.regression_headers.apply(_xavier_init_) + + def init(self): + self.base_net.apply(_xavier_init_) + self.source_layer_add_ons.apply(_xavier_init_) + self.extras.apply(_xavier_init_) + self.classification_headers.apply(_xavier_init_) + self.regression_headers.apply(_xavier_init_) + + def load(self, model): + self.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage)) + + def save(self, model_path): + torch.save(self.state_dict(), model_path) + + +class MatchPrior(object): + def __init__(self, center_form_priors, center_variance, size_variance, iou_threshold): + self.center_form_priors = center_form_priors + self.corner_form_priors = box_utils.center_form_to_corner_form(center_form_priors) + self.center_variance = center_variance + self.size_variance = size_variance + self.iou_threshold = iou_threshold + + def __call__(self, gt_boxes, gt_labels): + if type(gt_boxes) is np.ndarray: + gt_boxes = torch.from_numpy(gt_boxes) + if type(gt_labels) is np.ndarray: + gt_labels = torch.from_numpy(gt_labels) + boxes, labels = box_utils.assign_priors(gt_boxes, gt_labels, + self.corner_form_priors, self.iou_threshold) + boxes = box_utils.corner_form_to_center_form(boxes) + locations = box_utils.convert_boxes_to_locations(boxes, self.center_form_priors, self.center_variance, self.size_variance) + return locations, labels + + +def _xavier_init_(m: nn.Module): + if isinstance(m, nn.Conv2d): + nn.init.xavier_uniform_(m.weight) diff --git a/Samples/DetectionRetrainingAndInfer/vision/transforms/__init__.py b/Samples/DetectionRetrainingAndInfer/vision/transforms/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Samples/DetectionRetrainingAndInfer/vision/transforms/transforms.py b/Samples/DetectionRetrainingAndInfer/vision/transforms/transforms.py new file mode 100644 index 0000000..753a628 --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/vision/transforms/transforms.py @@ -0,0 +1,409 @@ +# from https://github.com/amdegroot/ssd.pytorch + + +import torch +from torchvision import transforms +import cv2 +import numpy as np +import types +from numpy import random + + +def intersect(box_a, box_b): + max_xy = np.minimum(box_a[:, 2:], box_b[2:]) + min_xy = np.maximum(box_a[:, :2], box_b[:2]) + inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf) + return inter[:, 0] * inter[:, 1] + + +def jaccard_numpy(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes. The jaccard overlap + is simply the intersection over union of two boxes. + E.g.: + A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Args: + box_a: Multiple bounding boxes, Shape: [num_boxes,4] + box_b: Single bounding box, Shape: [4] + Return: + jaccard overlap: Shape: [box_a.shape[0], box_a.shape[1]] + """ + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2]-box_a[:, 0]) * + (box_a[:, 3]-box_a[:, 1])) # [A,B] + area_b = ((box_b[2]-box_b[0]) * + (box_b[3]-box_b[1])) # [A,B] + union = area_a + area_b - inter + return inter / union # [A,B] + + +class Compose(object): + """Composes several augmentations together. + Args: + transforms (List[Transform]): list of transforms to compose. + Example: + >>> augmentations.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img, boxes=None, labels=None): + for t in self.transforms: + img, boxes, labels = t(img, boxes, labels) + return img, boxes, labels + + +class Lambda(object): + """Applies a lambda as a transform.""" + + def __init__(self, lambd): + assert isinstance(lambd, types.LambdaType) + self.lambd = lambd + + def __call__(self, img, boxes=None, labels=None): + return self.lambd(img, boxes, labels) + + +class ConvertFromInts(object): + def __call__(self, image, boxes=None, labels=None): + return image.astype(np.float32), boxes, labels + + +class SubtractMeans(object): + def __init__(self, mean): + self.mean = np.array(mean, dtype=np.float32) + + def __call__(self, image, boxes=None, labels=None): + image = image.astype(np.float32) + image -= self.mean + return image.astype(np.float32), boxes, labels + + +class ToAbsoluteCoords(object): + def __call__(self, image, boxes=None, labels=None): + height, width, channels = image.shape + boxes[:, 0] *= width + boxes[:, 2] *= width + boxes[:, 1] *= height + boxes[:, 3] *= height + + return image, boxes, labels + + +class ToPercentCoords(object): + def __call__(self, image, boxes=None, labels=None): + height, width, channels = image.shape + boxes[:, 0] /= width + boxes[:, 2] /= width + boxes[:, 1] /= height + boxes[:, 3] /= height + + return image, boxes, labels + + +class Resize(object): + def __init__(self, size=300): + self.size = size + + def __call__(self, image, boxes=None, labels=None): + image = cv2.resize(image, (self.size, + self.size)) + return image, boxes, labels + + +class RandomSaturation(object): + def __init__(self, lower=0.5, upper=1.5): + self.lower = lower + self.upper = upper + assert self.upper >= self.lower, "contrast upper must be >= lower." + assert self.lower >= 0, "contrast lower must be non-negative." + + def __call__(self, image, boxes=None, labels=None): + if random.randint(2): + image[:, :, 1] *= random.uniform(self.lower, self.upper) + + return image, boxes, labels + + +class RandomHue(object): + def __init__(self, delta=18.0): + assert delta >= 0.0 and delta <= 360.0 + self.delta = delta + + def __call__(self, image, boxes=None, labels=None): + if random.randint(2): + image[:, :, 0] += random.uniform(-self.delta, self.delta) + image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0 + image[:, :, 0][image[:, :, 0] < 0.0] += 360.0 + return image, boxes, labels + + +class RandomLightingNoise(object): + def __init__(self): + self.perms = ((0, 1, 2), (0, 2, 1), + (1, 0, 2), (1, 2, 0), + (2, 0, 1), (2, 1, 0)) + + def __call__(self, image, boxes=None, labels=None): + if random.randint(2): + swap = self.perms[random.randint(len(self.perms))] + shuffle = SwapChannels(swap) # shuffle channels + image = shuffle(image) + return image, boxes, labels + + +class ConvertColor(object): + def __init__(self, current, transform): + self.transform = transform + self.current = current + + def __call__(self, image, boxes=None, labels=None): + if self.current == 'BGR' and self.transform == 'HSV': + image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + elif self.current == 'RGB' and self.transform == 'HSV': + image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) + elif self.current == 'BGR' and self.transform == 'RGB': + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + elif self.current == 'HSV' and self.transform == 'BGR': + image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) + elif self.current == 'HSV' and self.transform == "RGB": + image = cv2.cvtColor(image, cv2.COLOR_HSV2RGB) + else: + raise NotImplementedError + return image, boxes, labels + + +class RandomContrast(object): + def __init__(self, lower=0.5, upper=1.5): + self.lower = lower + self.upper = upper + assert self.upper >= self.lower, "contrast upper must be >= lower." + assert self.lower >= 0, "contrast lower must be non-negative." + + # expects float image + def __call__(self, image, boxes=None, labels=None): + if random.randint(2): + alpha = random.uniform(self.lower, self.upper) + image *= alpha + return image, boxes, labels + + +class RandomBrightness(object): + def __init__(self, delta=32): + assert delta >= 0.0 + assert delta <= 255.0 + self.delta = delta + + def __call__(self, image, boxes=None, labels=None): + if random.randint(2): + delta = random.uniform(-self.delta, self.delta) + image += delta + return image, boxes, labels + + +class ToCV2Image(object): + def __call__(self, tensor, boxes=None, labels=None): + return tensor.cpu().numpy().astype(np.float32).transpose((1, 2, 0)), boxes, labels + + +class ToTensor(object): + def __call__(self, cvimage, boxes=None, labels=None): + return torch.from_numpy(cvimage.astype(np.float32)).permute(2, 0, 1), boxes, labels + + +class RandomSampleCrop(object): + """Crop + Arguments: + img (Image): the image being input during training + boxes (Tensor): the original bounding boxes in pt form + labels (Tensor): the class labels for each bbox + mode (float tuple): the min and max jaccard overlaps + Return: + (img, boxes, classes) + img (Image): the cropped image + boxes (Tensor): the adjusted bounding boxes in pt form + labels (Tensor): the class labels for each bbox + """ + def __init__(self): + self.sample_options = ( + # using entire original input image + None, + # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9 + (0.1, None), + (0.3, None), + (0.7, None), + (0.9, None), + # randomly sample a patch + (None, None), + ) + + def __call__(self, image, boxes=None, labels=None): + height, width, _ = image.shape + while True: + # randomly choose a mode + #mode = random.choice(self.sample_options) # throws numpy deprecation warning + mode = self.sample_options[random.randint(len(self.sample_options))] + + if mode is None: + return image, boxes, labels + + min_iou, max_iou = mode + if min_iou is None: + min_iou = float('-inf') + if max_iou is None: + max_iou = float('inf') + + # max trails (50) + for _ in range(50): + current_image = image + + w = random.uniform(0.3 * width, width) + h = random.uniform(0.3 * height, height) + + # aspect ratio constraint b/t .5 & 2 + if h / w < 0.5 or h / w > 2: + continue + + left = random.uniform(width - w) + top = random.uniform(height - h) + + # convert to integer rect x1,y1,x2,y2 + rect = np.array([int(left), int(top), int(left+w), int(top+h)]) + + # calculate IoU (jaccard overlap) b/t the cropped and gt boxes + overlap = jaccard_numpy(boxes, rect) + + # is min and max overlap constraint satisfied? if not try again + if overlap.min() < min_iou and max_iou < overlap.max(): + continue + + # cut the crop from the image + current_image = current_image[rect[1]:rect[3], rect[0]:rect[2], + :] + + # keep overlap with gt box IF center in sampled patch + centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0 + + # mask in all gt boxes that above and to the left of centers + m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) + + # mask in all gt boxes that under and to the right of centers + m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) + + # mask in that both m1 and m2 are true + mask = m1 * m2 + + # have any valid boxes? try again if not + if not mask.any(): + continue + + # take only matching gt boxes + current_boxes = boxes[mask, :].copy() + + # take only matching gt labels + current_labels = labels[mask] + + # should we use the box left and top corner or the crop's + current_boxes[:, :2] = np.maximum(current_boxes[:, :2], + rect[:2]) + # adjust to crop (by substracting crop's left,top) + current_boxes[:, :2] -= rect[:2] + + current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:], + rect[2:]) + # adjust to crop (by substracting crop's left,top) + current_boxes[:, 2:] -= rect[:2] + + return current_image, current_boxes, current_labels + + +class Expand(object): + def __init__(self, mean): + self.mean = mean + + def __call__(self, image, boxes, labels): + if random.randint(2): + return image, boxes, labels + + height, width, depth = image.shape + ratio = random.uniform(1, 4) + left = random.uniform(0, width*ratio - width) + top = random.uniform(0, height*ratio - height) + + expand_image = np.zeros( + (int(height*ratio), int(width*ratio), depth), + dtype=image.dtype) + expand_image[:, :, :] = self.mean + expand_image[int(top):int(top + height), + int(left):int(left + width)] = image + image = expand_image + + boxes = boxes.copy() + boxes[:, :2] += (int(left), int(top)) + boxes[:, 2:] += (int(left), int(top)) + + return image, boxes, labels + + +class RandomMirror(object): + def __call__(self, image, boxes, classes): + _, width, _ = image.shape + if random.randint(2): + image = image[:, ::-1] + boxes = boxes.copy() + boxes[:, 0::2] = width - boxes[:, 2::-2] + return image, boxes, classes + + +class SwapChannels(object): + """Transforms a tensorized image by swapping the channels in the order + specified in the swap tuple. + Args: + swaps (int triple): final order of channels + eg: (2, 1, 0) + """ + + def __init__(self, swaps): + self.swaps = swaps + + def __call__(self, image): + """ + Args: + image (Tensor): image tensor to be transformed + Return: + a tensor with channels swapped according to swap + """ + # if torch.is_tensor(image): + # image = image.data.cpu().numpy() + # else: + # image = np.array(image) + image = image[:, :, self.swaps] + return image + + +class PhotometricDistort(object): + def __init__(self): + self.pd = [ + RandomContrast(), # RGB + ConvertColor(current="RGB", transform='HSV'), # HSV + RandomSaturation(), # HSV + RandomHue(), # HSV + ConvertColor(current='HSV', transform='RGB'), # RGB + RandomContrast() # RGB + ] + self.rand_brightness = RandomBrightness() + self.rand_light_noise = RandomLightingNoise() + + def __call__(self, image, boxes, labels): + im = image.copy() + im, boxes, labels = self.rand_brightness(im, boxes, labels) + if random.randint(2): + distort = Compose(self.pd[:-1]) + else: + distort = Compose(self.pd[1:]) + im, boxes, labels = distort(im, boxes, labels) + return self.rand_light_noise(im, boxes, labels) + diff --git a/Samples/DetectionRetrainingAndInfer/vision/utils/__init__.py b/Samples/DetectionRetrainingAndInfer/vision/utils/__init__.py new file mode 100644 index 0000000..0789bdb --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/vision/utils/__init__.py @@ -0,0 +1 @@ +from .misc import * diff --git a/Samples/DetectionRetrainingAndInfer/vision/utils/box_utils.py b/Samples/DetectionRetrainingAndInfer/vision/utils/box_utils.py new file mode 100644 index 0000000..42f2469 --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/vision/utils/box_utils.py @@ -0,0 +1,295 @@ +import collections +import torch +import itertools +from typing import List +import math + +SSDBoxSizes = collections.namedtuple('SSDBoxSizes', ['min', 'max']) + +SSDSpec = collections.namedtuple('SSDSpec', ['feature_map_size', 'shrinkage', 'box_sizes', 'aspect_ratios']) + + +def generate_ssd_priors(specs: List[SSDSpec], image_size, clamp=True) -> torch.Tensor: + """Generate SSD Prior Boxes. + + It returns the center, height and width of the priors. The values are relative to the image size + Args: + specs: SSDSpecs about the shapes of sizes of prior boxes. i.e. + specs = [ + SSDSpec(38, 8, SSDBoxSizes(30, 60), [2]), + SSDSpec(19, 16, SSDBoxSizes(60, 111), [2, 3]), + SSDSpec(10, 32, SSDBoxSizes(111, 162), [2, 3]), + SSDSpec(5, 64, SSDBoxSizes(162, 213), [2, 3]), + SSDSpec(3, 100, SSDBoxSizes(213, 264), [2]), + SSDSpec(1, 300, SSDBoxSizes(264, 315), [2]) + ] + image_size: image size. + clamp: if true, clamp the values to make fall between [0.0, 1.0] + Returns: + priors (num_priors, 4): The prior boxes represented as [[center_x, center_y, w, h]]. All the values + are relative to the image size. + """ + priors = [] + for spec in specs: + scale = image_size / spec.shrinkage + for j, i in itertools.product(range(spec.feature_map_size), repeat=2): + x_center = (i + 0.5) / scale + y_center = (j + 0.5) / scale + + # small sized square box + size = spec.box_sizes.min + h = w = size / image_size + priors.append([ + x_center, + y_center, + w, + h + ]) + + # big sized square box + size = math.sqrt(spec.box_sizes.max * spec.box_sizes.min) + h = w = size / image_size + priors.append([ + x_center, + y_center, + w, + h + ]) + + # change h/w ratio of the small sized box + size = spec.box_sizes.min + h = w = size / image_size + for ratio in spec.aspect_ratios: + ratio = math.sqrt(ratio) + priors.append([ + x_center, + y_center, + w * ratio, + h / ratio + ]) + priors.append([ + x_center, + y_center, + w / ratio, + h * ratio + ]) + + priors = torch.tensor(priors) + if clamp: + torch.clamp(priors, 0.0, 1.0, out=priors) + return priors + + +def convert_locations_to_boxes(locations, priors, center_variance, + size_variance): + """Convert regressional location results of SSD into boxes in the form of (center_x, center_y, h, w). + + The conversion: + $$predicted\_center * center_variance = \frac {real\_center - prior\_center} {prior\_hw}$$ + $$exp(predicted\_hw * size_variance) = \frac {real\_hw} {prior\_hw}$$ + We do it in the inverse direction here. + Args: + locations (batch_size, num_priors, 4): the regression output of SSD. It will contain the outputs as well. + priors (num_priors, 4) or (batch_size/1, num_priors, 4): prior boxes. + center_variance: a float used to change the scale of center. + size_variance: a float used to change of scale of size. + Returns: + boxes: priors: [[center_x, center_y, h, w]]. All the values + are relative to the image size. + """ + # priors can have one dimension less. + if priors.dim() + 1 == locations.dim(): + priors = priors.unsqueeze(0) + return torch.cat([ + locations[..., :2] * center_variance * priors[..., 2:] + priors[..., :2], + torch.exp(locations[..., 2:] * size_variance) * priors[..., 2:] + ], dim=locations.dim() - 1) + + +def convert_boxes_to_locations(center_form_boxes, center_form_priors, center_variance, size_variance): + # priors can have one dimension less + if center_form_priors.dim() + 1 == center_form_boxes.dim(): + center_form_priors = center_form_priors.unsqueeze(0) + return torch.cat([ + (center_form_boxes[..., :2] - center_form_priors[..., :2]) / center_form_priors[..., 2:] / center_variance, + torch.log(center_form_boxes[..., 2:] / center_form_priors[..., 2:]) / size_variance + ], dim=center_form_boxes.dim() - 1) + + +def area_of(left_top, right_bottom) -> torch.Tensor: + """Compute the areas of rectangles given two corners. + + Args: + left_top (N, 2): left top corner. + right_bottom (N, 2): right bottom corner. + + Returns: + area (N): return the area. + """ + hw = torch.clamp(right_bottom - left_top, min=0.0) + return hw[..., 0] * hw[..., 1] + + +def iou_of(boxes0, boxes1, eps=1e-5): + """Return intersection-over-union (Jaccard index) of boxes. + + Args: + boxes0 (N, 4): ground truth boxes. + boxes1 (N or 1, 4): predicted boxes. + eps: a small number to avoid 0 as denominator. + Returns: + iou (N): IoU values. + """ + overlap_left_top = torch.max(boxes0[..., :2], boxes1[..., :2]) + overlap_right_bottom = torch.min(boxes0[..., 2:], boxes1[..., 2:]) + + overlap_area = area_of(overlap_left_top, overlap_right_bottom) + area0 = area_of(boxes0[..., :2], boxes0[..., 2:]) + area1 = area_of(boxes1[..., :2], boxes1[..., 2:]) + return overlap_area / (area0 + area1 - overlap_area + eps) + + +def assign_priors(gt_boxes, gt_labels, corner_form_priors, + iou_threshold): + """Assign ground truth boxes and targets to priors. + + Args: + gt_boxes (num_targets, 4): ground truth boxes. + gt_labels (num_targets): labels of targets. + priors (num_priors, 4): corner form priors + Returns: + boxes (num_priors, 4): real values for priors. + labels (num_priros): labels for priors. + """ + # size: num_priors x num_targets + ious = iou_of(gt_boxes.unsqueeze(0), corner_form_priors.unsqueeze(1)) + # size: num_priors + best_target_per_prior, best_target_per_prior_index = ious.max(1) + # size: num_targets + best_prior_per_target, best_prior_per_target_index = ious.max(0) + + for target_index, prior_index in enumerate(best_prior_per_target_index): + best_target_per_prior_index[prior_index] = target_index + # 2.0 is used to make sure every target has a prior assigned + best_target_per_prior.index_fill_(0, best_prior_per_target_index, 2) + # size: num_priors + labels = gt_labels[best_target_per_prior_index] + labels[best_target_per_prior < iou_threshold] = 0 # the backgournd id + boxes = gt_boxes[best_target_per_prior_index] + return boxes, labels + + +def hard_negative_mining(loss, labels, neg_pos_ratio): + """ + It used to suppress the presence of a large number of negative prediction. + It works on image level not batch level. + For any example/image, it keeps all the positive predictions and + cut the number of negative predictions to make sure the ratio + between the negative examples and positive examples is no more + the given ratio for an image. + + Args: + loss (N, num_priors): the loss for each example. + labels (N, num_priors): the labels. + neg_pos_ratio: the ratio between the negative examples and positive examples. + """ + pos_mask = labels > 0 + num_pos = pos_mask.long().sum(dim=1, keepdim=True) + num_neg = num_pos * neg_pos_ratio + + loss[pos_mask] = -math.inf + _, indexes = loss.sort(dim=1, descending=True) + _, orders = indexes.sort(dim=1) + neg_mask = orders < num_neg + return pos_mask | neg_mask + + +def center_form_to_corner_form(locations): + return torch.cat([locations[..., :2] - locations[..., 2:]/2, + locations[..., :2] + locations[..., 2:]/2], locations.dim() - 1) + + +def corner_form_to_center_form(boxes): + return torch.cat([ + (boxes[..., :2] + boxes[..., 2:]) / 2, + boxes[..., 2:] - boxes[..., :2] + ], boxes.dim() - 1) + + +def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200): + """ + + Args: + box_scores (N, 5): boxes in corner-form and probabilities. + iou_threshold: intersection over union threshold. + top_k: keep top_k results. If k <= 0, keep all the results. + candidate_size: only consider the candidates with the highest scores. + Returns: + picked: a list of indexes of the kept boxes + """ + scores = box_scores[:, -1] + boxes = box_scores[:, :-1] + picked = [] + _, indexes = scores.sort(descending=True) + indexes = indexes[:candidate_size] + while len(indexes) > 0: + current = indexes[0] + picked.append(current.item()) + if 0 < top_k == len(picked) or len(indexes) == 1: + break + current_box = boxes[current, :] + indexes = indexes[1:] + rest_boxes = boxes[indexes, :] + iou = iou_of( + rest_boxes, + current_box.unsqueeze(0), + ) + indexes = indexes[iou <= iou_threshold] + + return box_scores[picked, :] + + +def nms(box_scores, nms_method=None, score_threshold=None, iou_threshold=None, + sigma=0.5, top_k=-1, candidate_size=200): + if nms_method == "soft": + return soft_nms(box_scores, score_threshold, sigma, top_k) + else: + return hard_nms(box_scores, iou_threshold, top_k, candidate_size=candidate_size) + + +def soft_nms(box_scores, score_threshold, sigma=0.5, top_k=-1): + """Soft NMS implementation. + + References: + https://arxiv.org/abs/1704.04503 + https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/cython_nms.pyx + + Args: + box_scores (N, 5): boxes in corner-form and probabilities. + score_threshold: boxes with scores less than value are not considered. + sigma: the parameter in score re-computation. + scores[i] = scores[i] * exp(-(iou_i)^2 / simga) + top_k: keep top_k results. If k <= 0, keep all the results. + Returns: + picked_box_scores (K, 5): results of NMS. + """ + picked_box_scores = [] + while box_scores.size(0) > 0: + max_score_index = torch.argmax(box_scores[:, 4]) + cur_box_prob = torch.tensor(box_scores[max_score_index, :]) + picked_box_scores.append(cur_box_prob) + if len(picked_box_scores) == top_k > 0 or box_scores.size(0) == 1: + break + cur_box = cur_box_prob[:-1] + box_scores[max_score_index, :] = box_scores[-1, :] + box_scores = box_scores[:-1, :] + ious = iou_of(cur_box.unsqueeze(0), box_scores[:, :-1]) + box_scores[:, -1] = box_scores[:, -1] * torch.exp(-(ious * ious) / sigma) + box_scores = box_scores[box_scores[:, -1] > score_threshold, :] + if len(picked_box_scores) > 0: + return torch.stack(picked_box_scores) + else: + return torch.tensor([]) + + + diff --git a/Samples/DetectionRetrainingAndInfer/vision/utils/box_utils_numpy.py b/Samples/DetectionRetrainingAndInfer/vision/utils/box_utils_numpy.py new file mode 100644 index 0000000..177456f --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/vision/utils/box_utils_numpy.py @@ -0,0 +1,238 @@ +from .box_utils import SSDSpec + +from typing import List +import itertools +import math +import numpy as np + + +def generate_ssd_priors(specs: List[SSDSpec], image_size, clamp=True): + """Generate SSD Prior Boxes. + + It returns the center, height and width of the priors. The values are relative to the image size + Args: + specs: SSDSpecs about the shapes of sizes of prior boxes. i.e. + specs = [ + SSDSpec(38, 8, SSDBoxSizes(30, 60), [2]), + SSDSpec(19, 16, SSDBoxSizes(60, 111), [2, 3]), + SSDSpec(10, 32, SSDBoxSizes(111, 162), [2, 3]), + SSDSpec(5, 64, SSDBoxSizes(162, 213), [2, 3]), + SSDSpec(3, 100, SSDBoxSizes(213, 264), [2]), + SSDSpec(1, 300, SSDBoxSizes(264, 315), [2]) + ] + image_size: image size. + clamp: if true, clamp the values to make fall between [0.0, 1.0] + Returns: + priors (num_priors, 4): The prior boxes represented as [[center_x, center_y, w, h]]. All the values + are relative to the image size. + """ + priors = [] + for spec in specs: + scale = image_size / spec.shrinkage + for j, i in itertools.product(range(spec.feature_map_size), repeat=2): + x_center = (i + 0.5) / scale + y_center = (j + 0.5) / scale + + # small sized square box + size = spec.box_sizes.min + h = w = size / image_size + priors.append([ + x_center, + y_center, + w, + h + ]) + + # big sized square box + size = math.sqrt(spec.box_sizes.max * spec.box_sizes.min) + h = w = size / image_size + priors.append([ + x_center, + y_center, + w, + h + ]) + + # change h/w ratio of the small sized box + size = spec.box_sizes.min + h = w = size / image_size + for ratio in spec.aspect_ratios: + ratio = math.sqrt(ratio) + priors.append([ + x_center, + y_center, + w * ratio, + h / ratio + ]) + priors.append([ + x_center, + y_center, + w / ratio, + h * ratio + ]) + + priors = np.array(priors, dtype=np.float32) + if clamp: + np.clip(priors, 0.0, 1.0, out=priors) + return priors + + +def convert_locations_to_boxes(locations, priors, center_variance, + size_variance): + """Convert regressional location results of SSD into boxes in the form of (center_x, center_y, h, w). + + The conversion: + $$predicted\_center * center_variance = \frac {real\_center - prior\_center} {prior\_hw}$$ + $$exp(predicted\_hw * size_variance) = \frac {real\_hw} {prior\_hw}$$ + We do it in the inverse direction here. + Args: + locations (batch_size, num_priors, 4): the regression output of SSD. It will contain the outputs as well. + priors (num_priors, 4) or (batch_size/1, num_priors, 4): prior boxes. + center_variance: a float used to change the scale of center. + size_variance: a float used to change of scale of size. + Returns: + boxes: priors: [[center_x, center_y, h, w]]. All the values + are relative to the image size. + """ + # priors can have one dimension less. + if len(priors.shape) + 1 == len(locations.shape): + priors = np.expand_dims(priors, 0) + return np.concatenate([ + locations[..., :2] * center_variance * priors[..., 2:] + priors[..., :2], + np.exp(locations[..., 2:] * size_variance) * priors[..., 2:] + ], axis=len(locations.shape) - 1) + + +def convert_boxes_to_locations(center_form_boxes, center_form_priors, center_variance, size_variance): + # priors can have one dimension less + if len(center_form_priors.shape) + 1 == len(center_form_boxes.shape): + center_form_priors = np.expand_dims(center_form_priors, 0) + return np.concatenate([ + (center_form_boxes[..., :2] - center_form_priors[..., :2]) / center_form_priors[..., 2:] / center_variance, + np.log(center_form_boxes[..., 2:] / center_form_priors[..., 2:]) / size_variance + ], axis=len(center_form_boxes.shape) - 1) + + +def area_of(left_top, right_bottom): + """Compute the areas of rectangles given two corners. + + Args: + left_top (N, 2): left top corner. + right_bottom (N, 2): right bottom corner. + + Returns: + area (N): return the area. + """ + hw = np.clip(right_bottom - left_top, 0.0, None) + return hw[..., 0] * hw[..., 1] + + +def iou_of(boxes0, boxes1, eps=1e-5): + """Return intersection-over-union (Jaccard index) of boxes. + + Args: + boxes0 (N, 4): ground truth boxes. + boxes1 (N or 1, 4): predicted boxes. + eps: a small number to avoid 0 as denominator. + Returns: + iou (N): IoU values. + """ + overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2]) + overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:]) + + overlap_area = area_of(overlap_left_top, overlap_right_bottom) + area0 = area_of(boxes0[..., :2], boxes0[..., 2:]) + area1 = area_of(boxes1[..., :2], boxes1[..., 2:]) + return overlap_area / (area0 + area1 - overlap_area + eps) + + +def center_form_to_corner_form(locations): + return np.concatenate([locations[..., :2] - locations[..., 2:]/2, + locations[..., :2] + locations[..., 2:]/2], len(locations.shape) - 1) + + +def corner_form_to_center_form(boxes): + return np.concatenate([ + (boxes[..., :2] + boxes[..., 2:]) / 2, + boxes[..., 2:] - boxes[..., :2] + ], len(boxes.shape) - 1) + + +def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200): + """ + + Args: + box_scores (N, 5): boxes in corner-form and probabilities. + iou_threshold: intersection over union threshold. + top_k: keep top_k results. If k <= 0, keep all the results. + candidate_size: only consider the candidates with the highest scores. + Returns: + picked: a list of indexes of the kept boxes + """ + scores = box_scores[:, -1] + boxes = box_scores[:, :-1] + picked = [] + #_, indexes = scores.sort(descending=True) + indexes = np.argsort(scores) + #indexes = indexes[:candidate_size] + indexes = indexes[-candidate_size:] + while len(indexes) > 0: + #current = indexes[0] + current = indexes[-1] + picked.append(current) + if 0 < top_k == len(picked) or len(indexes) == 1: + break + current_box = boxes[current, :] + #indexes = indexes[1:] + indexes = indexes[:-1] + rest_boxes = boxes[indexes, :] + iou = iou_of( + rest_boxes, + np.expand_dims(current_box, axis=0), + ) + indexes = indexes[iou <= iou_threshold] + + return box_scores[picked, :] + + +# def nms(box_scores, nms_method=None, score_threshold=None, iou_threshold=None, +# sigma=0.5, top_k=-1, candidate_size=200): +# if nms_method == "soft": +# return soft_nms(box_scores, score_threshold, sigma, top_k) +# else: +# return hard_nms(box_scores, iou_threshold, top_k, candidate_size=candidate_size) + +# +# def soft_nms(box_scores, score_threshold, sigma=0.5, top_k=-1): +# """Soft NMS implementation. +# +# References: +# https://arxiv.org/abs/1704.04503 +# https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/cython_nms.pyx +# +# Args: +# box_scores (N, 5): boxes in corner-form and probabilities. +# score_threshold: boxes with scores less than value are not considered. +# sigma: the parameter in score re-computation. +# scores[i] = scores[i] * exp(-(iou_i)^2 / simga) +# top_k: keep top_k results. If k <= 0, keep all the results. +# Returns: +# picked_box_scores (K, 5): results of NMS. +# """ +# picked_box_scores = [] +# while box_scores.size(0) > 0: +# max_score_index = torch.argmax(box_scores[:, 4]) +# cur_box_prob = torch.tensor(box_scores[max_score_index, :]) +# picked_box_scores.append(cur_box_prob) +# if len(picked_box_scores) == top_k > 0 or box_scores.size(0) == 1: +# break +# cur_box = cur_box_prob[:-1] +# box_scores[max_score_index, :] = box_scores[-1, :] +# box_scores = box_scores[:-1, :] +# ious = iou_of(cur_box.unsqueeze(0), box_scores[:, :-1]) +# box_scores[:, -1] = box_scores[:, -1] * torch.exp(-(ious * ious) / sigma) +# box_scores = box_scores[box_scores[:, -1] > score_threshold, :] +# if len(picked_box_scores) > 0: +# return torch.stack(picked_box_scores) +# else: +# return torch.tensor([]) diff --git a/Samples/DetectionRetrainingAndInfer/vision/utils/measurements.py b/Samples/DetectionRetrainingAndInfer/vision/utils/measurements.py new file mode 100644 index 0000000..5cc590c --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/vision/utils/measurements.py @@ -0,0 +1,32 @@ +import numpy as np + + +def compute_average_precision(precision, recall): + """ + It computes average precision based on the definition of Pascal Competition. It computes the under curve area + of precision and recall. Recall follows the normal definition. Precision is a variant. + pascal_precision[i] = typical_precision[i:].max() + """ + # identical but faster version of new_precision[i] = old_precision[i:].max() + precision = np.concatenate([[0.0], precision, [0.0]]) + for i in range(len(precision) - 1, 0, -1): + precision[i - 1] = np.maximum(precision[i - 1], precision[i]) + + # find the index where the value changes + recall = np.concatenate([[0.0], recall, [1.0]]) + changing_points = np.where(recall[1:] != recall[:-1])[0] + + # compute under curve area + areas = (recall[changing_points + 1] - recall[changing_points]) * precision[changing_points + 1] + return areas.sum() + + +def compute_voc2007_average_precision(precision, recall): + ap = 0. + for t in np.arange(0., 1.1, 0.1): + if np.sum(recall >= t) == 0: + p = 0 + else: + p = np.max(precision[recall >= t]) + ap = ap + p / 11. + return ap diff --git a/Samples/DetectionRetrainingAndInfer/vision/utils/misc.py b/Samples/DetectionRetrainingAndInfer/vision/utils/misc.py new file mode 100644 index 0000000..e795458 --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/vision/utils/misc.py @@ -0,0 +1,45 @@ +import time +import torch + + +def str2bool(s): + return s.lower() in ('true', '1') + + +class Timer: + def __init__(self): + self.clock = {} + + def start(self, key="default"): + self.clock[key] = time.time() + + def end(self, key="default"): + if key not in self.clock: + raise Exception(f"{key} is not in the clock.") + interval = time.time() - self.clock[key] + del self.clock[key] + return interval + + +def save_checkpoint(epoch, net_state_dict, optimizer_state_dict, best_score, checkpoint_path, model_path): + torch.save({ + 'epoch': epoch, + 'model': net_state_dict, + 'optimizer': optimizer_state_dict, + 'best_score': best_score + }, checkpoint_path) + torch.save(net_state_dict, model_path) + + +def load_checkpoint(checkpoint_path): + return torch.load(checkpoint_path) + + +def freeze_net_layers(net): + for param in net.parameters(): + param.requires_grad = False + + +def store_labels(path, labels): + with open(path, "w") as f: + f.write("\n".join(labels)) diff --git a/Samples/DetectionRetrainingAndInfer/vision/utils/model_book.py b/Samples/DetectionRetrainingAndInfer/vision/utils/model_book.py new file mode 100644 index 0000000..b1e9d17 --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/vision/utils/model_book.py @@ -0,0 +1,81 @@ +from collections import OrderedDict +import torch.nn as nn + + +class ModelBook: + """Maintain the mapping between modules and their paths. + + Example: + book = ModelBook(model_ft) + for p, m in book.conv2d_modules(): + print('path:', p, 'num of filters:', m.out_channels) + assert m is book.get_module(p) + """ + + def __init__(self, model): + self._model = model + self._modules = OrderedDict() + self._paths = OrderedDict() + path = [] + self._construct(self._model, path) + + def _construct(self, module, path): + if not module._modules: + return + for name, m in module._modules.items(): + cur_path = tuple(path + [name]) + self._paths[m] = cur_path + self._modules[cur_path] = m + self._construct(m, path + [name]) + + def conv2d_modules(self): + return self.modules(nn.Conv2d) + + def linear_modules(self): + return self.modules(nn.Linear) + + def modules(self, module_type=None): + for p, m in self._modules.items(): + if not module_type or isinstance(m, module_type): + yield p, m + + def num_of_conv2d_modules(self): + return self.num_of_modules(nn.Conv2d) + + def num_of_conv2d_filters(self): + """Return the sum of out_channels of all conv2d layers. + + Here we treat the sub weight with size of [in_channels, h, w] as a single filter. + """ + num_filters = 0 + for _, m in self.conv2d_modules(): + num_filters += m.out_channels + return num_filters + + def num_of_linear_modules(self): + return self.num_of_modules(nn.Linear) + + def num_of_linear_filters(self): + num_filters = 0 + for _, m in self.linear_modules(): + num_filters += m.out_features + return num_filters + + def num_of_modules(self, module_type=None): + num = 0 + for p, m in self._modules.items(): + if not module_type or isinstance(m, module_type): + num += 1 + return num + + def get_module(self, path): + return self._modules.get(path) + + def get_path(self, module): + return self._paths.get(module) + + def update(self, path, module): + old_module = self._modules[path] + del self._paths[old_module] + self._paths[module] = path + self._modules[path] = module -- Gitee From fe1f7315acccef93e06cbce45e051636e1b88b99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=87=E7=A5=96=E6=B6=9B?= <1025494833@qq.com> Date: Thu, 21 Mar 2024 02:01:50 +0000 Subject: [PATCH 02/18] add Samples/DetectionRetrainingAndInfer/README.md. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 万祖涛 <1025494833@qq.com> --- Samples/DetectionRetrainingAndInfer/README.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 Samples/DetectionRetrainingAndInfer/README.md diff --git a/Samples/DetectionRetrainingAndInfer/README.md b/Samples/DetectionRetrainingAndInfer/README.md new file mode 100644 index 0000000..e69de29 -- Gitee From dd5186ff61c3bb3bfd3ed8c66f60b735b48e2172 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=87=E7=A5=96=E6=B6=9B?= <1025494833@qq.com> Date: Thu, 21 Mar 2024 02:02:15 +0000 Subject: [PATCH 03/18] =?UTF-8?q?=E6=96=B0=E5=BB=BA=20models?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Samples/DetectionRetrainingAndInfer/models/.keep | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 Samples/DetectionRetrainingAndInfer/models/.keep diff --git a/Samples/DetectionRetrainingAndInfer/models/.keep b/Samples/DetectionRetrainingAndInfer/models/.keep new file mode 100644 index 0000000..e69de29 -- Gitee From b4928cd9e33687ce44ea7f0a2f3edd71cd29acab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=87=E7=A5=96=E6=B6=9B?= <1025494833@qq.com> Date: Thu, 21 Mar 2024 02:03:16 +0000 Subject: [PATCH 04/18] =?UTF-8?q?=E6=96=B0=E5=BB=BA=20data?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Samples/DetectionRetrainingAndInfer/omInfer/data/.keep | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 Samples/DetectionRetrainingAndInfer/omInfer/data/.keep diff --git a/Samples/DetectionRetrainingAndInfer/omInfer/data/.keep b/Samples/DetectionRetrainingAndInfer/omInfer/data/.keep new file mode 100644 index 0000000..e69de29 -- Gitee From 439d57d7dedc3325d37753f3f8b1ac61edbc72ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=87=E7=A5=96=E6=B6=9B?= <1025494833@qq.com> Date: Thu, 21 Mar 2024 02:03:31 +0000 Subject: [PATCH 05/18] =?UTF-8?q?=E6=96=B0=E5=BB=BA=20model?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Samples/DetectionRetrainingAndInfer/omInfer/model/.keep | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 Samples/DetectionRetrainingAndInfer/omInfer/model/.keep diff --git a/Samples/DetectionRetrainingAndInfer/omInfer/model/.keep b/Samples/DetectionRetrainingAndInfer/omInfer/model/.keep new file mode 100644 index 0000000..e69de29 -- Gitee From ffb553b04732597a521aaee30c0fdf93ce587a9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=87=E7=A5=96=E6=B6=9B?= <1025494833@qq.com> Date: Thu, 21 Mar 2024 02:03:53 +0000 Subject: [PATCH 06/18] =?UTF-8?q?=E6=96=B0=E5=BB=BA=20output?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Samples/DetectionRetrainingAndInfer/omInfer/output/.keep | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 Samples/DetectionRetrainingAndInfer/omInfer/output/.keep diff --git a/Samples/DetectionRetrainingAndInfer/omInfer/output/.keep b/Samples/DetectionRetrainingAndInfer/omInfer/output/.keep new file mode 100644 index 0000000..e69de29 -- Gitee From 0c35b85517287644ec47cc5913f0fec837209b90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=87=E7=A5=96=E6=B6=9B?= <1025494833@qq.com> Date: Thu, 21 Mar 2024 02:34:47 +0000 Subject: [PATCH 07/18] update Samples/DetectionRetrainingAndInfer/README.md. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 万祖涛 <1025494833@qq.com> --- Samples/DetectionRetrainingAndInfer/README.md | 165 ++++++++++++++++++ 1 file changed, 165 insertions(+) diff --git a/Samples/DetectionRetrainingAndInfer/README.md b/Samples/DetectionRetrainingAndInfer/README.md index e69de29..04e4e8f 100644 --- a/Samples/DetectionRetrainingAndInfer/README.md +++ b/Samples/DetectionRetrainingAndInfer/README.md @@ -0,0 +1,165 @@ +# 口罩识别目标检测训练和推理 + +#### 样例介绍 + +本样例基于预训练ssd-mobilenet模型使用口罩识别数据集实现了检测口罩佩戴识别的功能 + +### 标题,包含训练到om推理全过程。 + +#### 样例下载 + +可以使用以下两种方式下载,请选择其中一种进行源码准备。 + +- 命令行方式下载(**下载时间较长,但步骤简单**)。 + + ``` + # 登录开发板,HwHiAiUser用户命令行中执行以下命令下载源码仓。 + cd ${HOME} + git clone https://gitee.com/ascend/EdgeAndRobotics.git + # 切换到样例目录 + cd EdgeAndRobotics/Samples/ClassficationRetrainingAndInfer + ``` + +- 压缩包方式下载(**下载时间较短,但步骤稍微复杂**)。 + + ``` + # 1. 仓右上角选择 【克隆/下载】 下拉框并选择 【下载ZIP】。 + # 2. 将ZIP包上传到开发板的普通用户家目录中,【例如:${HOME}/EdgeAndRobotics-master.zip】。 + # 3. 开发环境中,执行以下命令,解压zip包。 + cd ${HOME} + chmod +x EdgeAndRobotics-master.zip + unzip EdgeAndRobotics-master.zip + # 4. 切换到样例目录 + cd EdgeAndRobotics-master/Samples/ClassficationRetrainingAndInfer + ``` + +#### 执行准备 + +- 本样例中的模型支持PyTorch2.1.0、torchvision1.16.0版本,请参考[安装PyTorch](https://www.hiascend.com/document/detail/zh/canncommercial/700/envdeployment/instg/instg_0046.html)章节安装PyTorch以及torch_npu插件。 + ``` + # torch_npu由于需要源码编译,速度可能较慢,本样例提供 python3.9,torch2.1版本的torch_npu whl包 + wget https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/wanzutao/torch_npu-2.1.0rc1-cp39-cp39-linux_aarch64.whl + + # 使用pip命令安装 + pip3 install torch_npu-2.1.0rc1-cp39-cp39-linux_aarch64.whl + ``` + +- 本样例中的模型还依赖一些其它库(具体依赖哪些库,可查看本样例目录下的requirements.txt文件),可执行以下命令安装: + + ``` + pip3 install -r requirements.txt # PyTorch2.1版本 + ``` + +- 配置离线推理所需的环境变量。 + + ``` + # 配置程序编译依赖的头文件与库文件路径 + export DDK_PATH=/usr/local/Ascend/ascend-toolkit/latest + export NPU_HOST_LIB=$DDK_PATH/runtime/lib64/stub + ``` + +- 安装离线推理所需的ACLLite库。 + + 参考[ACLLite仓](https://gitee.com/ascend/ACLLite)安装ACLLite库。 + + +#### 模型训练 + +1. 以HwHiAiUser用户登录开发板,切换到样例目录下。 +2. 设置环境变量减小算子编译内存占用。 + ``` + export TE_PARALLEL_COMPILER=1 + export MAX_COMPILE_CORE_NUMBER=1 + ``` +3. 准备数据集 + ``` + cd dataset + wget https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/wanzutao/detection/mask.zip + unzip mask.zip + ``` + + +4. 数据集处理,分出训练集合测试集. + ``` + cd .. + python3 predata.py + ``` +5. 运行训练脚本。 + + ``` + python3 main.py + ``` + 训练完成后,权重文件保存在models目录下,并输出模型训练精度和性能信息。 + + 此处展示单Device、batch_size=8的训练结果数据: + | NAME | Loss | FPS | Epochs | AMP_Type | Torch_Version | + | :----: | :---: | :---: | :----: | :------: | :-----------: | + | 1p-NPU | 1.8480 | 2 | 10 | O2 | 2.1 | + + +#### 离线推理 + +1. 以HwHiAiUser用户登录开发板,切换到当前样例目录。 +2. 导出onnx模型 + ``` + python3 export.py + ``` + +3. 获取测试图片数据。 + + ``` + cd omInfer/data + wget https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/wanzutao/detection/mask.jpg + ``` + + **注:**若需更换测试图片,则需自行准备测试图片,并将测试图片放到omInfer/data目录下,并修改代码中图片名称。 + +4. 获取PyTorch框架的mobilenet-ssd模型(\*.onnx),并转换为昇腾AI处理器能识别的模型(\*.om)。 + - 当设备内存**小于8G**时,可设置如下两个环境变量减少atc模型转换过程中使用的进程数,减小内存占用。 + ``` + export TE_PARALLEL_COMPILER=1 + export MAX_COMPILE_CORE_NUMBER=1 + ``` + - 为了方便下载,在这里直接给出原始模型下载及模型转换命令,可以直接拷贝执行。 + ``` + # 将导出的mobilenet-ssd.onnx模型拷贝到model目录下 + cd ../model + cp ../../mobilenet-ssd.onnx ./ + + # 获取AIPP配置文件 + wget https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/wanzutao/detection/aipp.cfg + + # 模型转换 + atc --model=mobilenet-ssd.onnx --framework=5 --soc_version=Ascend310B4 --output=mobilenet-ssd --insert_op_conf=aipp.cfg + ``` + + atc命令中各参数的解释如下,详细约束说明请参见[《ATC模型转换指南》](https://hiascend.com/document/redirect/CannCommunityAtc)。 + + - --model:转换前模型文件的路径。 + - --framework:原始框架类型。5表示ONNX。 + - --output:转换后模型文件的路径。请注意,记录保存该om模型文件的路径,后续开发应用时需要使用。 + - --input\_shape:模型输入数据的shape。 + - --soc\_version:昇腾AI处理器的版本。 + + +5. 编译样例源码。 + + 执行以下命令编译样例源码。 + + ``` + cd ../scripts + bash sample_build.sh + ``` + +6. 运行样例。 + + 执行以下脚本运行样例: + + ``` + bash sample_run.sh + ``` + + 执行成功后,omInfer/output目录下会生成检测输出图片 + + +#### 相关操作 \ No newline at end of file -- Gitee From 972573320e233080a63b3be8e6c55f369f31efec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=87=E7=A5=96=E6=B6=9B?= <1025494833@qq.com> Date: Thu, 21 Mar 2024 02:35:49 +0000 Subject: [PATCH 08/18] update Samples/DetectionRetrainingAndInfer/README.md. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 万祖涛 <1025494833@qq.com> --- Samples/DetectionRetrainingAndInfer/README.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/Samples/DetectionRetrainingAndInfer/README.md b/Samples/DetectionRetrainingAndInfer/README.md index 04e4e8f..067df64 100644 --- a/Samples/DetectionRetrainingAndInfer/README.md +++ b/Samples/DetectionRetrainingAndInfer/README.md @@ -2,9 +2,7 @@ #### 样例介绍 -本样例基于预训练ssd-mobilenet模型使用口罩识别数据集实现了检测口罩佩戴识别的功能 - -### 标题,包含训练到om推理全过程。 +本样例基于预训练ssd-mobilenet模型使用口罩识别数据集实现了检测口罩佩戴识别的功能,包含训练到om推理全过程。 #### 样例下载 -- Gitee From b0ae6289ad5ed0a7adfedfd3d6644dcfa73de92d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=87=E7=A5=96=E6=B6=9B?= <1025494833@qq.com> Date: Thu, 21 Mar 2024 02:38:15 +0000 Subject: [PATCH 09/18] update Samples/DetectionRetrainingAndInfer/omInfer/src/main.cpp. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 万祖涛 <1025494833@qq.com> --- Samples/DetectionRetrainingAndInfer/omInfer/src/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Samples/DetectionRetrainingAndInfer/omInfer/src/main.cpp b/Samples/DetectionRetrainingAndInfer/omInfer/src/main.cpp index 6fda9c4..f18a11a 100644 --- a/Samples/DetectionRetrainingAndInfer/omInfer/src/main.cpp +++ b/Samples/DetectionRetrainingAndInfer/omInfer/src/main.cpp @@ -45,7 +45,7 @@ int main() ModelProc modelProc; ret = modelProc.Load("../model/mobilenet-ssd.om"); CHECK_RET(ret, LOG_PRINT("[ERROR] load model mobilenet-ssd.om failed."); return 1); - string imagePath = "../data/mask3.jpg"; + string imagePath = "../data/mask.jpg"; ImageData src = imageProc.Read(imagePath); CHECK_RET(src.size, LOG_PRINT("[ERROR] ImRead image failed."); return 1); -- Gitee From f9abe6d2c5083e5a5e2fa77c60eceb23c98daf11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=87=E7=A5=96=E6=B6=9B?= <1025494833@qq.com> Date: Thu, 21 Mar 2024 02:41:01 +0000 Subject: [PATCH 10/18] update Samples/DetectionRetrainingAndInfer/README.md. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 万祖涛 <1025494833@qq.com> --- Samples/DetectionRetrainingAndInfer/README.md | 1 + .../omInfer/out_0.jpg | Bin 0 -> 43221 bytes 2 files changed, 1 insertion(+) create mode 100644 Samples/DetectionRetrainingAndInfer/omInfer/out_0.jpg diff --git a/Samples/DetectionRetrainingAndInfer/README.md b/Samples/DetectionRetrainingAndInfer/README.md index 067df64..c3a981c 100644 --- a/Samples/DetectionRetrainingAndInfer/README.md +++ b/Samples/DetectionRetrainingAndInfer/README.md @@ -158,6 +158,7 @@ ``` 执行成功后,omInfer/output目录下会生成检测输出图片 +![输入图片说明](omInfer/out_0.jpg) #### 相关操作 \ No newline at end of file diff --git a/Samples/DetectionRetrainingAndInfer/omInfer/out_0.jpg b/Samples/DetectionRetrainingAndInfer/omInfer/out_0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..03c559a549ea7a48b0dda1c756cc64f6536c345d GIT binary patch literal 43221 zcmbrlcT`hp_ck2If*rwv4KgYSO35QdM3JL1h=>^Jos5FgWk3-kkTYXLQON*81`wjs zq)QhFM7j_m(wh)ELMSOf`r*BqcdhUJ=UeN0-rpD3$pSe!x%a;J-q+sO-uG(X>IiDr z)ypQAQR~*BQ0u@y)aoG05cSKtwb$DFFKe$2>(|~lZrHGX!={a!HvM}^ZP~m@YKzpS zO`ErE-m-P=1%BGTZR_^6U)CP-zd!oR#`WtrZk5_3_3thJS8uCdP%@j>QP*{?Uw0Js zi_E(9GV50BQE0HvjsGqWEcSoD*8Q@6!^TZgU>V!M4JEt4^4G5ii{1!U4DJpCe@AVQ z*(iJD)WuD^&2LH__1~lUSJeB>$1WAr?6qiTDQevcc(i5fKDqq|4k{g2RylD}Tj#Xy z8NIWHzZqRNzG8CK(#qP#*3RC+)$R5jcMngmd-nq$1O?+CK7R7_S@`o85z#SkV&h2v z#3y`6Nli=7_?VehSX5k6TK2iTqPDKSp|R;}b4y2OS9ecuU;n`9*!aZc)b!7pSvH5e zw7kOO3xsQOtwXK<-(rFH|5v$WfLy<U}}YmSb8Lti87a+PCgg)E-r0uZi}rWdEIDkN!uJ{Xc^JAGrol+t;rHo3~yD z1)(J34-qP;|MB&T=%9JtP0sz=1;PO zVaFG^-?(@0o8}ig4z8kluX3gNC&R^Z^oEBz)#1`}e(a;$-ow`Oa23bgQ>Q*$Zr8AtEg*Cw3qj?;fkmY&HZ6z;cO^p?=4tn zPmASAB2({ui?>JVkI&k?A#-RJ1r^+hRn*r4eBih(n)S!@{1J6ZLTq>gc3(}7lj$l7 zqif#m=Z6+^IpT`J0pe`<(&XNuRO&3-T#`qJV^)f}1|{N(pQ|XA%PMN{Ru5Ea7aCCg z2-aOi{Q@!5e86fp&Ps|)<^y)zVvb$Xf3S+;wy2CpVudrtC34-;Gi~P*HId&ekFTNx zQzzIoIoAjzm>{!?s`(0+$kk;|#3AzQt&5|Xp``<{`AAoOrokq-rUO~u1-_w{&)4=n z#)Zdubv(qKw2JbXOc^M`WTEpbn|PTbX;HXxD|1%?^RuW4R^W|6F+ zJs)81c^sZ5FM7{7s^4i5wqtxWTt7d$!fQ~J^&38)efQ4IxtB{;um{dqoEbVgf5r9Z zhX|C)|KIE5Dk{WE=ls2E?Q_FCtFGT4?GInP9p@Ysu|7al>e^uJazu2IFpL8$c89jK zj-#Kuoava;9t@^{ZMa+Nj+8ech1goE3J?~{Kex|&$E7f-624E7pU-D|tcw%dc;AvIqGJnBElSM-Mas=4TuKjM)i{7Ajf&(*DaqGraC z3KLuuqeMz9=ZI|85ncAGc%_Tmt3gqH=VB_#-CYM)MnWP_q_Fxw$S_-}3!bqhPQjh_ zUO$3d6f4R&x4DX|s8Sc_W7v^Z)aG{nUu>x{x%w(gk!z#N>F&TonMXO|L8d@4@6K?Y zMcn8#^%*Q=S~dsf*cl)0=xcmBoie@Vv&IpxW6UkbH@~#2?s~R^efbrH0u~_Q3)c3S9(`bUbYtwDOkx0?l}xK zt)iSJB||2Dd(advHfj}Rb~MLxL%XA(*KryM1JkZX)2%oaX4BwC{Yl9WaL3?PRDT)L z*RhHU+d~hM5-%03qU={u1ZH}>*nGz_<&Pm3V|{;cAEZ&RRGF>k`l?~-A|x(cMa2=^ zCB>O@W9`c(iexM9yz(l_cFIUjTnw}v)vNbDJI5%~K@gPSlBO>MG=Uiw#4b8DB+Z;GrHP++pXT;F!9~ai%N^A6HgtA&6?0zuOajyvAtHjdlOi>jq4-{eYRV!^wmOrG)CR z0(q}XMD?a-1_LHaK;>_K=}Kc;jx{|o*u}pN@8Q!duDHIsrSfV}VGJvniWOji;(i*paZ5$aPmhK)Dz?*bx8F8$ot>L4oE+-5!#@w!lb30{SOwC}v1~iWT*cD!SnrnB zdh}qY3D`fnqo>Du*>S6`lvR}BOFC9`+r0nhKwODvk-Dan2#mnfS^MFRsqA}-11@RD zY49-KVhB&J%>6n~pk)BN?QdimBUw1XjMLc4weh7YRpU>eC5vQ5HhA{QS)b*_4X+j8 zc{6laR}|jgNl^0Qqbc6|4=(%sF$o6_cQt)k`% zO;&KrFsDTio=kF`-5r(%o#RErgAR{O1@EH`r~M+a0&ffB2|)lsBynGJzvRI0cg&+z z2<=EEEn~1jL|H{;!vdzfiNBfwE+)(VE5L#CM^iiJ3}_$%>CKL-a|wsTf3%pB2gU z(}gC)wa};Tn;P+_p5wX9V(JfP!wF+$aS4$Z)$AB2{MLa?A+6&?#NHIWcr5qO>qEvB zl9k7r7Y~z}zNs%IU;!o8j^|8@Jth*=b?G3Q^bErEvnU2@6&0b)80u%tnz+h0iAy-t zAZ5w@Tnzt`zAo$!_31#LhMGYjq>c|?opdV+x9VI)wNt{^zDRCiA~Jq-6}899MV_-X zWJwn(MuQlr2QkU3s6+3*4|Xn!ty7$gcZg4=0!wC;L(C+@ETeShDM?}1`Bd^!k3&1r z>Y3em_C~^Kl2eL3Wj1^2P28thg zk+=>&@VY?coK=*yQW#LmhI-)bF`-uA9gno{t3JnS7;H}Sl5LQBzVrgd%QsNhY^C%~ z&q(*yk4VLbc?;XONGrKmu8ukoKGOD-T{Ie5CO<{Yv)?ijr`smYZgP1_7WxToE&90q zN_!$kGd=$_H`;^xzqzdL{C-)>!cF1XyEr@7DF@V7HXN};-|Q)!>2>d@Leyn=GBuVD ztY_m_uI~@D;?0`NqlOoU`2%GFWWjrw;AObuEaqr@!S(aaT z8%OrBrU_=Lk9u#;5U9JYH$I0%h@B!*Hb?t`RtElmk zhv&9IMQ|E$FN%b*c$doW-c()Rh%_ds!b{GkJa~!zEi&JO4X0n6swKf7Ig6uGvi?qAVAlxj@w%_vW2-1f zH2_#mYcV|XbHy|DOJCS!qV`05z<3ArNtTei8;<`|wOKET`$6N^srPOpK)%MtdWEssFiHW!T_ z#saXK&I$a!u}H1mSh4p;Dhq(l-?O=mkGC_bI~E}(DG@2mwI3oZwt@9+fhKeA3)GIp z@Or3}$B2UfW~%>Oe|C7*PiyXBudp1N%!qYdkz+MU^(~HKk7AcC0G^+1`&`oX0cer| z#UiDcy`(MWfWucqob$1}%2#ZNqz+|G&XKsXE}!PrsgEJ=D=Ym{&QWky*{ zX+ZLTVqTkEe6N&LnihDNVEtAhlsuaUpuiT?1;C_|HmRPgq#r*+XbwXEn~A}atg>Sz z&w8M9rd`!0N?fJVIGhFug#3BeA2SaPijeq~dTbG}_^9t2&0H@QMU2nKeQI_M1I}>^ zL`9{GO5rPMiFIMy{XslbcLg#B*i84|c}e5u8-T}*_t@VCK$~Pvso>s~J!A@d0e=<+#ImMTU)% zkPUR;H1X-4Wt^rXS(ld}|BcLHqa`32N{*GN@NxBLjbGj0hTRhk-+0DpM2643bbN^q zh`u~ilRrAn>nS@ii5@n_wJVDc0)W_fDtw&&#XG-`Hs(^n2`>A>#tY>J(qK$&Y9_JF-$aM%J~1ssfEn5cwyas72v_%~9QOj~ys5ixeDPYqFg9vl_2m zi9z&~Zd)gOPn*xS6Mq6h^&_Eo0GdC>J}MnPyuK@=vVG~CD6uNUqMU?WO2{EI6^TfE z6R2_ci&hmy5yr(OfIqOJ@AB5dCtuS-I3P8%_8EKDqnlvQqrj9LPmA*ht9}_CaP=4E zgGlFP#qfGIi#@L5CYcsyWv}F)W{B7V9Me06g?^-OJ@&G`GgR*m#{{5z%An#NkBWvv zVVG79!q5N=f#=G>pJL1iB+DmCq)kC(B=lZ^!IDF~T#o(Lqi-&?08Nv`OZoYPLr5dC zmsRx^-TzR3jS^A&kbk_?_tJ}{GuI-kTu1!YBXy9hPT~d%_1SNa#(4xQawQj_zyn)y zfkNU4QMx(D8}drVF3T7x77R8iMw?86aF?RGWN9i%{xuF-%7@&@%v+JTp5x=W+c74a%|VtvL_aO`ufeLck$9uL^nN)H*ZWx`+A? z<4D_AbbI$vneg~Ht(?ds?HhKYQu1FCv>iOF)jJp#FX>knl*&2O@A}P(l0_1FaqSVb zGtB%gj)RATDdjg1nK9fh%xDiBrqQe{vRw?#2LU<@PIIJ7CsApV=~%~Al&rpVX0WMM zh@Q`oKFw|-0vc*w%Yvb0TxDL+jc?UPlQ~djG-lb~_Pcork(hf=`9{-#qNHsJc;r#b zW5v&K4E$gRcG)@w2@gca1d7UlcUr)mSfQtt@BE!lpE-BmOMZY74YBXN_*JGMM9v^c zNELX47TU>Hfj1QU=riHp_{()-4DALJWi9(@Ix&LMN!%5f*T zrdSH6i~4+&D#FBS zzc^17ex%ZceXFP?=b+PYe?{w?WShjp_nMG0GZJDCfE*IKUr@ki&@zMX0Z^en!Vb|u z+{y-sU}5W+=6%+;CmFF3j6`EJ0fd)>ni?(9j?DK?-ki{T!?g7v_iN-haz3;<^y>sj zWPoggs&8X1ifrREJ&N$Vfytm9e5jS2HOAupx(-JsqyGV0P^qxx=qk!6?hLgDW@m*) zx89sEn>JR2XOYf-!D!|!Ssd_bi7L}DWwJBV{V=@L2|}+;N=1UbdqOS8lqz8E!7){| zPv6x^)VmKnZYg+@YtV`jNV=_v_MJ+vM@iovh2Y`wX$0RRKT+{_8Md$iErf)b$0XAv z`wp+@fnj`4veP|~@qV3nw~BHmqO<(34qDC!bFK1=g8hqW`&! zHVM-q!Hl(^#`_A6I7M-H6TnhkUI5&CRvVHLm^Zm=(1o|vsGLRcQ9yqtZE}0)j3jbf zr2)dv3^buPX>Lgn8rJv@2|AE9o?_~Z!`g5Q#^<(e0V3b5w zfkQ|$TwGR9?3Y^klLaVG7Re&I*N!zLz9;ZE6;%SGnM^AtFST7A5%-1C&tl4z3>7ce zj$)*K0;-8vQLERN-1~=PJ^SJ!RC2dJicop? z%YSpFkg7~>I0JyL{ZoH&C4I{jVphN^xNWD&`8+}1I^%C=^Z~69AJ9)JY85B48Axu+9KPrFFVov1GA7<13el( z$uAkRTN*mphCK=mI|E*ec4j=k9otV{(PzPbE(K$Ba;3g;#z3~3R!vY?gml*5Ik{}FRnWsNPQAE zW%)*g9*Z@Q7pscY-`?~38=vgHmMSF&{ya!{?Cyh5h5wn_5Ak3M%h(YYsf` zFVggY4u(og9{Z|2(0vflbpGj*KR()iEkpwHT>usiTxoIWq21N$dQQ=HOzC@3$_TtG z_?{V?`U+DVf?KLQpzbPR3{ZonjHgV{U~Nw0`SRWRdiaq+T9uU5Q^iME4eV)1;C`4q z%N)+r*BkeH9E3^XmzkNJWSWxb6(GYLnw4o*CdtoJC0fL9l4*%%TxKvJSGS*j7}W5w zm{E4o4-=K;Xe#=;sGb@Pk~1==7hsPB&+MYT0-LSmH5`1eh1gv2 z*8~@uoPw4E;LfYjrrTOIhOw$)%lTrU{$#)~A@R1J&2tfP51j^_>mR_gBhk1u647UE zRa#E0`J^7QsLu_g$BPV!C>M9~n@50)9sXV)*adZHnk>j#RmfwEQ;w`5;qrjG3AETMa62L05eJA^pK$1xi=OmLI!|H zEtF3H?KDpS-3w;aT-;(qS~(ylVV}XrS~z3n+|XwAwRY1*f1;M?)~uUkqS}eN)bXcE z^UZXc-(g*sJ9L47LtO$QAI@kV4L(XSJ3V_qG4cwv{dTEjA`0aBeY5hGZJ$2JG;{UK zmDao#aIh?Ap4#K?%AD{IEmNnFHPi{fn($b~Rv^VJY zHsAxpPPS(JFrD<0-S)A%7Y(=?ptY~DEwjksgdW`4TDP?>T^x?v@#lXgDxNWU5VSCx zN*Xt~U&;Fybn{~Ga#ZxGf>8)DjQ-e+Nd!5`g+m9Odk1oc58Pqs!?RRo(kUq9@ZW$J zAq#dE`2Zf~inzZFq`IlmE^M<4yJwJV1p--sY*}TBF22cLPl9YLR#Ca(w>^!g!^mwf zY0-}j+o`uel8>$@7xvDd7@-n}4<}riW+hko24P4TAq@sqo&`2slo6@mNEb~BiR}e3Kho|5A1Ng{U zMKzN(s}EC1yjUvRu%Rfptmz=>l?JuPOO6A&CHJGTn&q@}+vqaPT^>T`#G#>K)5!*3 zIbpM6o>tZICGBc7|Ck;B@z@n@`Ar8D;nP%0_~t6A{~&#o{uDHMUg#6tqKp*_of<^6 z7ZCpkERZcH_Mwq2ymYZa`&_8Wc!+jo_uZqS-wK+MJO&cqJLIyP6`Pb>?RyjSE(OxMqXx`2jq57*z^=(_=yNioao@ z9@Nu3g)E%q%^TEuxQY%_DInT(!<|=Vl7D5%Cuh1lZ-Qzd2+=0cs}J#;-(MWX9k3S! zJ@7YCdY0yk6t+#TITX8sFRDjx725)EEQG_!`|dYSup-OlNH;?c<0@ki`+xZx!Z`4J zcU%@EYp}~~#@+NNym15f(R`dPXF*T%kL# zG%2s)WghVJtUStbb&HP7>8br#U(tZ^Bwc)DH(YyB;qln{n1R8D$v)5>rk#7~sR*bR zJ!qaxa%y_6FOw6DGkwItB~F~&d{t3`W09QsC` zQ9wtUKzR8Hn7gP95%NMC}{Rzj&c5Ie<|HA&(MC$>a3+Zcs|{;hAC@4UNtggI;SP5 zb}Kz)I@tsg!wyIQsw^j_u~o-ZNg3IWRb8-gNx_d4=!@i;{@bKd%a)exoE<$L;VBD} zs2-Y{i~Aa6CYQ`8Ts+^iQF8O0XuAuiT$_inoqioFET>HVcqyCxG68~aJ4(hycs^96K$qu;1op@ z^@;PvFD~qE%EFsVUbUh@exydB4;w2>inGo>?Kou5ALx5hIrL2h!wN0FvswMGzZ2%r*?Ns44BQphNj74PF57MqHhX>o?#=*=O2F=nlSMBH?zd$ zLo5&C{1NCe)pHef5NYl0rT(l|`wkc9x^m&kF%VlznsHH=>_UkKCwUP6ICWk(lh8yi zZTiyy>hr>SE0FSy6*55$sg^1ks<%3L-dqwrC~1i}>!jR#=bYMeHtAK+ z1JVhx|5ttbq@&LJ>8E?lFKqlTt{vzHgO& zI0O@~j2J%e9hT;q5w5^BLoUZdpuWyg`m^2}GU=J@J^H1!I6hy-(WK$FRR@tB4JDho zV9tUO5DW?D2G=OBqHf!Tn&&``^gT1x?u((X53OOjKiDY~)o~G+atU!NZl606ZL*;M zd9B}-9BqQ+QU8G4pughaO^EbQ1?dPn52RuKU_QYanoPqV$n|r*Abb_GWxk*(Js4eU z^U#?Bmvh9QDHU_YA)9ZPN!mXyRoXqUnM%4T-M@+o#7UP6&+oyy-xN`Y42Mg6$JP=d z%WFj?$~?W?{qv!D9{Xt_I32xg8@3@Apj~vluiPSZ-!&VIWUK=`Y{|#CtL>e0`t9H& zTfY>fL;a8A#y^URzX-Lq}{(PI*9CxrMt_f3WU=qF7>!Usmy?uvt!$edb`jtZ0 zgfO(2Z;`m=YB!kER(*{>3i6k}@LC?4^$(%t;BxwFh4+Dn(X5xufaaAmAfFBLm2VQ# z{DkhxLh5kuOn5!^uLS9rJ|Jot5@B&62{cWJ-sl}V#4S!HDf|>Oq@y5)7eE8F0Emil z$>3XhkqVpQtlrwgfu1)Sia67aa$M7oxKK2IdbDzJXK>>=>7{M*cY8sESPO&8K@jlG zWSHsQ8tMru?eC!m#3 zw#=Th%dLj>I?1hwYRt;KWC|e54K?enZ0b^B}f%`><3D;^Lb}?XkotQc5ICmSK6lNIU|p z=tz5>k}Lp>&bBC2LJ4Vum77$Q4P@M)M=u-($B_OFUx_g7!p=?#d5((pY#kMn* zy=`bW6F;PkoqYrH>9$tMP(pq9r80Q|vgF@`!VasJ~} zUvNy#q0zIx!)Yldn0WOvA?+nBm{AOXX`t!{JSK&RB-nP##eEtwYQ1-`>GA#k(At}z z2}Fr=NPx<9p_5=XLAuXylvposu{bVO^BFFnTolnNv&qr}0@wq1ZPX7(05J}H^qBFV zVv~1+HM%eFGmdEJf1MAk^u{gW6~R3LF#QCWvm)@tH~H!*{Kz)=WG9c^Aj$+|*!3nYc|lmP+9V_-6$9C#bE2z9uwZh(8=8cYJ%uYv1jV)P{NPBb(Tf042<(i@6{;6l_Z{Z7vA7&rawTq z5RO$qSRt?wx?-$iSx^+5m4L;7_(!|?$joq_Ll6O;w)5mq&28tXg~J5i0zf5?&k^2^ z-+lx01DazPU8J%JbT>|5f}?(|5P*qLlosz-Ng)PItMt53QU=60Z{Z~#?ty7~`dMo+ ziY}}`Me5y*!=629CmwW>`sYowgTV24IO1Tn!m~HP3cvo;qdf>f7I`I7F_TXCb~rm@ zkS3xpLjWg0qE(c+c$E7)_k)1>_U8e#kUkWc4cN=7G7z+z?ZHp~hK@Nx5J{BqbMLVg zEV5%a|8vXt_??}yea+p1`r`cYQ-Qz!kj15MJ$-#S|0C>aEs z5{t#SoL+iIfIlGN#-0F)sr^8?x5V;czDK2yJk?#Q#B zVVmi`qUhzNt1adY3|T0bMVp$2zjlpE|NV~~tEYND`~_1G@$TI;C9sSU~_rS2WGsfpruHGfG>W5WSaNLY#tb5=YBd+0} z-A5JVI}MKUx0&;!)Xp|NrkH$iFdD0}XdF_XFP1oX2bxS>Xd$Xs&t%x71e8yV{V=zBgJ zm@hLlmAtJ!`V-nn!KaANbcRZeQup!cmL|8#I+vdKo@{Fl{l_R#5I4W_^z@9@DUZF^ z$}C3r$Ty|%Ek%0ryT^A>8_rvf@ao80YEpmR>z{dGAXCbnPM8|I@{Od~=Ii^&C*p|g z$;Q@~FE}HoJ108F7a(g-Q_-c8ac`H@T-rB8Da4HP^GC1F+`=%!_tmKM78j{KJ<{7p zx{VYw%`Ibb+2-bN)ZS!T3VXH3J}p=Y@FQI(n3jzTl=EN~w`?x=9ijIiweh^6ZJFnN z=A!J+fxBe4p7f5XnIp*^gD8{3Ec~Ehv(}X9gPke1R5sklaP@AbIMG--lV1+ zTpX}^Hm1*1xSvh5wxTD4*&KDHevvJmC+H3`N((S-e5DHk5-qKAT&yuUDSuy zcjqTJ9Z{idw?pkiArDaY@gvYv10_E39e)=0#8#N`qL_)^?YZKznf)Lovwxr_d$7o{ zA@hw}@m8Dur|qV(k#-YjHpd9A<cZ87E4V6mzE7mqyqye>II|08yNEqya3(%T3f$XOQs1+|nCm>TS|GB$n<-ZF#8 zRYjY-sVh6acvV*9c;Z3DuV20l>YM+Tbv+zLLw(TJBF=VB zDlgAj^iK93#ctboh_vO;8qo=BN?5KA{W{uVidZ~JEXh>9-r`jBTIu3-4TbI+swxXp z1FypodXz`fsEyEp*|S=MIGNEAwQ>Fy6opokBUAC{^CuhSWdv-qm9Q+&g zM8sfgl&g*66;~^PfA2H)dg32__}aLe_6y^!%EyQSayCPLEZ2-wtQQf`jf($(+__)++J!^Yqafd=Uv$ov<-vZKgrKsj&whr zy|dSblR09!tE5R@(*N-%{r?0dkdxNu{(YQzr0SVVpUB)Uns7;-*dF|Kk@ritP`g6 zLd!dF$AfQ0nY;bSat%DVkFgE*`b@p8Zn)x4VKM3?R960csA|S=!&v2_fSuekwsQF_ zEP42*YzqMPmu5U84?cI4IKOe6F>q)U-AJ34s#)0;EG=>rc{O784z+Q@s-Ys?znicZ z5GR%B*Mqv1#qh7@s%H(3u;&k5KO<}{+0f~mE1y_obG>aoThBA3@0DLTa==yD;kA>3 z?5^pC7TF9HzqO3q5#221^*?fW%r zt(8s*C-WmAuP70Tuf^g^2r3L|No}|E);>a~4J@a#_(DjpWBg*YvtZOz{*=0JXjq79 zlX}doWqo~V^q)C}I7Od_jUU@Kq&{bMWsYGFH8&Wdt-bS-)Tb3z3{*8(w8yNjY`WZ| zI1TY8@rT1F;Mx(!TB`39E6n58t_e=Jh+wjpe0@68aanU5|i%$`2$LL9pzT5sQw z<$S55DWkbev-n;98;;T8zl~4-OtubAlUbtfG1wNW$PdYJ2@biy@~fxHF=ifWdR0`B zo|4V*KJ#v#ar^rJ_#QuCQ(C_5@GrLC6!$u_7m1^`bNwRIP-RJszOv{{&HG?Ei>^?; zvFxy`)8+=4s%IyTl=i&qcD2UM^_)w+9k)+v-7l$sHd7iCsx|mZ?9N6iP=|upk`J-U zng;eRkkgGuVd>xu=QkP8hR@wwnS4?jqY~5hIpEr7C+in zB};tA>rD6i3C;?uD4k_Kc`N0JeVT1}UHg5{maK@tTasc|8<~uA_d5OGjGmyyS)WMF zHYrv;H6yxYD6((M=-6%wpx>K8#2CQ>tY=w;~RP+*HhuhW30c@AgD z8Kw^~jBjaZSn_0_#MfIm=0tsn|LtO%j@X2~US`W>yV*>~6#t+?Z|nSql%hf7eR+(# zxKYv*d!LUj`@OU>=Y1yiIy2F?=TBa?eUiIPwD*9L%3y*MJ12egGNclEf#1)D>MXtf z^V2{oa#UL|Sr)zE=u_LN(Xci3`hh=VT3>(qyz_m)W-pfCLAd4;nEBYspzU$l#d`JW z|Cw0(xjpKP>xqz zLF%X6I<1DJSH~hX11?--^O~fe{oQBmiI?HXUa78u(STA{;K>Nc)5w5r$8CdS{%H?` zvR`M0ESDxHsOHfZ=LpUX!}5Q)tXIe!`J*dIXFsQ9M{~0dMKTl6HFHMwvU*6_>og6l z6Bv!lK{^NY2cgkTER&bt23H=Zxy;dd(D>qFh|LbVc%SY{LaF+m>P@U&{=7>wyw~KG z1q!Jt`EcT8qC~dML=Eny`Tk4$hq(CW*cnR4Pl7JM0P|<2KVvT~6nud+2Rf5o8^8TQ zUA-^>!&l?YD=NBCSjvB;SR>F}YI$yVD@jZEq(+~Ql zq;xtGC1M6UoY=tqN<5^vaNsJp>C9K?t)_#(R^dJ%MDbwOx_WIc@svZe+&q+jm31Ya zwd4C25t%J=?Wj(CdM5vQf!OF--PkG3H%hA7^LKKV9-?)V-SF4tPUPCe?mb}StW)9M zRM1L~3EU3Ytq|4R0Ep8B#Jvd$9orSJ8BEc^sQpuPhBIgwV|Dk7e#A)<{D*s6dyb1vcFq<@MGs( z8Q9D$z!zEcPR3=hV_ZtQl%Mt5#Q=UbIhCvWnXfD|6I-#S-fw^gFh}7|6K{@_i)=~I zJhm@juJ-9xg|8CBPhgV*J^BU1?)=^-U|c4xm~p^Zl9=yPLPipLOJ{;^&_qAeB|~@} z&yd|eJ$hEY_C+7L;pEyX>C%A~Zb!a23w3-t{VSMc5WTf#>fMaMGHevev*a}xZJ@yg zI3n6vF6LlpSW(BabU~jPQi#JK4~?g6mNK{IS`PS%9n%M9(K8OQS-`8zZs+9hB6Kr~ z7*8iepyOVD?f@vhZPm=3@ySl)m6qxg{P*LS{d~@tsJ>c-ZQY$0?l%=iGAp2RCL|2x zb#%EZywhw&w?~6gX+=cN`%=4`(2^oR`GEEwXyf(1H{DD;nWrYg)T7JtYvM!cd}E#G z)4^brh#af9HYEEuGrSA(%RX0$*Gi#pfGdfug{!D5U?At-8rijnU#l)P=P4=LSroWJ z95(XZW3hHVq8Gb%M5cHaE5*l|b?X35q;P4lY6sK<3Xl5ye+MVi1DerDS6+pkgOrk! zqd_S}QX;m7{asI!wB?4~s@>!FfI&tvI2Zvol!n28*_6SnlFxu!JF{#`TwZHPV0acb!%aH_LcFFwd%|VU9((4%uFQsc#TdkaVKmt^WsDR; z!rRKg@9kZ}(~xdtPr}$0zbP=UOa~*Jw)9dUV(wJc=w~r{oId~hYR?%;rM(|u3$6i{ z2`~A8S+2%OwDZpwNyvngp_p!ssgC6ni^UAxwGJ^P)P`q48JKPXWYK5hvTiob$Mfjn z$jP4vm#IlNGvQv`=~AflX54jnJ~uReL?m$)9e_%7RK~gLuFEy~J1Z1|qa~V@gs$sC9b%A(08N$(ADI~SF6b{V$2NJuyNU-_Q9oKCFbUji z>=y)P9B1CfHru@TX_14cH)6MhQB7m!b~Q~8(x!~2U^TOsyG}psTmcLU)h(FF@UWa^2rCF<$-=u`fw#YhYyZ(n=!m|?3_%z zmBrl8#BUkKyd$=}=_#oO;E=a#l-dx1?t;&C)p3YH@?EO{|4F+5ZJlbb)<{$$c3QsUl40hnmGUQmz zhnnO_fnj@am62G@E8-TRkCN*+cl$AY*%cKSzugl5Hy$vK)gJ+#XsI=e%)MK$ z9xAsR{u(ZW>>d}r#Feb=DA=6jKpTBi&!$D!rz4f)-joF9%v9F< zVM+7m0*Kpj(w&4vng1l*w>Sb88B5@FwoY721P6bL<)!>Yy8I7*>yYmXTiH)@7*uey zH^aFBG^5HRsm0Qpm_07)tlaSDb3uZKPg8@?ydQz$8y;-4wApT`Hk=<$IO=)KUQ%Ij z#y%{ujb0brcVzMT{Vbu{V@cD0R#Bhu@Jx$4=`}dWTnqR10T0lEjxW>uLdJqtQdd#G zi?*}C4yvVqQ_tP`+r+=H%btoWp#vV=8er>L<>0@22h9iSyr&HJx-sKQVRud_S(ADc z|1^o%|M9lN@!);o-e4A8Dt6~8O6ZNdE|y|nF3=#YocH4Bfq^8Lc9!$nwJt%w5YD^; zez#%NO58O^}OE;9XU7jMub2CJP+p95Qnbq!{>w+8*c& zs6N+DV5HHdC7Ssgx((!7<(?}i&PQW6fU^i7dV!H|Vfc-X`hxMjb^Msx8#I5#jyie2 ztC@|VH2;xvUn#Ap9nTyqDXZWfU7MbZOLkwF4a1Du6s){x%}d+>x|o>Lm?1&nLKOdW z+X-mAO_!^B`P{a1I7|ztqeR#jW-S&SiIdvAn?c5w&Y}^&Zlr}h5&3Z4Gqv$5YX*xdAaj0Xpsa?HY*ET-YaB+x%*7+No-lG+Fez@ zy8$740xH5mh9TA9H_`BJz*xl4#GE|Y)NLT^S=fa9B_ax^k-SR&q{u3mOU0Wb3k56r z1%+wh`pI0NtUctPvQ3AhKo3auUd@AQ#g3FKfn-oc&!1h|JQl8l?mH5PJx=tVP-uQb zplvS|gLPfyP+u$T2)AebU=0XfD8iHoFoLTrF3P_pQUQnR&epLnlcxG92G<4PWc%KL zr2E8H%-$$ll5>8)y{~=kYhN4MKB$JtbyWV|FL>k!%CO!t zGl7^GvZ34#W(N!l8>Ux zGOwt^3(;-?ds@#SO5L}A??wEMhVX6kXiKgw`n;&Bdy7dh^FD0H%9Q;TorS5q_m||&3*Rw|WIn**0 zRfdE(n-8ZdTm-!uro-h6K3Ov421(`R{;v<8^b*FlvKK_n0e+-M7w|ooWkW?GXlw~K zeO|QZqlA2uV?QWZ>l{jWj%}!Ir)>XW?9kr5mme~1;$q+%IIc(J9gobHrzqB)2ugR8 zHsR3yh4Z**3#&n+V6Bx(8zTISVH2&^CE#F(mEIEo7(yX=3(|LQBl4K=G#=N0R2Wx; zeO|O$^NW^d2l<@1s^%Tz#5rYQ{FDxUbmZ6i5eMT)TR^s32+OU7?WH zX!c1?)EpncuXXyMVvTWI@$;|ig_x;kLz<|mQraE|fKHdZ&+)b)E5PQ~cc!G_T*dDw zgWyBzbaZ0aDn}$Q$B!=zrUW7(LEZrWT&H%RJwY9cM?EfZ$)KF6Kyjh+r{rby!%BDc z(+u6XlN|?oe!%Y-x2=nt3)!G-1}!Ok_#5H=xJQD9+jyG(0l5%|ssDz1F^$HO zwY;_>h}j^IuH)32KDhHv_4jm?znE8-MXR8y^ZL^8^CQ}*345H30W2`FRl7Q#|gQ++BhM;1uNY*jkIdc^s78G z_@@EKIx1MIQGjSOFcaXrffjqA!&Qx)WsO5jwr{XddS*W}>8DeHvK zNA5b9%n?T&pAl}q8%vRP_Vm)E!n)8aPF{=+lE7!vw+UXtZ`&TG{V3O$iQ!wNBsP(+ zXiRFNuwKF*PEP`_@7}qT)hf_OJ(_LSehWpv+ntn0d~pZv_Kqk+szuk=7W@|(wzuk` z<}f0N|FrK%HWGZd5?1-n2@~BfC%h#akL`kT9MnG$gJ<8?;=6L)rN)uDXjy&Q<TXO%S_niEzyc=iSCVigI%o@R08OvO69r<3lmLOc+oX8 z^i#^tBB~Ud>ae1WcO&?Pnc9LRKAXYuRa^M$Y*XTJF9X-Rla9{aHnlpycS#?m2~7Bm zH*hOPu@ZtYVSe0o>2u_UXp^N`Xz4UW-6fo*loG$numfp-R(4%}+^Bb5AyqN0)Lrvi z4(or#NNSKen8y)LYqZs8-Tl&T0+@@!+8jIf6VC6BC;U;@8HlKSjUbhD( zlP!C6Qb!u~t5Sy%{%#9_$r}b-hYb5q&Oj*mJjsk~a9)+p7W5m&t?(7Mmt6DJYaLt= zH>EJCbk{TgNF2@+6<=-N@@{bXw6dJh9T-Kf(yABf5Qj$R=MqD<@#-qMAk{aYMYJ@| ztrQ!?(#ZPgta6`3fGuAqjeJ&#{ap9Wmf{4~urt+}V`rt$0r!#K0le`rUab-f+jw7q z_Z9@YWfYquwr1y|B34-uLWcPi#qA+ra42H?*M;wozCPS*f9E>E>9U23H^tx{WF}ET zXY14h9y$vmxNrqq5*-(4+33Mm!8~_G0yj%ZU>|40OH4<-=L?1^t~!@QYif?XnnXt8 zNs4uo&Xvv!<}2dtuj(->UG@i>&(I^+qLY`#Hs+$&2~@WwVs`rZs>I{sjtdsi1G&HV zw;K(mIUzxMc0|xfEzE^v1Uo+CRT?}Cuu&7#7dixQxGkR-1yP23=uihQg>4Oy{-_CIl2{z`?FXrbP;Ir8*L}G#2Q4Rf_jvX< z4`+L!=?00zP5Nz2K30Z$PZMUmJdkI)M$WB@Lg`|XMFP^l7%mvZaB`*nJUel(pgP?cJWnDX3clJR_<=Bjs9)R zi?F%dv%7d954ypGgLw#4b1y|0wG%0-hO?fm@{*bivKmNTE(1~3poS2}R*=_));ghL ztT;k@o`;z(c<{hawL5I^`!Jeq(|s@GIWGp;$h#7#+;u5hdSs9l%d0YtKl>u6N5t22 zcdO$BVX`Csp$j!Dr1ioCbCe_{PTE7~*xgyF)$fq5HytEKd64BkVKmZjJ%3-kmabCHFXfJxHaGi*{UG0;&Hz~sk@_y?n+8K$-u@pj&kF-&Y z$*L@&8o2bq3QhNbbkn+LT64@}xHGFa1TAH+O`I%p3o{ig;wXI;Uu(vlE}|lhj1Bko zMhOj1*g*ItP}`kczqAWt&ZMFgPW;}Xs?;}f!nEiv+27UYT@UpSm8p!U*|1{_<~yGu z>nU;U9#z_bKTa&nufWq(WM4niy2hq+`}uf#^6tV((o^mhuhb{7Y!K-biTY><})q-+WGBU_Z;E)TS4g8p}qV|jHwCxs#!H??x zUQv5Fztrp7lGDF4pDNfNa4LWDx^b}bgSdYTew2vmhiWyBhbVrOWEZyudL1F|uzutl zyJ?Mpc;)E$=pp2P!9J_$5o2O3ddT)~8Gv|30=Cm( zql3&Byvtf_dgLuTI|HdJ9~wh^&@TDfcK3&N%o{p1O)$(3OW z8yA0nav92iHy#3qM_{fj;8k0)m#ODRUbrunq()A0zYFy~rhLFA`_MOVIy}pzI#piU zAVLn*RQbm~^s^`S?c)WFD9Bv>oI93)OImGTFF&euuT{%t1Uu+glI?~c0I4$}TZ z@Pny=YTxYG58@)5H&v+#1M2s1RT2B@vcJOD zjg(zxbayDq8NWk};@y(F&|F(8MgxH%NIJ(*;~msXUa-9DlIEZmzLaoEQ@xOt6QL|z zsXk5bW-Z4ygjyntHJ0k@h57EzDs6ISDg4(-5T|&Q)=~1~#;5^F1S4`L;6k{~Ce16c z?4gdJEBNlNZzW;af03Jt&RTR>Sk8IJRvFr-Q0eZ#xBJ$#@f-ZxWu^?A*Z@34wF zM)qRx8Nkg;U!z`<=m`_aOS}nY^pf!}SA>6luLVF_yr!IeO@cfXto*blH#sMZfz!-_ zl-twA6q3fJw$O-n_M{&!l>pp8-(5mEH+5iZarjfZ1?#v?D5rMij$w>66JsIjK_3!k zDf4k^ea*K*&YJ#+gYpKmq{aw^(_l>{yfXlXBxr8Cy^ZcS+Fw7n@B9S#!FjoH86FBJ z9%Uu9J&68sHs2#6rp#M=-U6x4)3AH(Ymu2#(gUVTn!V1QEAe)xkN;J8u&)_E-&L1q zvue*BR$cX}0Ndbumvr*AMRI?pbO}2hm&Do!*Ta6;F{_tGkVZYg-Sc=r0)sHLCz!X! za+|XLxwC89tY$S4Tpv}4pIi?W`_8)l;JD<@eNn09P9P#z`e4SAkMpUDD|1z#))&n; z>9P|*e}TJcM>9bSDu&av_VOa9)-3w;`pV?j1mz?>QS)o-#H_6`m21m9p%)Aeh5{@N z`XRCaU4_0fUc&mG+;z2~i^5IK8JkLTVC1^KJTWu6zkIkaZ(tYl{>HArd(tOp7rLT; zOEPO?Uu!pO#=Ak8Xq{}l`Nh>her_KfQv5SxLSKxi+)=YW{pFhRcQG2Ag!^{FWxUt2 z4$3D+3>?h!fw_$f_OCeAD#C&?+oEs^>DI-T`>GEh&vtRwyCTmH^ZNeC*!I3(IBD@p z<04&UqTo$H)fa39FG!IIF>bPetyy1>EpHxR-Qh0gqJP4BXKh3zH);D@in;Yyi$@%-r`@G~%_HM68>q_lmwkh-m9I-wO}a|)}Lfnkna?CtDq~NIU?Eo+ zy@5%4f^vd2)55vq?muy6s!wK%g#8L?N0?kH)*0{T=zQ7If5QrE*Z;_M)M&uUF;q(t zDy@31wU(Fof134#j+)ST2UXm+@hU^oVK|H(kQI$uS#@ zSQTL{^qxOa9ICB&J4>pKbuwR?Ok1r4@*@Un1YsjS6)?S<;aEKGJaCHh?nn{iH9INL za&2#@){CS!6oXLk!6}X~kQ6u}{fmC;L4aY=2JsF?*m&I$U8dSYJZ9-o^5|cy=K=ZI|SQC9+@=!)S zYW@k%$KD|h`VpC%IyA(rqt!4f8+C}Q9TbOcj}H=}=)2O7DINRovPY8mLW{udJ;GG1 zY@$O(+>W^LN{#IIf&9Iob~cz`0Mu zEF_`cPo*wy<7WG9h9fe-(vt~ZuzE~hKJTW(A0KE3-cJm+a{1fBF8b`48Q^xMB$Hl2 zV|dZnS=4cml=U+}pv%Gk$)wv7TCt=Sv#XMtVK3qkbVHV~-BiY|*HT#vUx?IejUcf& zB#}w=$kR832_5phn@M9cFXk2yaD_mG-<=7GGd)Ek^s6V`e@Hhn%3@*41l~%5pH}C9 zV8E;)-f`gjTSE|c<j)(4K9$jky}lYEfNU?`3$6Dfh87wP`|yC>u|xFKDXa2v=MSC!o_9WasxZ@85_ zFYaHsJKbqAr1Wt+qtw+l`84rW@1aDnvL|v%W3%2 zyaS0qLt7Ib?`t%eZ%%dqB$CZri8S7ywL@i}7`?RFNe8nVL8MYtrXM2@bpuPo`BfUf3fqvV2Ase`!}qdCb2tg4vw? z0C|IjzQWCJr8U0VV7_Z!F{toO86)*7g?u?!#zsRLEkR!39;ZglPR<8F3^a9P%}n}( zC0jNc6TjSol>FZEP?gfxn*!m}OJ(u;Bgp(F&ECl9jdHk>d*=NiCEcLY_B?uWhc8{x zK*FIBE^tB0$1|bKzQBk)xE3i9jwE4e(c`z=_dT7l43ofs8hQ}It&T6)8#_;2 z*5Gp*a7i2fUN-Dl+ljqlEuHj4J0h>aw-%mz!ekRAxPC9Trs+EXC*TcqkqP125hj?& zd6I!emEX&S8%Vx#bWt zE10`Rq=@E{+ijwJ@T#z2`$1@4i{B5iBr%~SPD9byn%Zg2l1yyF9K)E`9tsZ>4#-ij z1oPSHDpiGp?WL`5*yZ=oe`pxC8RkPB%$}wH9z8EiuqYYqW4thAsXQ;RX64A4&x`6U zLKQA?HUj0(GljUa9-0!0+EPc7EqqF?^R;xx-a+UhLjUql(QJZfTm2{4>XgQ{0gZbbvJyqs}TYE3~9>A|^`0kGwt1ohEQj6xcF3;1Q z3o6{-SieQ2j2;lJ*PM{*VW!H+lBR;ODqL+GcD}n44xYv2Af&W2!n@?IxQJk@9m&on zo8#@nc0BlkIAlxZe94F#yu5Nt_*Hfj%ll=hw&li`7xs}G0M+V>kK&$f@#@G(xgc9% zQy6n`0k{m&p)&zzF2Il{5j}UaHcAO`-$0tDq!_aCsGIu}Nx;32T-$b5Q_!%LnmwD? z1oT50)~RC6P1O;fg*o1_(AT>j^ix{$9q4_zq?3V+gkkGaR7_i1NjRZcwoYTF^rrgf z5vfSdv+Y?3M~;UE<$Jp0e3wG7w5xgJs{ndC$P(T>2Kp4R?)Y73cxCnww~LG6svv@L zGN-UBH-2Kq|J8(;lXQmlTD0L$WE>Gc`?$38_R@e!*pQ( zK1s8Z3?H`W*+PuO&QpWVNLKW-U-TeONdtimdG{-rgF`JdodogQZkUQKBkm>m$S&OP zC_KlzMqM(TTtqv@N!X6J%^=MC?j_f{Z6xXIRcA&FhzKcrjBx z8S;72*kB1gGxO5KIqA=}-u6UWGujKB*C@xLirHjP@>Y zkGL^!g@kNFq|;#4rFcu&&jk|j#x(Nm`w48hW-V+g_);?B);WmV+ND~1yM#y09c}rC zA*y7J4YL2>d3;{92DlsivnLj&9#q_2?|y;O+Tsi-3M___P`4^noxrWt&zx|qu-qE^ znBH4enhh9B@POcfbP<8NfnW9})UY2Py?4@mrRT|{`d;1YxVd1f@E{x%9)Dt!QT5$rJ`akjY{B$}MWJ2KovO={>fe zrYD(huw8yyk%ln>5v`O^%ZtBEupeYsODGA=3sY1vb0>`ag8!32s{b}=S3RFrS&#fD zN!y?y4i+=qPD}oqCl|hnEzL~9;Ze$UUU?Mj8}jP4R2r|>S*W$`w2Dr;x1IK2>awMk z$NsDPk%E%i{yGgB>CQuCK^EiUpAXRb$(H~1>9O0LMS&HZbK-Sb3=m@a){?WwVBf{=-z( z%xiFLFhET*ZBhh$#_PKaHYB!wAA>uaJ7ikD3t6M=P)Tg!_`CEMrE#`gR2$NRiZ61Yk-9 z&Za9K#`=HPLt5Nu%OzEUQ_HZk3h{-K*)mYPpRf6pgwA1XDNHlp#$pu5rn z6a4@Q#~smCc$4Og#gO;yiIejdF8wu}8;3Mqd$oVO9`>dsj%0Q$WAPKYWxL|8r+hC6 zU9vZZbiGwOdA%R}T`z@H1*R03>@D~`5W8nWv(U*Qu(cQ8Gr`spmUrFm$3Bo$(Efvz zIGCJ$y9Spl8W)~yRG+o2h@2~p;O;8PN}2%wX#pQxaDAYlxXJXHwADE5R(D#8*1LO8 z>%!HiyP06?@XPrIN&ex()KquZGemg)FkNYaZPdPn1eyH}0t9XMAIUQlHLPWAC2)$P9ryo68gQ+~L<;K>&(%dfiRp11& z;4OtyI+=gXm^&>1+dg2C{wdh;IQiRL_368OEOJ*u_wGa1{gu89`{he7?02cViFs{iqWbwf|aToq|%LO@i?M@v|v~QX;Hl0 zzD}CK_6@v#O|*h^P0{i)kvQClB|@MyzcKiNzDI71@iV&+h=K(v=aIpIa>Y)v>;22U zK3xY&-E1jf>D8s<8sZ*!SKI`^XnYlE{dvBWgsnb&?SlnBe(B^whJuDmFtl92VD5gK zDQknxFx9!?&KflJO&ow$s*^D7ZT8;8dgzlx3lv%==oWl0sLD*R>elR=`wIU$ofQ)7 z*4Y3&J3I^fE2e*4$zrbC59YX+nT{JCG~|piJP*{oEzCn73ZVoO|NW2KmcwwWr`Lv? zs~aBtJSv7b2WR-CgUJ3=qx#tXt#K|#)colR@OFXC4IoK41h0@aZTJ`Z(6SLXQ8j=( z(n?Y<#m@8HvK&OCsf}+d7;jnzQDHhxB^$q|Z1P3x)JMk6S7yB<>A#E&PPD;+>u-qb zd4~Gd*{WUH6J^VZBgJA-dXQdc3Gu~7!IJqQFIv&k*^(Q(Zw=djk6B@uhiV z11S#EXW0u)SDe=&o6QAv!;J$UJA%p!T8V5IV}6gDUxrS=y?)t8!W1v|7uv7Niorkz zZZ)3!9qqJ`WhXR+dSn0CFj*754UxY-G+rRxSqHTgZbFo!G&sg#u3)dXfoFJokTFNU z7=%q7@ilzI-XXshe##4c3?L-Nh&dhCl#9!(EHcjhdV*}nJf#r}DZM$V6C;ulJLTQM zW&dvB@bESjOziWbIUXmW>o9>zZeOZd(_8-%x18q3pWGhhp-e)@moJ&PB3Jx{X!K1YbvM$L4VEM&Jf6&lgt_CI#`ixIFBpmgZ;52dFx;omgPzW(;+@=lb~0BR zdF_CKdhY8lP24A8jZ(FpSVPYHvH+?IZ`Ems~h=Dpnw zg0QlkGHx82SG1A*#flG>D5sj8Lg?QGU$Q(LYRVk@{{CHRPg=e9$s29$N&BpD_rj0* zDcwIWdY#m?T0qd0%fQZwRhu+NX>t$Ebc*Arb;v=jBuSF`6p*=G`{iYVy zo$=b=EBHs8t`v_%36IsXF#Vv$^}%En-5WK5YoN{t3e&H9^jE%jgqh+S2JVAy0)Edp zW=R-jq5F%pmW|>35(wtLgz3p^9WoV}+d#2@NS!G>Xr%ysk9^MAf+g)2V+uGckRx}s zJ}i%QGkh)S5lu~AL?=q&p7T%>DCEw`Y}AZd!tCq2!?t^onsn5SaKV2vc8hzex3=Q? zHEJtgZjo<%g|}b|1*x{!#QITLMZ(KXoI{d`49rUFE;1YeDsw7J*1;z3eg6%AZs?)x z!zRusoismUr1OG5ekar0L*WvB>VhO9E2y!+${6c*79vpzP+gg2V6oTS?$fB9j$;9s59_deHI`}LvFD@nxQ?gF zToc4e1#>}=F7W?h9lH0lU9T!_w4{u-zJbe3Hr@phZSbEA*bx&nOg(ds1;%1d0C&S2 zOi4c^>O8pL*%MLQ+!2_FhiyH1*)j~Wqw5%KZ#b56UekSeNd2SxkKD_pX9sK2cGGSb zkVH1>>1x0}yQ51@==_|nrtKpNi&fGtl};ddF~Ns$D6cd2&(|e*p8A$C7eo}FgL8YM zgDBrAap}}=rI(WAT9ZDBy9^Ie+_r^~ zXw=9W(##4^$;Z+o-&TO=OMhbq`GWIe0ojvtY_<=n*^zR&N9gV$WdmqKYVCy|l9-r)e=gFm*#YYr%BVF+C2L?1J!W zFR~)dU_^(NU9UZl)2JJqGQN9aGTP4#;Q4Gd8LcR+l8SG3X~mkxZ}B;36esVdk@dchAu^P(e}}q{6mPjt z8Ik3KZIPVJ)=$U|LiA)G8?e$tgX};wR&zw^-@q#!=BXX9HrDFgR*s4H`k4D48vVD0 zY`F>8g>SVG1DNzF^Olsgb)H+Qqn0zIqzkaON${-q>vvPL)9XW-jXzhkPSVtc>FIfc z3B6EM=~bw^Q}1F(!<^{U-|?C=w}mH8Lzb z21H-<*bW@MFosO7VDC2!gly3`48K?NdC?QuHMYszdz1~thp_rj@wl>rbnmV5Mk(69 zgS6gXn3hvWvCJdoK_ff zsJPB@Ytyv|YuUfCQWrS!GX01pOv^PLl-}&RA9h`+dK?W1%5iVp@5?LKo}}Uw?(wkwE8^y_$fXAOv>}9XizkDzygaQNVNUNXS1KJ( zC9G#mbglf8?h1`ja}B6eZ|ng2-)Iuk3caCL{2?8Q_S{%=>o>+Rtfc~GJ>Zx-LriYn zgPKZC-n20554YWlMZ6UBI?OQ4^PeoJwts}?bH)DuNwxjY->Y2uLp0fe1A%DFMvz1s zTCofK#mFs0mMr=|(he)?^P=DXLf~r6)Oa(}k##G8?sUYGOgafZ>c9%xK1kF#7q}Kl zO>wAgM6&wW3w@-!h3)IZSnnSs9qSlgz-HJ1{CTEug`34$3y(7 zw6|nwcV&@Eo3WBQf4Wk|Q2iQC1RDOx1RX9ys#`P1i7{73rPR}o$S#!Kz^O(=SeWC3 z&W6R6QfxT3G4oc{qi^NE#G#G?Qmhbj=3}KrE?t;>j_U?fV`le@nj@Hd8U>oiVaAJF z6p$ykR4AN~k8s=05&gB$YuS@0W`DF(FP)NG1$s2jk5%a%9muyRK`6uulk8gxz3bBB zSeEVCM)e0Mh3+jru05Uo(CQZ*an#P-LYuXXD2^PnhuHl=dw%pxz)}|H1 zS{278|6-n4MR@O#t#$-rAY2Qr$RCU&%(g_!yb}HhzjOxtn6`FYmNgGk3VR*Xhv_44 zWS`VVBbv%$G_lEX<1zJJJ>WEY6378EXdhyys%UOlA^%+#*u(iE?u^sD-&LcVRIsL z4m%G~>Zh?@hS3DiFKYb!<0LHN0{qhyN`o`AZRU{^PjQo`tNzLN*Djbl*i#mZGd?s+ zGJ$?}Nn)eB$KcadF+fMQ05~yn7ppJ&h<;IyMo<|&V;eU@jv6wT5^m@H@^0{V`Y6W% zH^&#gmYg^ens^J?_&AyOqy@>PJ2dDv@)12#rE?*%f59iUYAyN-ysL?S1t7?e%=-0# z)V(6TAZoHo=B8Cs{6HLp3AV5%ebuhMy9UW#AE%Iz8NY95-01!#Bms?1=3!^*I|&}Hn}(L`Ci z#US^mFEcY=QJ~@wE>Zn=cVwdU~+03apHKWeA=rNmk z46$X%dPZw!qN+y{s5mLDf^~ev6d^GQ-67jW8CR^smc55Yfxm}yd*kvGMQ8D>O)J}WPbV7F6(!<-M(}$^#IvOJyz)IfV0+YH!7`~IW&iBjTl zJ?Jnp5~Z68ay{?=xFg|^M^x>0<1v4iYXgsFYF9L6LTse)a=Tcwgqe=VMgAurXQiS( z>T@kkKP7amJ4^KU0%C*11HET-V5MHdcym6XqqQ$~*FUv$+jBQd-I2}kfAA`jyzO7V zuP1?Vr31I!8!m7ONv-}VWt>PTv`8HzQnS^tNr5(_URDDnxFZuM}#+1Qx8>w;k41Knob8 zHt%?#M=z;2lvI?8&4S>A0 zekQ5?^>vo*#%vjoCz=SMqk!>(e?Ey_*LJPd_g@aM?;bU47TT z7R|6+&6kD`{1|rhBSE^eMN>U{7djrON#%m7yvZlbbhxZY|K2UuT?=O8HUMw{gV-0* z2NoxCF!*R?^R#az4I`R!^j)!5rg5L(P=h#{V*?Qq4$~w>az0}F(c-gHCax~o>NMhA z%~fTc_DfP+NCzDXmPGIRU^Mb%0xV2g3z}$axhUpI#k<}qG9F1jt6A65s5mlahc@Y> zLP$)33~fon12+|~mH}*!qmduCi&wtsw(MsyoE6W z6U#rOC!!TO(IMU-WUpXy-K8hv#jZ#H=J~^Z0?w>0Sw+t9$tFZx6ibr)={g>O#b z2~w_XhZ6!-jQm z1I;6C9?2=EcBXI}mpN%hHWSZ&JWz{w01z%4+NIjjgD_?yU=QZpzdcEfI zzDBJNeW35#Ul@GCcRex&Ns?#p@bnQi}J0KbP+ z*`=a;tr<5Z`yD$M)sWF8h0}qttLvpn{TZ44R+oot#`$Dp;p9HxmovQY;;=aOE$-g) zB;SWwAsg+8JN!2NdFn5pzY4`$MY@Drp`0U8g}Bv{{r@yG3xqGvy5la4@!!P&a!P=8GTYR0j!w@qAqZ<>%xlqMD9E7-#;YI6|W)84pEDb~J?W9#P;Er)x zHVE!qv)83?k`F;dyBCI=td{bL9ajP@SJEugUr7Wv2V@JK7`*mA`1?cSuV_m%Z3hTz zovDvBDN^{FFnMRz;GiJX5KcZn6AGZ#xgd=6`%K^jZpE{ldf+(=-*7El-2D=pyw{+O znhN#>rtSvLufkM~Q|jxMk+fYQ!kbk`wI8eC0XmNiWX~MrjScajyqNw-8RV|3O+Wbr z70$mBqiuG=7Nm0I#?Dlywu(;=^ZG;5`2uCGu#iIn%j z`6QE%*KiCpvurV`54lsM**ajem9Jt=Fx`mTO1ojmetgalTSise^)zQYce=P+X8Zc8 zpRfa9A@TJy1=NZ>tPed2dkUO-*GW&L#vUkdD+0Ub%fkEAWB*^4-jerfpw|AAO|$*H z=&=qU&lA3;v!=8%NhW4jw;WD{&zwXGj`j=eNX2Oz)Nokqg3xlWa#R8Tv@S!v(hAkdUCtMBRH z50Z{7l=S4lP~LO*?b4TwoiS`DJ?X1`j6rK9`*cP0N$HYa2CN6Is$mj(h3~{lp8o@B zbqRT*&~+y8i!m@ReVvYag*mV&>7o7XjS@HE2SUf2jEj^B=M7jolHGWqilq=Vm((;S zTPKzU?oMHe0W8=P)plh=IvSJ`i%+sk?>lo+HidOi8thWlEOaroJ-W4R^CH!^Tv@m&lWPatzNMk zu0Aaw_Y|$>LuH&TaDt7c9Jx?oV2GAUBhSIxTzC6h26nI5?+{0SaOK&#uDMutE+9g# z+x1G>Gx+8Y?(@Ha)?B%k*Y_iPu`NCdP-j1Yl5m_oZFrUR?phB4J9>HST59w3*~Z?; z)H!Y3F!+Vq|6W3$o7>+}4B$5T66QFq!f1y;72)$Kr7?5g>e@d8_IbDr!Zwf%c8tR#TV&59~ zM{_A`lb(cWCoS=3J%c5g$&Uk9s3$Vo?=27;1a zJzkUYF&nQ#f5XzTIIQ13!PB}@VQ7QQI(7xl{}_x+T7OiM`WO@2*I0Z)5Cctwl~I`4 zEzFGe^ZQKoHg`SsxkM11Gl%gefCFiD|8S>L)O5vS0yzS$LrMtB!tg%kVwe zht?C3gEu}k{u|pr8qB z%XmAs1wF0I?B7Kf&^)3sOG_)5J{%<8%#jcqY;BIE;Ear5?I?PcUHIDHcY^Wf+tF}b z?_XFY((X{>;ohTi5~3Gh@zNH1>rvupO_#{D+po)gSnS3xM`DGflrail-Bk+iDw{7UOX2GDDxmI^ypTkGpw zzQ9WWF(XbftGuT^tbE?;F5%Xxc?SsxyU4HvVcKvW7_u9|$g?1VmsZ!+r@w|?9UwKJ zE1qW;a3lH(nm=6%Z&ue^E&HWr@g z^xtRoZ`8js;fx-&_~#$2{{xSO|6=RB`-z(DaA50hEB=eE%P$nOZsBCf)4L_i8}qY`Y-@Q>{w{p++K7D(G!OOhPlQfnqv`Dn(R{fv=M`@>J0d4CP0H><`m#lL ziV6LA$%}u_ZrUN6f;5%!Sm>~-sJ9C6B>g-& zfpVU!)p2?Dp79a>N+vbYp*)ewlXb4;jOCUQFD$e8i+6Qz`(l~uJv92wE4JUXB+vhx zZGfWrlM@HFx*$7uL7H#1zLc{q7?VBbc&u;fR_TPmEwH^~qck15k!$vYPv7;pQV%}Dra_%{9%erc-4xp3;bsM*f&xq4lGmgF!5?lmIks&=@snk{ zn%5Ur6{vkz-EXh^2pvkYwe!>2@xFX3q0%okW--|GHv`MZs66#Pp+3ilQ`35_*4NZf z{i-j4Fk5N`Vl~R3+ytgOkc?ie?kx^ zSIyTV>N)F_wYB>T8TZ2ip) zx3i)aO~AVtf#tpk)%`kt^Muw_<7YJ^`*J^-r<5TEo!8nDj>I7CFhbC?Kw8QOa4pwp z4i6ZnN=P?54EYgzW>z%2c@0=Se$vRc&SN--{l^ui@j^(De4hAp2S+Xf$QhP}7u59nE z{!qNOd(IfH;Aq@4g^oRke$Mox&pf$l1ZwBY1zO9L!CSv>ZAJ%$L^uV)p>rI8(C<^i zRj0|>k}qGT=v)r`4*o9Beduah4T2sb5)cS3_nU;mI3g1D&3KdT}t~C$b<5jfKa>X$z{%Kh1E!dMZ@}xq( z7xG+2$Ub3q>IF36Tlv+?up91yibC};?O^Eb3p&c1-cb_hl&TIVS!DxJzsYJ{fG#{0Mygk%i6Acik*g2aT+spPJr_ zx2PUO#h1p{M3fweax)t#Aah$MTVJbm%s-hBP7ok2DEDIV5^$@f;iS2J%{kc$L^<0r z!}{n-gu46@QL~{6ZvBe@y09LsID}IZ6q~%h6ut{Qhd*3-;P#~Q-;V5azMf2_fVgn| z{!>h3$Q)e{u@-z)iZnklPD-~4Tb$4Hw`F*hyL)$zFR9LzGDc;+r@Vf5##h^Q;3JI# zj5)h=ntcg(ah;wrs4XigrO*f&bI}PcKXc-{|FILI`u`hOcRl3(qWIUszQ>^lE)tYP2cw9mU2B)^yUUJyUq!b1Gg2x%*2+ zeJS^(M^2hmewP)N6+|B+|Fep7{N3{lJ=&LO3jgHrz}^;1KP564E?5PtZ|}*h-1+@z zIAHl@Y37o9p3moB##8n0Y4A&WI9%7qR&Kc8G{ggTL@&xg!ThLuyxWk8oLSBzG!rCR=Ujmq;%nlpWz*#)h2d9xvM(PPrq%@V&ywC{|VN}Q`$R$KN6H+q(q zoSvtwd`tfL;rK{h@OJu#cJSMq_JEm$nU4D$VR_FULT@iX2$gt)$5>7Fkn`1MBwue0 z+*ortYi+1Z{vw3wgRP84$x~u3?6m4A&)@7&aD{eCsW!dI>fG{;B|KRHewOK&?zhD( zmJ?*qZQ+tH00b$hN_U2Fd2`o{Yobxd7NnR73k)2xYShL}?1kdJGlXr=a-gcxiHKf;tRMQbY7-$D!51tB|u@LlN(-4{p9}bM$bzC8o0Ugql zi<&|paIOI@Lgwd154!F_BlA4eyehccl#vVa^(nd4sTVUkG0EHaGrq;dZes;~3g7#2 zq1+vB;2xs+wsF7xNT?S@n=yALB!AU4A+S=lQ4D>keA9gch{#)j{M*+^&#cia5cyB#~;+$}ll z1`m~&-07DXA$4+oaau(~_@{Us9lue%V_?Tl5dbP3c+~VH~Xj=7fsVXwqBKyu;l~^4qMGwdrTU6G8)tG%HoZlIRWz7xeb68 zwYn!X{d0qHb9TZB7HpU`Ii`&7t8;ux*DYgA)MVKaVe{caoy~5VzJn~`ct|2&YWMz2 zdNN%s@K*ETRo!c-0#^>MX!pxKZATMTijdSAQ?$S$qtD%jD|nl?G)X~o`)OeqV$b`w zy-SK%m$ef+C@oB6cZ zUF^S({sx5W7r5NP(^ezb3pZ@O`NiL-m;OZyR<7}u{!63|IkxXv5fQi66OEwt6!~;D zMuWXZk6m+L4K)26yR)M0z8`hPH-HJi7PE#GgabAT8-Q-@u3Cd79OK=^n}-GfoA{yG zy7jcpY{bu(ih*iH{|dF$834e6Wgsi?|4quPXKQOr+bmdXe7ixR`vwf1tTVj_M0*Z_ z6(LOcz&z}-oKVKxk&Fj{gUHebgZYRdCKZuQu6Mu4JeSH^O^N@3piJ_CF+*Ss{UboT zg9?{rD|>PHrqqS7>{>_{oJ*S*y{%|PCX?S(lyz_;@=Yj_%2>=VK!nltq5=ge1*Obw=DmREc~bkD7~#c%RA35^|` z%-B`S6k0P-`#Ye9fyUAK%K4f7s<11O-a|wv-gy*PcK58II+|0{aqt*)r}Xn1V-IvY z^=@BE8~wb9Kp2A%!%ObS*P#H{U;Vxji-XlbDNx=Q=Hc-ZlNT8#M17HFHWhWHylzAW zio*t*Dt3bdSuS`Hl)73i>CIjYa9Q}K$gktE@MJvyz347v$QUQ)R5)X{YCrO+hAB9f zHa0+-On(I!5M)wFh$B#?GW>qnK>EKbyY_e}^R{odgHCp~gKP(;4QWGDIVFet>A;Y! zNl0>-Em?~sQlT+-M?_d8MNAtK(riK{&4eUPnIb0V!#L$QW5yUW_uSrVo@bx;^Ld}= z{k(^NeawCT?(6saUB~bBy*L!*_`O;kxI)w^uENJF+0e}PN#tI)inznQBHiS_5gbO3>H>zcAXjO+14n)fp^v*Wrb4_>NjamEt`^m7tayD!aJx*YJM9z> zIz1EK^49x77x1+(&*3Y6b+v2uP{fk0Xo0$E+Fg-*Sp+q8gAJ{2D(CE3EoYa0cG)>B z`2|5vsD>P;9`l#Uiu~=&^$O4neK)_LUn({5m@L+dYIJ7vDdLon{`1U!^I4T-hgCSC zNlTev8IB*!H1tX-bxA=cISNmjz4%v?KI*uu`9m8|mp%_P|KZ{`4|{pnRH&w^zHW4e z_KfXg%pPW^S(vojA0=SL4+DK2MBbADM!V~PZ7A%FL->#B5V?pbV~%odIGH8L)xJib z$&Rfj_LZ}kopjIwx`AXh)i;Tm% z1*~;~f!u6&^22ZNa(A66ipHEvR9J?bNSsguLyk-Ir zM|I$ijXV%>vmP?>rZpchx@($YBnWoy@1Q^PXikbN)&_pN5Os>)!sc}akAl9`E(I$9-q3)xL2SrFJsU=Nv zT^b}mh;1mAJ{7$1aDTaDw7?oBjgoFA{{VqN61qcQB*)cu@G2Y%&`Q&apK0C?+xHm| zHKrDvc+Vd=JU=yOOcSY~itd$|6F&d1`GS8}s{hYC!9V=jzZClXXDX*GKARkiRpIb3 ztDLYDN~E@1#DZ*ltm=$hq0rbEmjXZc*s|pkFb14tbGpSmBfpZKH*#LcuJxHl%ae1iC>h zTT*}L_pWc`MSD9*on-1p%tMhy3l{_(kRQLvIN6By-^^W0m2G?Jzr9>NF{{RBYH0MA z4NsINaatMSZmAil8?+%n65VPu+SDAF^)k$Rau4ZPCIn zNaL1yF4Ti&B!KNd^*!@lzRoEJC1UX`)IiTfSj`V846FGJO#l#q8-Qb|>E|ai&CnS4 zi2_~mH7&Kjnd%OO=OuYKJykrSwwJnXja&95f|e}!HO#z`Um>*3g$HvUWdzDOafxzp21@noEO+)M zfFv3QEU=e64Vd@ zyM6h+k-zjqV&k?u2~qC5l~Z~5n6jwqt-r~Pw+5@*=Y92Z4xTyC+@Bep+FEhCW8ITd zB@pU!(*&tACt)M)o&64c_`)jSLSkwu%6?2BU5yCa&lda+ei-OIZ{})6nGt+oRdrB1 z-W5lj?v0QKseu-J{trYx|5+Dq|1(vJIG>-JjE@&kT8rw`oRg`Y-nbj@HRGU5g{aW% zKq)Z)rnksR8K0dsRsvTlVhDs4FtjvUxm^L-lNBvQ>F}dJl%pf7DJ|dB+UHjD>Sl%$ zo&d$g(?_?CfXMvr<>QqaY54tOW6N_?!HKQ5Q{70Z@2IKPQ%A?Of(85xZ~y;I~3Ow+XiSxZF6MfZvxG>0M)A&0Vel}lVGuYy{4OO8uR<=eul zfN@5&Z&C~C^ozLRh-MEfBOC-b^B9ELc2>+7WflNuQ_2(24u^iA6lh?dW^JRX3Mk;J zsBya|VsQ16rqH&;sWKh{u+--e-fY6&DhiSARCGCOPN!#>s5_C-8?*#{sGMC(GM-gY z?!KWn^G}Y*Up=k3fVTf=vTGF*U^hLnj9%SEK)dguD&zPTnrcO|nyu9u-0G1J1Ucv7 zxHXfhx2(4a*w-ypA9{va)RbF%>GC^waOau8(;5@n$t*DcmwfO})@sEKOZ&)fMl3Pl za`B=Z+u0!^wU2MwnQr~Nt&UrRHd8WxOE^VJ^ z+lC|munCLCEYC@{!LbI_e^be>@K%gg*P`AenE_)a$REKHp?Cl3YpPh)6?8rqOJI+k zD8w)Q2-wP6j2m>XVmGd3Azy`l0$+%8-VyFV1k%H2KTcDuq+m}=i8wi1;oeJ<%>kKQ z8pL-hic-;!dl13)S(V#kV%L~|c^Ub?SxzAhuHE0xpaKtA9a37%lr4zHRf!@iR;-X9 z4VxH+6JV19$7T#RCJ7wl=mh)?RqTdLc78@@v=XO+V%S69G4aKW8C6M{%dE=I?`KtB z@p0d$pS<8Z2%5m6HB^lDzkC2VTPvbK*!6`2&LQsJ3eXVkPk50~L-7qb^OdJ(Ro3Cn z@7P1@iPE*`d&zU{iS+yjerF>-WPEMSy>q$}W+Ue^iT}EQbhJAgeH^Rkx(?pZ^S@-& z#DZw(?hdweGulUNZE?lmSNJ@lXr1(`e9h2w&xAEGfOs@T;roSoOr*`!7wU})g%(wg zb>RiSebpfjV{ODv*Fpw>lSfbdFea_esvjGOZO=**YVI`_ed6xoo)wy7D;W~iW#~uEAWQ?9T=qUB@h`D%E&Z@MXo+7)Sjw9t z6{}t1?|2=5+IoYvN7;_%sms)Ts(6Y1YLU4M(!1Pidcyd2bj0VQ;ql~pL>|WEQm&h_ zA&-EqSi>h{I!LM5Nb2;eEHDn4??+!3Wv8cAJhKhJd9~^j)~bT2=x(c(ZPi(8K8}qV z1forC?svcF*S+bZxvG!$TzJGw{RwtNuk>O~-!ghbSh@Yo=jfTNYCb^=hTOk?GSW~b z*r|s_w1l&Vht;gBKO{d7G&@rAG}Or9!VWACgPNU?e$nq#%3{o$InKcJ`wCB`{osK6 zz+MXWQSd!-OTF(;H7zogjS;a`1l&%{o%4kb*szG8=rP)ZkvG^@VL;D7<|c=5=89TJ zi5q2d@0d8KRj<1#Pd2Nf5eU%l$V%K1+97YhA&NIMLF)eKq#jlFshTQYqtp_+!Tq(O zP~x^W$5-s;4ET%eP=DgZy6K9bM0phG9btJHqm=sUBCnP@<4_L9)rBp~tm*pCXKg@6 z%)qXE3?%(yff2wWM1Z~+I+#f6tILi5{FyUX>QBA6rUJUzpNsMLe}7W+e_0hIt47Fb z=*PgnK6R6j+2#&Vo2zbl6ZJ^Xy0}@DJrqw>%!K%==oTO^1_At&M#cOq_>w=QTmi3_ zbiqJOnJ-NlOLGiZ%GIHIr(4>tPqpXVOL$n}Ysvd^N_BCtHR1Z7F2WyHf4uT*#*O@a zIuEA;Yk+x?a@`1vQ?K4$;2a#%% zS$lR%b6*mDynr0g`lp1vGch<9AAfR2ZT>2%}c#?{DyvjT?>T! zlwSlmcms3gLIopgE~DkEkmD%V!e1bdLp>a+HfW_Tz7VsJY6up;mXWH73uC!Q3x)s zK;*B}Ogk@lTAkj}cVL0{izczd`P&EAy|0{OcjfB5Tm0>u*(*O1MD^R~*5s?g)E}vD|BD=4FN7jy$!Cm>+H<2+T(kO)I zwlm!*rZWmLbH~}NQ<6vLdXd-cg)~D!JL@MQE2sI;0YYZOLY<&@JCKqm#2E#GEu2-k z73s8iR>hh>t!;*{0`eswnxF{~eA!tPI|({BDnOBe4t3E8Ahb*dkmxh7L8b~`VeTMN zlg=sl&=YUk_8MRR8UFX?%5uw_@Uf8`ssU{mZ1r>9@El4MQq1Zn4tIGpZ9ym1iLGvl z7|y`w&Z@k({-;1d^T1c}v*s<8ww+lY#+Td`u8X3~oM}8sf6JIz zlsINki=z~v=|cyI&XCdKK2yBK)99PvPjknl+Ig&8VL&zPV9A1B&j!@bZ#6B92tyE- zDjp);FBk$l!o@O;R@cgj3tC!2{PD>bf#?$&(O2_E;@81CP5OK(kW|KDha3E}eiC|R z7q-^{>gh;4mImTW(*L&8DbgUM1N%yr-1LUfA`cu6L_{b!iVU{X`yyTCzkoqL!m_3} zv2=IjMT-LoJ)pH2XKp5StC@$+ctq?wy{M4g9WZE)`ZHP+-*{=4!Bodq!W+qm@r$KJ zm|NIFEC*^|JKpm)IOv~T-uZ&v;$G>1@)wxs-zVy@O6%h4jnsE5$0xsH%w#1HeyVb1 zV%rsO721kBQ$ec^OEoAZLc;Yk@q7ET6^i=Ye2xV}y8N-ph|D(cv}KH5=38VeeUfq+ zB)y1|8*!3R(|3yZw&{Z?NVb51tCB?4oe559ZPKmq0@oSMfg8Vjoap*??bN&2(F)}5 zSB>DluclmFya>%bRletn%`F_@5fRA6p4j0o^?sN7(}}W$`5+Sx#yxb?^(M-VLA&8m z+UNEFRYdF0lpci@ki!;VhmMIcG4d^jWcX%udul%6oxKgWq9wqLQWj;%D1VeeE`{9T zX*V>`mndHD1ZDrU8Q_$IaFL<(lDu)-+~{ibYRWy|p88AnBbn=o_=b`qd^u)tlC5TG zO)oblVm6!8DrkfwB;5dViYa1lmpAstA*|%Yu{(SmjA?fW+-97^K?Z@#^anT>pcNJs z_D~heNpMQK(+T+coTo4ptc_A0E3592M9&qx)rtBo^p*?%RLCkjk4uDgu|qr1f(IBo zqSK8PsMGj*U!GzHy?u!jGW4Qg3tQ?sGp~X=RiBqQ?Q!h3#ss_hKyQfU1D_VJ;06j_ zkexE^!8w>mt%@Fz9+-#fO^4PU1KVU`cEj?$BQ!=BuU^68^pk7!QkUGzf$X-XEgCMr zRXezoQT8L$)W~Z+a7hs|4z)-F#>=14dEbo1ZnGI^Qrk_5qPFunMYs2Csq44hl!^&! z2lnYiPX~wJSv03@tzsa4i6_*V1BCouUKRBMPGG$0Ql}mDzUy}g&XYWnls4e{uIp(U zreW5DzET7@Ngd^g^vKBTEK=-GYBdAdD}FzGm^eW&noefFq}ySkiX3ZGyun(}dYkl< zQ2QWnbDn3ju$_Y^E`XFmFc|{QnuzCg2e0Qn18aBMCXVM5KNGN}zLH$X03E|WWNikD zIOfk4JNBmvO{eV=1|2W#nI>t*U%7wa<&7CWaj*U$S8BX8iQS!!o9p3mH?u)Xo-1Qz UVcq7-621AGyw?u&x6iix2h44BWdHyG literal 0 HcmV?d00001 -- Gitee From 3c341cafe6b9476e0fca240c7397f82a07ed9d0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=87=E7=A5=96=E6=B6=9B?= <1025494833@qq.com> Date: Thu, 21 Mar 2024 02:41:25 +0000 Subject: [PATCH 11/18] update Samples/DetectionRetrainingAndInfer/README.md. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 万祖涛 <1025494833@qq.com> --- Samples/DetectionRetrainingAndInfer/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Samples/DetectionRetrainingAndInfer/README.md b/Samples/DetectionRetrainingAndInfer/README.md index c3a981c..ed4f9c3 100644 --- a/Samples/DetectionRetrainingAndInfer/README.md +++ b/Samples/DetectionRetrainingAndInfer/README.md @@ -158,7 +158,8 @@ ``` 执行成功后,omInfer/output目录下会生成检测输出图片 -![输入图片说明](omInfer/out_0.jpg) + + ![输入图片说明](omInfer/out_0.jpg) #### 相关操作 \ No newline at end of file -- Gitee From a412f23dbcfe7c159e2a13a6f054ba32d52f3c61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=87=E7=A5=96=E6=B6=9B?= <1025494833@qq.com> Date: Thu, 21 Mar 2024 02:42:28 +0000 Subject: [PATCH 12/18] add Samples/DetectionRetrainingAndInfer/requirements.txt. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 万祖涛 <1025494833@qq.com> --- Samples/DetectionRetrainingAndInfer/requirements.txt | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 Samples/DetectionRetrainingAndInfer/requirements.txt diff --git a/Samples/DetectionRetrainingAndInfer/requirements.txt b/Samples/DetectionRetrainingAndInfer/requirements.txt new file mode 100644 index 0000000..c70a80e --- /dev/null +++ b/Samples/DetectionRetrainingAndInfer/requirements.txt @@ -0,0 +1,5 @@ +onnx +numpy +opencv-python +protobuf==3.20.2 +tensorboard \ No newline at end of file -- Gitee From ef8a350ccf88d6037012c933ededee7156cdc1f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=87=E7=A5=96=E6=B6=9B?= <1025494833@qq.com> Date: Thu, 21 Mar 2024 02:47:59 +0000 Subject: [PATCH 13/18] update Samples/DetectionRetrainingAndInfer/README.md. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 万祖涛 <1025494833@qq.com> --- Samples/DetectionRetrainingAndInfer/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Samples/DetectionRetrainingAndInfer/README.md b/Samples/DetectionRetrainingAndInfer/README.md index ed4f9c3..d9ed652 100644 --- a/Samples/DetectionRetrainingAndInfer/README.md +++ b/Samples/DetectionRetrainingAndInfer/README.md @@ -77,7 +77,7 @@ ``` -4. 数据集处理,分出训练集合测试集. +4. 数据集处理,分出训练集和验证集. ``` cd .. python3 predata.py -- Gitee From b45269abde0ef2d33e4f7af6fa75ca8d27848095 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=87=E7=A5=96=E6=B6=9B?= <1025494833@qq.com> Date: Thu, 21 Mar 2024 02:53:43 +0000 Subject: [PATCH 14/18] 11 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 万祖涛 <1025494833@qq.com> --- .../DetectionRetrainingAndInfer/predata.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/Samples/DetectionRetrainingAndInfer/predata.py b/Samples/DetectionRetrainingAndInfer/predata.py index 1eecee1..6b1030a 100644 --- a/Samples/DetectionRetrainingAndInfer/predata.py +++ b/Samples/DetectionRetrainingAndInfer/predata.py @@ -6,8 +6,8 @@ def preparinbdata(main_xml_file, main_img_file, train_size, val_size): source_img = main_img_file + "/" + metarial[i] + ".png" mstring = metarial[i] - train_destination_xml = "./data/mask/train/labels" + "/" + metarial[i] + ".xml" - train_destination_png = "./data/mask/train/images" + "/" + metarial[i] + ".png" + train_destination_xml = "./dataset/mask/train/labels" + "/" + metarial[i] + ".xml" + train_destination_png = "./dataset/mask/train/images" + "/" + metarial[i] + ".png" shutil.copy(source_xml, train_destination_xml) shutil.copy(source_img, train_destination_png) @@ -17,15 +17,15 @@ def preparinbdata(main_xml_file, main_img_file, train_size, val_size): source_img = main_img_file + "/" + metarial[n] + ".png" mstring = metarial[n] - val_destination_xml = "./data/mask/val/labels" + "/" + metarial[n] + ".xml" - val_destination_png = "./data/mask/val/images" + "/" + metarial[n] + ".png" + val_destination_xml = "./dataset/mask/val/labels" + "/" + metarial[n] + ".xml" + val_destination_png = "./dataset/mask/val/images" + "/" + metarial[n] + ".png" shutil.copy(source_xml, val_destination_xml) shutil.copy(source_img, val_destination_png) if __name__ == '__main__': metarial = [] - for i in os.listdir("./data/images"): + for i in os.listdir("./dataset/images"): str = i[:-4] metarial.append(str) train_size = int(len(metarial) * 0.7) @@ -33,15 +33,15 @@ if __name__ == '__main__': print("Sum of image: ", len(metarial)) print("Sum of the train size: ", train_size) print("Sum of the val size: ", val_size) - if not os.path.exists("./data/mask"): - os.mkdir('./data/mask') - os.mkdir('./data/mask/train') - os.mkdir('./data/mask/val') - os.mkdir('./data/mask/train/images') - os.mkdir('./data/mask/train/labels') - os.mkdir('./data/mask/val/images') - os.mkdir('./data/mask/val/labels') - preparinbdata(main_xml_file = "./data/annotations", - main_img_file = "./data/images", + if not os.path.exists("./dataset/mask"): + os.mkdir('./dataset/mask') + os.mkdir('./dataset/mask/train') + os.mkdir('./dataset/mask/val') + os.mkdir('./dataset/mask/train/images') + os.mkdir('./dataset/mask/train/labels') + os.mkdir('./dataset/mask/val/images') + os.mkdir('./dataset/mask/val/labels') + preparinbdata(main_xml_file = "./dataset/annotations", + main_img_file = "./dataset/images", train_size = train_size, val_size = val_size) -- Gitee From 1bd2f88e89d534174e9849b2da40bae9999a9df0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=87=E7=A5=96=E6=B6=9B?= <1025494833@qq.com> Date: Thu, 21 Mar 2024 02:55:49 +0000 Subject: [PATCH 15/18] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=96=87=E4=BB=B6=20Sa?= =?UTF-8?q?mples/DetectionRetrainingAndInfer/main.cpp?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Samples/DetectionRetrainingAndInfer/main.cpp | 129 ------------------- 1 file changed, 129 deletions(-) delete mode 100644 Samples/DetectionRetrainingAndInfer/main.cpp diff --git a/Samples/DetectionRetrainingAndInfer/main.cpp b/Samples/DetectionRetrainingAndInfer/main.cpp deleted file mode 100644 index f74d510..0000000 --- a/Samples/DetectionRetrainingAndInfer/main.cpp +++ /dev/null @@ -1,129 +0,0 @@ -#include -#include -#include -#include -#include "acllite_dvpp_lite/ImageProc.h" -#include "acllite_om_execute/ModelProc.h" - -using namespace std; -using namespace acllite; -typedef struct BoundBox { - float x; - float y; - float width; - float height; - float score; - int classIndex; -} BoundBox; -float iou(BoundBox box1, BoundBox box2) -{ - float xLeft = max(box1.x, box2.x); - float yTop = max(box1.y, box1.y); - float xRight = min(box1.x + box1.width, box1.x + box1.width); - float yBottom = min(box1.y + box1.height, box1.y + box1.height); - float width = max(0.0f, xRight - xLeft); - float hight = max(0.0f, yBottom - yTop); - float area = width * hight; - float iou = area / (box1.width * box1.height + box2.width * box2.height - area); - return iou; -} -bool sortScore(BoundBox box1, BoundBox box2) -{ - return box1.score > box2.score; -} -int main() -{ - vector labels = { {"with_mask"},{"mask_weared_incorrect"},{"without_mask"}}; - AclLiteResource aclResource; - bool ret = aclResource.Init(); - CHECK_RET(ret, LOG_PRINT("[ERROR] InitACLResource failed."); return 1); - - ImageProc imageProc; - ModelProc modelProc; - ret = modelProc.Load("../model/ssd-mobilenet.om"); - CHECK_RET(ret, LOG_PRINT("[ERROR] load model Resnet18.om failed."); return 1); - ImageData src = imageProc.Read("../data/8.jpg"); - CHECK_RET(src.size, LOG_PRINT("[ERROR] ImRead image failed."); return 1); - - ImageData dst; - ImageSize dsize(300, 300); - - imageProc.Resize(src, dst, dsize); - ret = modelProc.CreateInput(static_cast(dst.data.get()), dst.size); - CHECK_RET(ret, LOG_PRINT("[ERROR] Create model input failed."); return 1); - vector inferOutputs; - ret = modelProc.Execute(inferOutputs); - CHECK_RET(ret, LOG_PRINT("[ERROR] model execute failed."); return 1); - - uint32_t dataSize = inferOutputs[0].size; - // get result from output data set - float* scores = static_cast(inferOutputs[0].data.get()); - float* boxes = static_cast(inferOutputs[1].data.get()); - if (scores == nullptr || boxes == nullptr) { - LOG_PRINT("get result from output data set failed."); - return 1; - } - size_t classNum = 3; - size_t boxes_nums = 3000; - size_t candidate_size = 200; - size_t top_k = 20; - size_t prob_threshold = 0.4; - size_t iou_threshold = 0.4; - int half = 2; - const double fountScale = 0.5; - const uint32_t lineSolid = 2; - const uint32_t labelOffset = 11; - const cv::Scalar fountColor(0, 0, 255); - const vector colors{ - cv::Scalar(237, 149, 100), cv::Scalar(0, 215, 255), - cv::Scalar(50, 205, 50), cv::Scalar(139, 85, 26)}; - cv::Mat srcImage = cv::imread(imagePath); - for(int class = 0; i < classNum; class++) { - vector box_scores; - vector result; - for(int j = 0; j < boxes_nums; ++j){ - if(scores[j * classNum + class]) > prob_threshold{ - BoundBox box; - box.score = scores[j * classNum + class]; - box.x = boxes[4 * boxes_nums]; - box.y = boxes[4 * boxes_nums + 1]; - box.width = boxes[4 * boxes_nums + 2]; - box.height = boxes[4 * boxes_nums + 3]; - box.classIndex = class; - box_scores.push_back(box); - } - } - std::sort(box_scores.begin(),box_scores.end(),sortScore); - box_scores.erase(box_scores.begin() + candidate_size + 1, box_scores.end() + 1); - int len = box_scores.length(); - if(len > 0){ - for(int i = 0;i < box_scores.length(); i++){ - if(result.length() == top_k) break; - result.push_back(box_scores[i]); - for(int j = i + 1; j < box_scores.length();j++){ - float iou_t = iou(box_scores[i],box_scores[j]); - if(iou_t > iout_threshold){ - box_scores.erase(box_scores.begin() + j); - } - } - } - } - for (size_t i = 0; i < result.size(); ++i) { - cv::Point leftUpPoint, rightBottomPoint; - leftUpPoint.x = result[i].x - result[i].width / half; - leftUpPoint.y = result[i].y - result[i].height / half; - rightBottomPoint.x = result[i].x + result[i].width / half; - rightBottomPoint.y = result[i].y + result[i].height / half; - cv::rectangle(srcImage, leftUpPoint, rightBottomPoint, colors[i % colors.size()], lineSolid); - string className = label[result[i].classIndex]; - string markString = to_string(result[i].score) + ":" + className; - cv::putText(srcImage, markString, cv::Point(leftUpPoint.x, leftUpPoint.y + labelOffset), - cv::FONT_HERSHEY_COMPLEX, fountScale, fountColor); - } - - } - string savePath = "../output/out_0.jpg"; - cv::imwrite(savePath, srcImage); - return 0; -} - -- Gitee From 4672ba820dd3a70e1cfae69324bbc2c742877344 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=87=E7=A5=96=E6=B6=9B?= <1025494833@qq.com> Date: Thu, 21 Mar 2024 02:56:32 +0000 Subject: [PATCH 16/18] rename Samples/DetectionRetrainingAndInfer/train.py to Samples/DetectionRetrainingAndInfer/main.py. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 万祖涛 <1025494833@qq.com> --- Samples/DetectionRetrainingAndInfer/{train.py => main.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename Samples/DetectionRetrainingAndInfer/{train.py => main.py} (100%) diff --git a/Samples/DetectionRetrainingAndInfer/train.py b/Samples/DetectionRetrainingAndInfer/main.py similarity index 100% rename from Samples/DetectionRetrainingAndInfer/train.py rename to Samples/DetectionRetrainingAndInfer/main.py -- Gitee From 208cf6e5b0b5f776ca67523b1e021508417fd842 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=87=E7=A5=96=E6=B6=9B?= <1025494833@qq.com> Date: Thu, 21 Mar 2024 03:03:48 +0000 Subject: [PATCH 17/18] update Samples/DetectionRetrainingAndInfer/README.md. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 万祖涛 <1025494833@qq.com> --- Samples/DetectionRetrainingAndInfer/README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/Samples/DetectionRetrainingAndInfer/README.md b/Samples/DetectionRetrainingAndInfer/README.md index d9ed652..f35ccfe 100644 --- a/Samples/DetectionRetrainingAndInfer/README.md +++ b/Samples/DetectionRetrainingAndInfer/README.md @@ -82,9 +82,15 @@ cd .. python3 predata.py ``` -5. 运行训练脚本。 +5. 下载预训练模型 + ``` + cd models + wget https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/wanzutao/detection/mobilenet-v1-ssd-mp-0_675.pth + ``` +6. 运行训练脚本。 ``` + cd .. python3 main.py ``` 训练完成后,权重文件保存在models目录下,并输出模型训练精度和性能信息。 -- Gitee From 8fba750e664aec847408d1028a723cfcc8a12131 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=87=E7=A5=96=E6=B6=9B?= <1025494833@qq.com> Date: Thu, 21 Mar 2024 03:38:29 +0000 Subject: [PATCH 18/18] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=96=87=E4=BB=B6=20Sa?= =?UTF-8?q?mples/DetectionRetrainingAndInfer/vision/nn/squeezenet.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../vision/nn/squeezenet.py | 130 ------------------ 1 file changed, 130 deletions(-) delete mode 100644 Samples/DetectionRetrainingAndInfer/vision/nn/squeezenet.py diff --git a/Samples/DetectionRetrainingAndInfer/vision/nn/squeezenet.py b/Samples/DetectionRetrainingAndInfer/vision/nn/squeezenet.py deleted file mode 100644 index d961678..0000000 --- a/Samples/DetectionRetrainingAndInfer/vision/nn/squeezenet.py +++ /dev/null @@ -1,130 +0,0 @@ -import math -import torch -import torch.nn as nn -import torch.nn.init as init -import torch.utils.model_zoo as model_zoo - - -__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] - - -model_urls = { - 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', - 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth', -} - - -class Fire(nn.Module): - - def __init__(self, inplanes, squeeze_planes, - expand1x1_planes, expand3x3_planes): - super(Fire, self).__init__() - self.inplanes = inplanes - self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) - self.squeeze_activation = nn.ReLU(inplace=True) - self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, - kernel_size=1) - self.expand1x1_activation = nn.ReLU(inplace=True) - self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, - kernel_size=3, padding=1) - self.expand3x3_activation = nn.ReLU(inplace=True) - - def forward(self, x): - x = self.squeeze_activation(self.squeeze(x)) - return torch.cat([ - self.expand1x1_activation(self.expand1x1(x)), - self.expand3x3_activation(self.expand3x3(x)) - ], 1) - - -class SqueezeNet(nn.Module): - - def __init__(self, version=1.0, num_classes=1000): - super(SqueezeNet, self).__init__() - if version not in [1.0, 1.1]: - raise ValueError("Unsupported SqueezeNet version {version}:" - "1.0 or 1.1 expected".format(version=version)) - self.num_classes = num_classes - if version == 1.0: - self.features = nn.Sequential( - nn.Conv2d(3, 96, kernel_size=7, stride=2), - nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), - Fire(96, 16, 64, 64), - Fire(128, 16, 64, 64), - Fire(128, 32, 128, 128), - nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), - Fire(256, 32, 128, 128), - Fire(256, 48, 192, 192), - Fire(384, 48, 192, 192), - Fire(384, 64, 256, 256), - nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), - Fire(512, 64, 256, 256), - ) - else: - self.features = nn.Sequential( - nn.Conv2d(3, 64, kernel_size=3, stride=2), - nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=3, stride=2), - Fire(64, 16, 64, 64), - Fire(128, 16, 64, 64), - nn.MaxPool2d(kernel_size=3, stride=2), - Fire(128, 32, 128, 128), - Fire(256, 32, 128, 128), - nn.MaxPool2d(kernel_size=3, stride=2), - Fire(256, 48, 192, 192), - Fire(384, 48, 192, 192), - Fire(384, 64, 256, 256), - Fire(512, 64, 256, 256), - ) - # Final convolution is initialized differently form the rest - final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) - self.classifier = nn.Sequential( - nn.Dropout(p=0.5), - final_conv, - nn.ReLU(inplace=True), - nn.AvgPool2d(13, stride=1) - ) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - if m is final_conv: - init.normal_(m.weight, mean=0.0, std=0.01) - else: - init.kaiming_uniform_(m.weight) - if m.bias is not None: - init.constant_(m.bias, 0) - - def forward(self, x): - x = self.features(x) - x = self.classifier(x) - return x.view(x.size(0), self.num_classes) - - -def squeezenet1_0(pretrained=False, **kwargs): - r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level - accuracy with 50x fewer parameters and <0.5MB model size" - `_ paper. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = SqueezeNet(version=1.0, **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_0'])) - return model - - -def squeezenet1_1(pretrained=False, **kwargs): - r"""SqueezeNet 1.1 model from the `official SqueezeNet repo - `_. - SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters - than SqueezeNet 1.0, without sacrificing accuracy. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = SqueezeNet(version=1.1, **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1'])) - return model -- Gitee