From a221ef884caaf4cd8613d7324827d19a384ca849 Mon Sep 17 00:00:00 2001 From: Guanzhong Chen Date: Mon, 11 Dec 2023 15:54:51 +0800 Subject: [PATCH 1/2] 1 --- AscendIE/TorchAIE/built-in/cv/ocr/PSENet/README.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 AscendIE/TorchAIE/built-in/cv/ocr/PSENet/README.md diff --git a/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/README.md b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/README.md new file mode 100644 index 0000000000..56a6051ca2 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/README.md @@ -0,0 +1 @@ +1 \ No newline at end of file -- Gitee From 78dc3bd3ddfc99e0ecfc6da69ed0981418cee47e Mon Sep 17 00:00:00 2001 From: Guanzhong Chen Date: Mon, 11 Dec 2023 15:57:42 +0800 Subject: [PATCH 2/2] psenet --- .../Post-processing/Algorithm_DetEva.py | 349 +++++++++++ .../PSENet/Post-processing/Algorithm_IoU.py | 424 ++++++++++++++ .../cv/ocr/PSENet/Post-processing/bintotxt.py | 137 +++++ .../Post-processing/rrc_evaluation_funcs.py | 394 +++++++++++++ .../cv/ocr/PSENet/Post-processing/script.py | 37 ++ .../TorchAIE/built-in/cv/ocr/PSENet/README.md | 169 +++++- .../TorchAIE/built-in/cv/ocr/PSENet/export.py | 76 +++ .../built-in/cv/ocr/PSENet/export_onnx.py | 54 ++ .../cv/ocr/PSENet/fpn_resnet_nearest.py | 540 ++++++++++++++++++ .../TorchAIE/built-in/cv/ocr/PSENet/perf.py | 114 ++++ .../ocr/PSENet/preprocess_psenet_pytorch.py | 64 +++ .../TorchAIE/built-in/cv/ocr/PSENet/pypse.py | 63 ++ .../built-in/cv/ocr/PSENet/requirements.txt | 192 +++++++ .../TorchAIE/built-in/cv/ocr/PSENet/run.py | 145 +++++ .../TorchAIE/built-in/cv/ocr/PSENet/url.ini | 6 + 15 files changed, 2763 insertions(+), 1 deletion(-) create mode 100644 AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/Algorithm_DetEva.py create mode 100644 AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/Algorithm_IoU.py create mode 100644 AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/bintotxt.py create mode 100644 AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/rrc_evaluation_funcs.py create mode 100644 AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/script.py create mode 100644 AscendIE/TorchAIE/built-in/cv/ocr/PSENet/export.py create mode 100644 AscendIE/TorchAIE/built-in/cv/ocr/PSENet/export_onnx.py create mode 100644 AscendIE/TorchAIE/built-in/cv/ocr/PSENet/fpn_resnet_nearest.py create mode 100644 AscendIE/TorchAIE/built-in/cv/ocr/PSENet/perf.py create mode 100644 AscendIE/TorchAIE/built-in/cv/ocr/PSENet/preprocess_psenet_pytorch.py create mode 100644 AscendIE/TorchAIE/built-in/cv/ocr/PSENet/pypse.py create mode 100644 AscendIE/TorchAIE/built-in/cv/ocr/PSENet/requirements.txt create mode 100644 AscendIE/TorchAIE/built-in/cv/ocr/PSENet/run.py create mode 100644 AscendIE/TorchAIE/built-in/cv/ocr/PSENet/url.ini diff --git a/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/Algorithm_DetEva.py b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/Algorithm_DetEva.py new file mode 100644 index 0000000000..da3d7956d6 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/Algorithm_DetEva.py @@ -0,0 +1,349 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +''' +Algorithm named DetEval +It is slightly different from original algorithm(see https://perso.liris.cnrs.fr/christian.wolf/software/deteval/index.html) +Please read《 Object Count / Area Graphs for the Evaluation of Object Detection and Segmentation Algorithms 》for details +''' +from collections import namedtuple +import rrc_evaluation_funcs +import importlib + +def evaluation_imports(): + """ + evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation. + """ + return { + 'math': 'math', + 'numpy': 'np' + } + + +def default_evaluation_params(): + """ + default_evaluation_params: Default parameters to use for the validation and evaluation. + """ + return { + 'AREA_RECALL_CONSTRAINT': 0.8, + 'AREA_PRECISION_CONSTRAINT': 0.4, + 'EV_PARAM_IND_CENTER_DIFF_THR': 1, + 'MTYPE_OO_O': 1., + 'MTYPE_OM_O': 0.8, + 'MTYPE_OM_M': 1., + 'GT_SAMPLE_NAME_2_ID': 'gt_img_([0-9]+).txt', + 'DET_SAMPLE_NAME_2_ID': 'res_img_([0-9]+).txt', + 'CRLF': False # Lines are delimited by Windows CRLF format + } + + +def validate_data(gtFilePath, submFilePath, evaluationParams): + """ + Method validate_data: validates that all files in the results folder are correct (have the correct name contents). + Validates also that there are no missing files in the folder. + If some error detected, the method raises the error + """ + gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID']) + + subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True) + + # Validate format of GroundTruth + for k in gt: + rrc_evaluation_funcs.validate_lines_in_file(k, gt[k], evaluationParams['CRLF'], True, True) + + # Validate format of results + for k in subm: + if (k in gt) == False: + raise Exception("The sample %s not present in GT" % k) + + rrc_evaluation_funcs.validate_lines_in_file(k, subm[k], evaluationParams['CRLF'], True, True) + + +def evaluate_method(gtFilePath, submFilePath, evaluationParams): + """ + Method evaluate_method: evaluate method and returns the results + Results. Dictionary with the following values: + - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 } + - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 } + """ + + for module, alias in evaluation_imports().iteritems(): + globals()[alias] = importlib.import_module(module) + + def one_to_one_match(row, col): + cont = 0 + for j in range(len(recallMat[0])): + if recallMat[row, j] >= evaluationParams['AREA_RECALL_CONSTRAINT'] and precisionMat[row, j] >= \ + evaluationParams['AREA_PRECISION_CONSTRAINT']: + cont = cont + 1 + if (cont != 1): + return False + cont = 0 + for i in range(len(recallMat)): + if recallMat[i, col] >= evaluationParams['AREA_RECALL_CONSTRAINT'] and precisionMat[i, col] >= \ + evaluationParams['AREA_PRECISION_CONSTRAINT']: + cont = cont + 1 + if (cont != 1): + return False + + if recallMat[row, col] >= evaluationParams['AREA_RECALL_CONSTRAINT'] and precisionMat[row, col] >= \ + evaluationParams['AREA_PRECISION_CONSTRAINT']: + return True + return False + + def one_to_many_match(gtNum): + many_sum = 0 + detRects = [] + for detNum in range(len(recallMat[0])): + if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and detNum not in detDontCareRectsNum: + if precisionMat[gtNum, detNum] >= evaluationParams['AREA_PRECISION_CONSTRAINT']: + many_sum += recallMat[gtNum, detNum] + detRects.append(detNum) + if many_sum >= evaluationParams['AREA_RECALL_CONSTRAINT']: + return True, detRects + else: + return False, [] + + def many_to_one_match(detNum): + many_sum = 0 + gtRects = [] + for gtNum in range(len(recallMat)): + if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and gtNum not in gtDontCareRectsNum: + if recallMat[gtNum, detNum] >= evaluationParams['AREA_RECALL_CONSTRAINT']: + many_sum += precisionMat[gtNum, detNum] + gtRects.append(gtNum) + if many_sum >= evaluationParams['AREA_PRECISION_CONSTRAINT']: + return True, gtRects + else: + return False, [] + + def area(a, b): + dx = min(a.xmax, b.xmax) - max(a.xmin, b.xmin) + 1 + dy = min(a.ymax, b.ymax) - max(a.ymin, b.ymin) + 1 + if (dx >= 0) and (dy >= 0): + return dx * dy + else: + return 0. + + def center(r): + x = float(r.xmin) + float(r.xmax - r.xmin + 1) / 2. + y = float(r.ymin) + float(r.ymax - r.ymin + 1) / 2. + return Point(x, y) + + def point_distance(r1, r2): + distx = math.fabs(r1.x - r2.x) + disty = math.fabs(r1.y - r2.y) + return math.sqrt(distx * distx + disty * disty) + + def center_distance(r1, r2): + return point_distance(center(r1), center(r2)) + + def diag(r): + w = (r.xmax - r.xmin + 1) + h = (r.ymax - r.ymin + 1) + return math.sqrt(h * h + w * w) + + perSampleMetrics = {} + + methodRecallSum = 0 + methodPrecisionSum = 0 + + Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') + Point = namedtuple('Point', 'x y') + + gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID']) + subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True) + + numGt = 0 + numDet = 0 + + for resFile in gt: + + gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile]) + recall = 0 + precision = 0 + hmean = 0 + recallAccum = 0. + precisionAccum = 0. + gtRects = [] + detRects = [] + gtPolPoints = [] + detPolPoints = [] + gtDontCareRectsNum = [] # Array of Ground Truth Rectangles' keys marked as don't Care + detDontCareRectsNum = [] # Array of Detected Rectangles' matched with a don't Care GT + pairs = [] + evaluationLog = "" + + recallMat = np.empty([1, 1]) + precisionMat = np.empty([1, 1]) + + pointsList, _, transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile, + evaluationParams[ + 'CRLF'], + True, True, + False) + for n in range(len(pointsList)): + points = pointsList[n] + transcription = transcriptionsList[n] + dontCare = transcription == "###" + + # convert x1,y1,x2,y2,x3,y3,x4,y4 to xmin,ymin,xmax,ymax + if len(points) == 8: + points_tmp = np.array(points).reshape(4, 2) + points_x = points_tmp[:, 0] + points_y = points_tmp[:, 1] + xmin = points_x[np.argmin(points_x)] + xmax = points_x[np.argmax(points_x)] + ymin = points_y[np.argmin(points_y)] + ymax = points_y[np.argmax(points_y)] + points = [xmin, ymin, xmax, ymax] + gtRect = Rectangle(*points) + gtRects.append(gtRect) + gtPolPoints.append(points) + if dontCare: + gtDontCareRectsNum.append(len(gtRects) - 1) + + evaluationLog += "GT rectangles: " + str(len(gtRects)) + ( + " (" + str(len(gtDontCareRectsNum)) + " don't care)\n" if len(gtDontCareRectsNum) > 0 else "\n") + + if resFile in subm: + detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile]) + pointsList, _, _ = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile, + evaluationParams['CRLF'], + True, True, False) + for n in range(len(pointsList)): + points = pointsList[n] + # print points + detRect = Rectangle(*points) + detRects.append(detRect) + detPolPoints.append(points) + if len(gtDontCareRectsNum) > 0: + for dontCareRectNum in gtDontCareRectsNum: + dontCareRect = gtRects[dontCareRectNum] + intersected_area = area(dontCareRect, detRect) + rdDimensions = ((detRect.xmax - detRect.xmin + 1) * (detRect.ymax - detRect.ymin + 1)) + if (rdDimensions == 0): + precision = 0 + else: + precision = intersected_area / rdDimensions + if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT']): + detDontCareRectsNum.append(len(detRects) - 1) + break + + evaluationLog += "DET rectangles: " + str(len(detRects)) + ( + " (" + str(len(detDontCareRectsNum)) + " don't care)\n" if len(detDontCareRectsNum) > 0 else "\n") + + if len(gtRects) == 0: + recall = 1 + precision = 0 if len(detRects) > 0 else 1 + + if len(detRects) > 0: + # Calculate recall and precision matrixs + outputShape = [len(gtRects), len(detRects)] + recallMat = np.empty(outputShape) + precisionMat = np.empty(outputShape) + gtRectMat = np.zeros(len(gtRects), np.int8) + detRectMat = np.zeros(len(detRects), np.int8) + for gtNum in range(len(gtRects)): + for detNum in range(len(detRects)): + rG = gtRects[gtNum] + rD = detRects[detNum] + intersected_area = area(rG, rD) + rgDimensions = ((rG.xmax - rG.xmin + 1) * (rG.ymax - rG.ymin + 1)) + rdDimensions = ((rD.xmax - rD.xmin + 1) * (rD.ymax - rD.ymin + 1)) + recallMat[gtNum, detNum] = 0 if rgDimensions == 0 else intersected_area / rgDimensions + precisionMat[gtNum, detNum] = 0 if rdDimensions == 0 else intersected_area / rdDimensions + + # Find one-to-one matches + evaluationLog += "Find one-to-one matches\n" + for gtNum in range(len(gtRects)): + for detNum in range(len(detRects)): + if gtRectMat[gtNum] == 0 and detRectMat[ + detNum] == 0 and gtNum not in gtDontCareRectsNum and detNum not in detDontCareRectsNum: + match = one_to_one_match(gtNum, detNum) + if match is True: + rG = gtRects[gtNum] + rD = detRects[detNum] + normDist = center_distance(rG, rD) + normDist /= diag(rG) + diag(rD) + normDist *= 2.0 + if normDist < evaluationParams['EV_PARAM_IND_CENTER_DIFF_THR']: + gtRectMat[gtNum] = 1 + detRectMat[detNum] = 1 + recallAccum += evaluationParams['MTYPE_OO_O'] + precisionAccum += evaluationParams['MTYPE_OO_O'] + pairs.append({'gt': gtNum, 'det': detNum, 'type': 'OO'}) + evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + "\n" + else: + evaluationLog += "Match Discarded GT #" + str(gtNum) + " with Det #" + str( + detNum) + " normDist: " + str(normDist) + " \n" + # Find one-to-many matches + evaluationLog += "Find one-to-many matches\n" + for gtNum in range(len(gtRects)): + if gtNum not in gtDontCareRectsNum: + match, matchesDet = one_to_many_match(gtNum) + if match is True: + gtRectMat[gtNum] = 1 + recallAccum += evaluationParams['MTYPE_OM_O'] + precisionAccum += evaluationParams['MTYPE_OM_O'] * len(matchesDet) + pairs.append({'gt': gtNum, 'det': matchesDet, 'type': 'OM'}) + for detNum in matchesDet: + detRectMat[detNum] = 1 + evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(matchesDet) + "\n" + + # Find many-to-one matches + evaluationLog += "Find many-to-one matches\n" + for detNum in range(len(detRects)): + if detNum not in detDontCareRectsNum: + match, matchesGt = many_to_one_match(detNum) + if match is True: + detRectMat[detNum] = 1 + recallAccum += evaluationParams['MTYPE_OM_M'] * len(matchesGt) + precisionAccum += evaluationParams['MTYPE_OM_M'] + pairs.append({'gt': matchesGt, 'det': detNum, 'type': 'MO'}) + for gtNum in matchesGt: + gtRectMat[gtNum] = 1 + evaluationLog += "Match GT #" + str(matchesGt) + " with Det #" + str(detNum) + "\n" + + numGtCare = (len(gtRects) - len(gtDontCareRectsNum)) + if numGtCare == 0: + recall = float(1) + precision = float(0) if len(detRects) > 0 else float(1) + else: + recall = float(recallAccum) / numGtCare + precision = float(0) if (len(detRects) - len(detDontCareRectsNum)) == 0 else float( + precisionAccum) / (len(detRects) - len(detDontCareRectsNum)) + hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) + + evaluationLog += "Recall = " + str(recall) + "\n" + evaluationLog += "Precision = " + str(precision) + "\n" + + methodRecallSum += recallAccum + methodPrecisionSum += precisionAccum + numGt += len(gtRects) - len(gtDontCareRectsNum) + numDet += len(detRects) - len(detDontCareRectsNum) + + perSampleMetrics[resFile] = { + 'precision': precision, + 'recall': recall, + 'hmean': hmean, + 'pairs': pairs, + 'recallMat': [] if len(detRects) > 100 else recallMat.tolist(), + 'precisionMat': [] if len(detRects) > 100 else precisionMat.tolist(), + 'gtPolPoints': gtPolPoints, + 'detPolPoints': detPolPoints, + 'gtDontCare': gtDontCareRectsNum, + 'detDontCare': detDontCareRectsNum, + 'evaluationParams': evaluationParams, + 'evaluationLog': evaluationLog + } + + methodRecall = 0 if numGt == 0 else methodRecallSum / numGt + methodPrecision = 0 if numDet == 0 else methodPrecisionSum / numDet + methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / ( + methodRecall + methodPrecision) + + methodMetrics = {'precision': methodPrecision, 'recall': methodRecall, 'hmean': methodHmean} + + resDict = {'calculated': True, 'Message': '', 'method': methodMetrics, 'per_sample': perSampleMetrics} + + return resDict diff --git a/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/Algorithm_IoU.py b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/Algorithm_IoU.py new file mode 100644 index 0000000000..0609fdb6ac --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/Algorithm_IoU.py @@ -0,0 +1,424 @@ +# -*- coding: utf-8 -*- +from collections import namedtuple +import rrc_evaluation_funcs +from rrc_evaluation_funcs import logger +import importlib +import re + +def evaluation_imports(): + """ + evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation. + """ + return { + 'Polygon': 'plg', + 'numpy': 'np' + } + + +def default_evaluation_params(): + """ + default_evaluation_params: Default parameters to use for the validation and evaluation. + """ + return { + 'IOU_CONSTRAINT': 0.5, + 'AREA_PRECISION_CONSTRAINT': 0.5, + 'GT_SAMPLE_NAME_2_ID': 'gt_img_([0-9]+).txt', + 'DET_SAMPLE_NAME_2_ID': 'img_([0-9]+).txt', + 'LTRB': False, # LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4) + 'CRLF': False, # Lines are delimited by Windows CRLF format + 'CONFIDENCES': False, # Detections must include confidence value. AP will be calculated + 'PER_SAMPLE_RESULTS': True, # Generate per sample results and produce data for visualization + 'E2E': False #compute average edit distance for end to end evaluation + } + + +def validate_data(gtFilePath, submFilePath, evaluationParams): + """ + Method validate_data: validates that all files in the results folder are correct (have the correct name contents). + Validates also that there are no missing files in the folder. + If some error detected, the method raises the error + """ + gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID']) + subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True) + # Validate format of GroundTruth + for k in gt: + rrc_evaluation_funcs.validate_lines_in_file(k, gt[k], evaluationParams['CRLF'], evaluationParams['LTRB'], True) + + # Validate format of results + for k in subm: + if (k in gt) == False: + raise Exception("The sample %s not present in GT" % k) + + rrc_evaluation_funcs.validate_lines_in_file(k, subm[k], evaluationParams['CRLF'], evaluationParams['LTRB'], + evaluationParams['E2E'], evaluationParams['CONFIDENCES']) + + +def evaluate_method(gtFilePath, submFilePath, evaluationParams): + """ + Method evaluate_method: evaluate method and returns the results + Results. Dictionary with the following values: + - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 } + - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 } + """ + if evaluationParams['E2E']: + from hanziconv import HanziConv + import editdistance + + for module, alias in evaluation_imports().items(): + globals()[alias] = importlib.import_module(module) + + def polygon_from_points(points): + """ + Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4 + """ + resBoxes = np.empty([1, 8], dtype='int32') + resBoxes[0, 0] = int(points[0]) + resBoxes[0, 4] = int(points[1]) + resBoxes[0, 1] = int(points[2]) + resBoxes[0, 5] = int(points[3]) + resBoxes[0, 2] = int(points[4]) + resBoxes[0, 6] = int(points[5]) + resBoxes[0, 3] = int(points[6]) + resBoxes[0, 7] = int(points[7]) + pointMat = resBoxes[0].reshape([2, 4]).T + return plg.Polygon(pointMat) + + def rectangle_to_polygon(rect): + resBoxes = np.empty([1, 8], dtype='int32') + resBoxes[0, 0] = int(rect.xmin) + resBoxes[0, 4] = int(rect.ymax) + resBoxes[0, 1] = int(rect.xmin) + resBoxes[0, 5] = int(rect.ymin) + resBoxes[0, 2] = int(rect.xmax) + resBoxes[0, 6] = int(rect.ymin) + resBoxes[0, 3] = int(rect.xmax) + resBoxes[0, 7] = int(rect.ymax) + + pointMat = resBoxes[0].reshape([2, 4]).T + + return plg.Polygon(pointMat) + + def rectangle_to_points(rect): + points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), + int(rect.xmin), int(rect.ymin)] + return points + + def get_union(pD, pG): + areaA = pD.area() + areaB = pG.area() + return areaA + areaB - get_intersection(pD, pG) + + def get_intersection_over_union(pD, pG): + try: + return get_intersection(pD, pG) / get_union(pD, pG) + except: + return 0 + + def get_intersection(pD, pG): + pInt = pD & pG + if len(pInt) == 0: + return 0 + return pInt.area() + + def compute_ap(confList, matchList, numGtCare): + correct = 0 + AP = 0 + if len(confList) > 0: + confList = np.array(confList) + matchList = np.array(matchList) + sorted_ind = np.argsort(-confList) + confList = confList[sorted_ind] + matchList = matchList[sorted_ind] + for n in range(len(confList)): + match = matchList[n] + if match: + correct += 1 + AP += float(correct) / (n + 1) + + if numGtCare > 0: + AP /= numGtCare + + return AP + + #from RTWC17 + def normalize_txt(st): + """ + Normalize Chinese text strings by: + - remove puncutations and other symbols + - convert traditional Chinese to simplified + - convert English characters to lower cases + """ + st = ''.join(st.split(' ')) + st = re.sub("\"", "", st) + # remove any this not one of Chinese character, ascii 0-9, and ascii a-z and A-Z + new_st = re.sub(u'[^\u4e00-\u9fa5\u0041-\u005a\u0061-\u007a0-9]+', '', st) + # convert Traditional Chinese to Simplified Chinese + new_st = HanziConv.toSimplified(new_st) + # convert uppercase English letters to lowercase + new_st = new_st.lower() + return new_st + + def text_distance(str1, str2): + str1 = normalize_txt(str1) + str2 = normalize_txt(str2) + return editdistance.eval(str1, str2) + + perSampleMetrics = {} + + matchedSum = 0 + + Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') + + gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID']) + subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True) + + numGlobalCareGt = 0 + numGlobalCareDet = 0 + + arrGlobalConfidences = [] + arrGlobalMatches = [] + + #total edit distance + total_dist = 0 + + for resFile in gt: + + gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile]) + recall = 0 + precision = 0 + hmean = 0 + + detMatched = 0 + + iouMat = np.empty([1, 1]) + + gtPols = [] + detPols = [] + + gtTrans = [] + detTrans = [] + + gtPolPoints = [] + detPolPoints = [] + + # Array of Ground Truth Polygons' keys marked as don't Care + gtDontCarePolsNum = [] + # Array of Detected Polygons' matched with a don't Care GT + detDontCarePolsNum = [] + + pairs = [] + detMatchedNums = [] + + arrSampleConfidences = [] + arrSampleMatch = [] + sampleAP = 0 + + example_dist = 0 + match_tuples = [] + + evaluationLog = "" + + pointsList, _, transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True, False) + for n in range(len(pointsList)): + points = pointsList[n] + transcription = transcriptionsList[n] + dontCare = (transcription == "###") or (transcription=="?") + if evaluationParams['LTRB']: + gtRect = Rectangle(*points) + gtPol = rectangle_to_polygon(gtRect) + else: + gtPol = polygon_from_points(points) + gtPols.append(gtPol) + gtPolPoints.append(points) + gtTrans.append(transcription) + if dontCare: + gtDontCarePolsNum.append(len(gtPols) - 1) + + evaluationLog += "GT polygons: " + str(len(gtPols)) + ( + " (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum) > 0 else "\n") + + if resFile in subm: + + detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile]) + + pointsList, confidencesList, transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile,evaluationParams['CRLF'],evaluationParams['LTRB'],evaluationParams['E2E'],evaluationParams['CONFIDENCES']) + for n in range(len(pointsList)): + points = pointsList[n] + + if evaluationParams['LTRB']: + detRect = Rectangle(*points) + detPol = rectangle_to_polygon(detRect) + else: + detPol = polygon_from_points(points) + detPols.append(detPol) + detPolPoints.append(points) + if evaluationParams['E2E']: + transcription = transcriptionsList[n] + detTrans.append(transcription) + if len(gtDontCarePolsNum) > 0: + for dontCarePol in gtDontCarePolsNum: + dontCarePol = gtPols[dontCarePol] + intersected_area = get_intersection(dontCarePol, detPol) + pdDimensions = detPol.area() + precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions + if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT']): + detDontCarePolsNum.append(len(detPols) - 1) + break + + evaluationLog += "DET polygons: " + str(len(detPols)) + ( + " (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum) > 0 else "\n") + + if len(gtPols) > 0 and len(detPols) > 0: + # Calculate IoU and precision matrixs + outputShape = [len(gtPols), len(detPols)] + iouMat = np.empty(outputShape) + gtRectMat = np.zeros(len(gtPols), np.int8) + detRectMat = np.zeros(len(detPols), np.int8) + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + pG = gtPols[gtNum] + pD = detPols[detNum] + iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG) + + # match dt index of every gt + gtMatch = np.empty(len(gtPols), np.int8) + gtMatch.fill(-1) + # match gt index of every dt + dtMatch = np.empty(len(detPols), dtype=np.int8) + dtMatch.fill(-1) + + for gtNum in range(len(gtPols)): + max_iou = 0 + match_dt_idx = -1 + + for detNum in range(len(detPols)): + if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0\ + and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum: + if iouMat[gtNum, detNum] > evaluationParams['IOU_CONSTRAINT']: + gtRectMat[gtNum] = 1 + detRectMat[detNum] = 1 + detMatched += 1 + pairs.append({'gt': gtNum, 'det': detNum}) + detMatchedNums.append(detNum) + evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + "\n" + + if evaluationParams['E2E'] and gtMatch[gtNum] == -1 and dtMatch[detNum] == -1\ + and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum: + if iouMat[gtNum, detNum] > evaluationParams['IOU_CONSTRAINT'] and iouMat[gtNum, detNum] > max_iou: + max_iou = iouMat[gtNum, detNum] + match_dt_idx = detNum + + if evaluationParams['E2E'] and match_dt_idx >= 0: + gtMatch[gtNum] = match_dt_idx + dtMatch[match_dt_idx] = gtNum + + if evaluationParams['E2E']: + for gtNum in range(len(gtPols)): + if gtNum in gtDontCarePolsNum: + continue + gt_text = gtTrans[gtNum] + if gtMatch[gtNum] >= 0: + dt_text = detTrans[gtMatch[gtNum]] + else: + dt_text = u'' + dist = text_distance(gt_text, dt_text) + example_dist += dist + match_tuples.append((gt_text, dt_text, dist)) + match_tuples.append(("===============","==============", -1)) + for detNum in range(len(detPols)): + if detNum in detDontCarePolsNum: + continue + if dtMatch[detNum] == -1: + gt_text = u'' + dt_text = detTrans[detNum] + dist = text_distance(gt_text, dt_text) + example_dist += dist + match_tuples.append((gt_text, dt_text, dist)) + + if evaluationParams['CONFIDENCES']: + for detNum in range(len(detPols)): + if detNum not in detDontCarePolsNum: + # we exclude the don't care detections + match = detNum in detMatchedNums + + arrSampleConfidences.append(confidencesList[detNum]) + arrSampleMatch.append(match) + + arrGlobalConfidences.append(confidencesList[detNum]) + arrGlobalMatches.append(match) + #avoid when det file don't exist, example_dist=0 + elif evaluationParams['E2E']: + match_tuples.append(("===============", "==============", -1)) + dt_text = u'' + for gtNum in range(len(gtPols)): + if gtNum in gtDontCarePolsNum: + continue + gt_text = gtTrans[gtNum] + dist = text_distance(gt_text, dt_text) + example_dist += dist + match_tuples.append((gt_text, dt_text, dist)) + total_dist += example_dist + + if evaluationParams['E2E']: + logger.debug('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>') + logger.debug("file:{}".format(resFile)) + for tp in match_tuples: + gt_text, dt_text, dist = tp + logger.debug(u'GT: "{}" matched to DT: "{}", distance = {}'.format(gt_text, dt_text, dist)) + logger.debug('Distance = {:f}'.format(example_dist)) + logger.debug('<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<') + + numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) + numDetCare = (len(detPols) - len(detDontCarePolsNum)) + if numGtCare == 0: + recall = float(1) + precision = float(0) if numDetCare > 0 else float(1) + sampleAP = precision + else: + recall = float(detMatched) / numGtCare + precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare + if evaluationParams['CONFIDENCES'] and evaluationParams['PER_SAMPLE_RESULTS']: + sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare) + + hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) + + matchedSum += detMatched + numGlobalCareGt += numGtCare + numGlobalCareDet += numDetCare + + if evaluationParams['PER_SAMPLE_RESULTS']: + perSampleMetrics[resFile] = { + 'precision': precision, + 'recall': recall, + 'hmean': hmean, + 'pairs': pairs, + 'AP': sampleAP, + 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(), + 'gtPolPoints': gtPolPoints, + 'detPolPoints': detPolPoints, + 'gtDontCare': gtDontCarePolsNum, + 'detDontCare': detDontCarePolsNum, + 'evaluationParams': evaluationParams, + 'evaluationLog': evaluationLog + } + if evaluationParams['E2E']: + perSampleMetrics[resFile]['exampleDistance'] = example_dist + # print("file:{} exampleDistance:{}".format(resFile,example_dist)) + + # Compute MAP and MAR + AP = 0 + if evaluationParams['CONFIDENCES']: + AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt) + + methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum) / numGlobalCareGt + methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum) / numGlobalCareDet + methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / ( + methodRecall + methodPrecision) + methodDistance = 0 if len(gt) == 0 else float(total_dist)/len(gt) + + methodMetrics = {'precision': methodPrecision, 'recall': methodRecall, 'hmean': methodHmean} + print('npu_predict_map:', methodHmean) + + resDict = {'calculated': True, 'Message': '', 'method': methodMetrics, 'per_sample': perSampleMetrics} + + return resDict diff --git a/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/bintotxt.py b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/bintotxt.py new file mode 100644 index 0000000000..e1bd3c7853 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/bintotxt.py @@ -0,0 +1,137 @@ +# -*- coding:utf-8 -*- +import os +import sys +import cv2 +import numpy as np +import tensorflow as tf +#from pse import pse +import subprocess + +BASE_DIR = os.path.dirname(os.path.realpath(__file__)) + +#if subprocess.call(['make', '-C', BASE_DIR]) != 0: # return value +# raise RuntimeError('Cannot compile pse: {}'.format(BASE_DIR)) + +def pse(kernals, min_area=5): + ''' + :param kernals: + :param min_area: + :return: + ''' + from pse import pse_cpp + kernal_num = len(kernals) + if not kernal_num: + return np.array([]), [] + kernals = np.array(kernals) + label_num, label = cv2.connectedComponents(kernals[kernal_num - 1].astype(np.uint8), connectivity=4) + label_values = [] + for label_idx in range(1, label_num): + if np.sum(label == label_idx) < min_area: + label[label == label_idx] = 0 + continue + label_values.append(label_idx) + + pred = pse_cpp(label, kernals, c=7) + + return pred, label_values + +image_h = 704 +image_w = 1216 +ratio_w = 0.95 +ratio_h = 0.9777777777777777 +img_path = sys.argv[1] +bin_path = sys.argv[2] +txt_path = sys.argv[3] + +def get_images(): + ''' + find image files in test data path + :return: list of files found + ''' + files = [] + exts = ['jpg', 'png', 'jpeg', 'JPG'] + + for parent, _, filenames in os.walk(img_path): + for filename in filenames: + for ext in exts: + if filename.endswith(ext): + files.append(os.path.join(parent, filename)) + break + return files + +def detect(seg_maps, image_w, image_h, min_area_thresh=10, seg_map_thresh=0.9, ratio = 1): + ''' + restore text boxes from score map and geo map + :param seg_maps: + :param timer: + :param min_area_thresh: + :param seg_map_thresh: threshhold for seg map + :param ratio: compute each seg map thresh + :return: + ''' + if len(seg_maps.shape) == 4: + seg_maps = seg_maps[0, :, :, ] + #get kernals, sequence: 0->n, max -> min + kernals = [] + one = np.ones_like(seg_maps[..., 0], dtype=np.uint8) + zero = np.zeros_like(seg_maps[..., 0], dtype=np.uint8) + thresh = seg_map_thresh + for i in range(seg_maps.shape[-1]-1, -1, -1): + kernal = np.where(seg_maps[..., i]>thresh, one, zero) + kernals.append(kernal) + thresh = seg_map_thresh*ratio + mask_res, label_values = pse(kernals, min_area_thresh) + mask_res = np.array(mask_res) + mask_res_resized = cv2.resize(mask_res, (image_w, image_h), interpolation=cv2.INTER_NEAREST) + boxes = [] + for label_value in label_values: + #(y,x) + points = np.argwhere(mask_res_resized==label_value) + points = points[:, (1,0)] + rect = cv2.minAreaRect(points) + box = cv2.boxPoints(rect) + boxes.append(box) + + return np.array(boxes), kernals + +im_fn_list = get_images() +for im_fn in im_fn_list[8:9]: + im = cv2.imread(im_fn)[:, :, ::-1] + idx = os.path.basename(im_fn).split('/')[-1].split('.')[0].split('_')[1] + seg_maps = np.fromfile(bin_path+"/img_{}_1.bin".format(idx), "float32") + seg_maps = np.reshape(seg_maps, (1, 7, 176, 304)) + seg_maps = np.transpose(seg_maps, [0, 2, 3, 1]) + print(seg_maps.shape) + + boxes, kernels = detect(seg_maps=seg_maps, image_w=image_w, image_h=image_h) + + if boxes is not None: + boxes = boxes.reshape((-1, 4, 2)) + boxes[:, :, 0] /= ratio_w + boxes[:, :, 1] /= ratio_h + h, w, _ = im.shape + boxes[:, :, 0] = np.clip(boxes[:, :, 0], 0, w) + boxes[:, :, 1] = np.clip(boxes[:, :, 1], 0, h) + + # save to file + if boxes is not None: + res_file = os.path.join( + txt_path, + '{}.txt'.format(os.path.splitext( + os.path.basename(im_fn))[0])) + + + with open(res_file, 'w') as f: + num =0 + for i in range(len(boxes)): + box = boxes[i] + if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3]-box[0]) < 5: + continue + + num += 1 + + f.write('{},{},{},{},{},{},{},{}\r\n'.format( + box[0, 0], box[0, 1], box[1, 0], box[1, 1], box[2, 0], box[2, 1], box[3, 0], box[3, 1])) + cv2.polylines(im[:, :, ::-1], [box.astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 255, 0), thickness=2) + cv2.imshow('result', im) + cv2.waitKey() \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/rrc_evaluation_funcs.py b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/rrc_evaluation_funcs.py new file mode 100644 index 0000000000..680512e272 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/rrc_evaluation_funcs.py @@ -0,0 +1,394 @@ +# encoding: UTF-8 +import json +import sys + +sys.path.append('./') +import zipfile +import re +import sys +import os +import codecs +import logging + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +def load_zip_file_keys(file, fileNameRegExp=''): + """ + Returns an array with the entries of the ZIP file that match with the regular expression. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + """ + try: + archive = zipfile.ZipFile(file, mode='r', allowZip64=True) + except: + raise Exception('Error loading the ZIP archive.') + + pairs = [] + + for name in archive.namelist(): + addFile = True + keyName = name + if fileNameRegExp != "": + m = re.match(fileNameRegExp, name) + if m == None: + addFile = False + else: + if len(m.groups()) > 0: + keyName = m.group(1) + + if addFile: + pairs.append(keyName) + + return pairs + + +def load_zip_file(file, fileNameRegExp='', allEntries=False): + """ + Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + allEntries validates that all entries in the ZIP file pass the fileNameRegExp + """ + try: + archive = zipfile.ZipFile(file, mode='r', allowZip64=True) + except: + print(file) + raise Exception('Error loading the ZIP archive') + + pairs = [] + for name in archive.namelist(): + addFile = True + keyName = name + if fileNameRegExp != "": + m = re.match(fileNameRegExp, name) + if m == None: + addFile = False + else: + if len(m.groups()) > 0: + keyName = m.group(1) + + if addFile: + pairs.append([keyName, archive.read(name)]) + else: + if allEntries: + raise Exception('ZIP entry not valid: %s' % name) + + return dict(pairs) + + +def decode_utf8(raw): + """ + Returns a Unicode object on success, or None on failure + """ + try: + raw = codecs.decode(raw, 'utf-8', 'replace') + # extracts BOM if exists + raw = raw.encode('utf8') + if raw.startswith(codecs.BOM_UTF8): + raw = raw.replace(codecs.BOM_UTF8, '', 1) + return raw.decode('utf-8') + except: + return None + + +def validate_lines_in_file(fileName, file_contents, CRLF=True, LTRB=True, withTranscription=False, withConfidence=False, + imWidth=0, imHeight=0): + """ + This function validates that all lines of the file calling the Line validation function for each line + """ + utf8File = decode_utf8(file_contents) + if (utf8File is None): + raise Exception("The file %s is not UTF-8" % fileName) + + lines = utf8File.split("\r\n" if CRLF else "\n") + for line in lines: + line = line.replace("\r", "").replace("\n", "") + if (line != ""): + try: + validate_tl_line(line, LTRB, withTranscription, withConfidence, imWidth, imHeight) + except Exception as e: + raise Exception( + ("Line in sample not valid. Sample: %s Line: %s Error: %s" % (fileName, line, str(e))).encode( + 'utf-8', 'replace')) + + +def validate_tl_line(line, LTRB=True, withTranscription=True, withConfidence=True, imWidth=0, imHeight=0): + """ + Validate the format of the line. If the line is not valid an exception will be raised. + If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. + Posible values are: + LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] + LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] + """ + get_tl_line_values(line, LTRB, withTranscription, withConfidence, imWidth, imHeight) + + +def get_tl_line_values(line, LTRB=True, withTranscription=False, withConfidence=False, imWidth=0, imHeight=0): + """ + Validate the format of the line. If the line is not valid an exception will be raised. + If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. + Posible values are: + LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] + LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] + Returns values from a textline. Points , [Confidences], [Transcriptions] + """ + confidence = 0.0 + transcription = "" + points = [] + + numPoints = 4 + + if LTRB: + + numPoints = 4 + + if withTranscription and withConfidence: + m = re.match( + r'^\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$', + line) + if m == None: + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription") + elif withConfidence: + m = re.match( + r'^\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*([0-1].?[0-9]*)\s*$', + line) + if m == None: + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence") + elif withTranscription: + m = re.match( + r'^\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,(.*)$', + line) + if m == None: + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription") + else: + m = re.match( + r'^\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,?\s*$', + line) + if m == None: + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax") + + xmin = float(m.group(1)) + ymin = float(m.group(2)) + xmax = float(m.group(3)) + ymax = float(m.group(4)) + if (xmax < xmin): + raise Exception("Xmax value (%s) not valid (Xmax < Xmin)." % (xmax)) + if (ymax < ymin): + raise Exception("Ymax value (%s) not valid (Ymax < Ymin)." % (ymax)) + + points = [float(m.group(i)) for i in range(1, (numPoints + 1))] + + if (imWidth > 0 and imHeight > 0): + validate_point_inside_bounds(xmin, ymin, imWidth, imHeight) + validate_point_inside_bounds(xmax, ymax, imWidth, imHeight) + + else: + + numPoints = 8 + + if withTranscription and withConfidence: + m = re.match( + r'^\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$', + line) + if m == None: + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription") + elif withConfidence: + m = re.match( + r'^\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*([0-1].?[0-9]*)\s*$', + line) + if m == None: + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence") + elif withTranscription: + m = re.match( + r'^\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,(.*)$', + line) + if m == None: + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription") + else: + m = re.match( + r'^\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*,\s*(-?[0-9]+\.?[0-9]*)\s*$', + line) + if m == None: + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4") + + points = [float(m.group(i)) for i in range(1, (numPoints + 1))] + + isClockwise = validate_clockwise_points(points) + if not isClockwise: + # convert anticlockwise to clockwise sequence + points = [points[0], points[1], points[6], points[7], points[4], points[5], points[2], points[3]] + + if (imWidth > 0 and imHeight > 0): + validate_point_inside_bounds(points[0], points[1], imWidth, imHeight) + validate_point_inside_bounds(points[2], points[3], imWidth, imHeight) + validate_point_inside_bounds(points[4], points[5], imWidth, imHeight) + validate_point_inside_bounds(points[6], points[7], imWidth, imHeight) + + if withConfidence: + try: + confidence = float(m.group(numPoints + 1)) + except ValueError: + raise Exception("Confidence value must be a float") + + if withTranscription: + posTranscription = numPoints + (2 if withConfidence else 1) + transcription = m.group(posTranscription) + m2 = re.match(r'^\s*\"(.*)\"\s*$', transcription) + if m2 != None: # Transcription with double quotes, we extract the value and replace escaped characters + transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"") + + return points, confidence, transcription + + +def validate_point_inside_bounds(x, y, imWidth, imHeight): + if (x < 0 or x > imWidth): + raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" % (x, imWidth, imHeight)) + if (y < 0 or y > imHeight): + raise Exception( + "Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" % (y, imWidth, imHeight)) + + +def validate_clockwise_points(points): + """ + Validates that the points that the 4 points that dlimite a polygon are in clockwise order. + """ + + if len(points) != 8: + raise Exception("Points list not valid." + str(len(points))) + + point = [ + [int(points[0]), int(points[1])], + [int(points[2]), int(points[3])], + [int(points[4]), int(points[5])], + [int(points[6]), int(points[7])] + ] + edge = [ + (point[1][0] - point[0][0]) * (point[1][1] + point[0][1]), + (point[2][0] - point[1][0]) * (point[2][1] + point[1][1]), + (point[3][0] - point[2][0]) * (point[3][1] + point[2][1]), + (point[0][0] - point[3][0]) * (point[0][1] + point[3][1]) + ] + + summatory = edge[0] + edge[1] + edge[2] + edge[3] + if summatory > 0: + logger.debug( + "Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.") + return False + return True + + +def get_tl_line_values_from_file_contents(content, CRLF=True, LTRB=True, withTranscription=False, withConfidence=False, + imWidth=0, imHeight=0, sort_by_confidences=True): + """ + Returns all points, confindences and transcriptions of a file in lists. Valid line formats: + xmin,ymin,xmax,ymax,[confidence],[transcription] + x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription] + """ + pointsList = [] + transcriptionsList = [] + confidencesList = [] + + lines = content.split("\r\n" if CRLF else "\n") + for line in lines: + line = line.replace("\r", "").replace("\n", "") + if (line != ""): + points, confidence, transcription = get_tl_line_values(line, LTRB, withTranscription, withConfidence, + imWidth, imHeight) + pointsList.append(points) + transcriptionsList.append(transcription) + confidencesList.append(confidence) + + if withConfidence and len(confidencesList) > 0 and sort_by_confidences: + import numpy as np + sorted_ind = np.argsort(-np.array(confidencesList)) + confidencesList = [confidencesList[i] for i in sorted_ind] + pointsList = [pointsList[i] for i in sorted_ind] + transcriptionsList = [transcriptionsList[i] for i in sorted_ind] + + return pointsList, confidencesList, transcriptionsList + + +def main_evaluation(args, default_evaluation_params_fn, validate_data_fn, evaluate_method_fn, show_result=True, + per_sample=True): + """ + This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample. + Params: + args: + -g for ground truth, + -s for detect result, + -p The parameters that can be overrided are inside the function 'default_evaluation_params' located at the begining of the evaluation Script, + -o Path to a directory where to copy the file ‘results.zip’ that contains per-sample results, + evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results + """ + evalParams = default_evaluation_params_fn() + if args.p: + evalParams.update(json.loads(args.p)) + + resDict = {'calculated': True, 'Message': '', 'method': '{}', 'per_sample': '{}'} + try: + validate_data_fn(args.g, args.s, evalParams) + evalData = evaluate_method_fn(args.g, args.s, evalParams) + resDict.update(evalData) + + except Exception as e: + print("we are here") + resDict['Message'] = str(e) + resDict['calculated'] = False + + if args.o: + if not os.path.exists(args.o): + os.makedirs(args.o) + + resultsOutputname = args.o + '/results.zip' + outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True) + del resDict['per_sample'] + if 'output_items' in resDict.keys(): + del resDict['output_items'] + + outZip.writestr('method.json', json.dumps(resDict)) + + if not resDict['calculated']: + if show_result: + sys.stderr.write('Error!\n' + resDict['Message'] + '\n\n') + if args.o: + outZip.close() + return resDict + + if args.o: + if per_sample == True: + for k, v in evalData['per_sample'].iteritems(): + outZip.writestr(k + '.json', json.dumps(v)) + + if 'output_items' in evalData.keys(): + for k, v in evalData['output_items'].iteritems(): + outZip.writestr(k, v) + + outZip.close() + + if show_result: + sys.stdout.write("Calculated!") + sys.stdout.write(json.dumps(resDict['method'])) + print() + + return resDict + + +def main_validation(args, default_evaluation_params_fn, validate_data_fn): + """ + This process validates a method + Params: + default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation + validate_data_fn: points to a method that validates the correct format of the submission + """ + try: + evalParams = default_evaluation_params_fn() + if args.p: + evalParams.update(json.loads(args.p[1:-1])) + + validate_data_fn(args.g, args.s, evalParams) + print(SUCCESS) + sys.exit(0) + except Exception as e: + print(str(e)) + sys.exit(101) \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/script.py b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/script.py new file mode 100644 index 0000000000..f1227fdb70 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/Post-processing/script.py @@ -0,0 +1,37 @@ +#-*- coding:utf-8 -*- +''' +default parameters, you can modify them by +-p '{\"GT_SAMPLE_NAME_2_ID\":\"([0-9]+).txt\",\"DET_SAMPLE_NAME_2_ID\":\"([0-9]+).txt\",\"CONFIDENCES\":true}' + +'IOU_CONSTRAINT': 0.5, +'AREA_PRECISION_CONSTRAINT': 0.5, +'GT_SAMPLE_NAME_2_ID': 'gt_img_([0-9]+).txt', +'DET_SAMPLE_NAME_2_ID': 'res_img_([0-9]+).txt', +'LTRB': False, # LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4) +'CRLF': False, # Lines are delimited by Windows CRLF format +'CONFIDENCES': False, # Detections must include confidence value. AP will be calculated +'PER_SAMPLE_RESULTS': True # Generate per sample results and produce data for visualization + +''' +import argparse +import rrc_evaluation_funcs + +def argparser(): + parse = argparse.ArgumentParser() + parse.add_argument('-g', dest='g', default='./gt.zip', help="Path of the Ground Truth file. In most cases, the Ground Truth will be included in the same Zip file named 'gt.zip', gt.txt' or 'gt.json'. If not, you will be able to get it on the Downloads page of the Task.") + parse.add_argument('-s', dest='s', default='./submit.zip', help="Path of your method's results file.") + parse.add_argument('-o', dest='o', help="Path to a directory where to copy the file 'results.zip' that containts per-sample results.") + parse.add_argument('-p', dest='p', help="JSON string parameters to override the script default parameters. The parameters that can be overrided are inside the function 'default_evaluation_params' located at the begining of the evaluation Script. use: -p '{\"CRLF\":true}'") + parse.add_argument('-c', dest='choice', default='IoU', help="choose algorithm for differet tasks.(Challenges 1、2 use 'DetEva' Challenges 4 use 'IoU', default 'IoU')") + args = parse.parse_args() + return args + +if __name__=='__main__': + args = argparser() + if args.choice=='DetEva': + from Algorithm_DetEva import default_evaluation_params,validate_data,evaluate_method + elif args.choice=='IoU': + from Algorithm_IoU import default_evaluation_params,validate_data,evaluate_method + + + rrc_evaluation_funcs.main_evaluation(args, default_evaluation_params, validate_data,evaluate_method) diff --git a/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/README.md b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/README.md index 56a6051ca2..4f664c1317 100644 --- a/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/README.md +++ b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/README.md @@ -1 +1,168 @@ -1 \ No newline at end of file +# PSENet模型-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + + - [输入输出数据](#section540883920406) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [获取源码](#section4622531142816) + - [准备数据集](#section183221994411) + - [模型推理](#section741711594517) + +- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573) + + ****** + + + + + +# 概述 + +PSENet(渐进式的尺度扩张网络)是一种文本检测器,能够很好地检测自然场景中的任意形状的文本。该网络提出的方法能避免现有bounding box回归的方法产生的对弯曲文字的检测不准确的缺点,也能避免现有的通过分割方法产生的对于文字紧靠的情况分割效果不好的缺点。该网络是从FPN中受到启发采用了U形的网络框架,先通过将网络提取出的特征进行融合,然后利用分割的方式将提取出的特征进行像素分类,最后利用像素的分类结果通过一些后处理得到文本检测结果。 + + +- 参考实现: + + ``` + model_name=built-in/cv/PSENet_for_Pytorch + ``` + + + +## 输入输出数据 + +- 输入数据 + + | 输入数据 | 数据类型 | 大小 | 数据排布格式 | + | -------- | -------- | ------------------------- | ------------ | + | input | RGB_FP32 | batchsize x 3 x 704 x 1216 | NCHW | + + +- 输出数据 + + | 输出数据 | 数据类型 | 大小 | 数据排布格式 | + | -------- | -------- | -------- | ------------ | + | output1 | FLOAT32 | batchsize x 7 x 704 x 1216 | NCHW | + + + + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + |---------| ------- | ------------------------------------------------------------ | + | 固件与驱动 | 23.0.rc1 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) | + | CANN | 7.0.RC1.alpha003 | - | + | Python | 3.9.11 | - | + | PyTorch | 2.0.1 | - | + | Torch_AIE | 6.3.rc2 | - | + + + +# 快速上手 + +## 获取源码 + + +1. 安装依赖。 + + ``` + pip3 install -r requirements.txt + ``` + +## 准备数据集 + +1. 获取原始数据集。(解压命令参考tar –xvf \*.tar与 unzip \*.zip) + + + + 本模型支持ICDAR2015数据集。用户需自行获取数据集,其目录结构如下: + + ``` + ICDAR2015 + ├── gt.zip //验证集标注信息 + └── val2015 // 验证集文件夹 + ``` + +2. 数据预处理,将原始数据集转换为模型输入的数据。 + + 执行`preprocess_psenet_pytorch.py`脚本,完成预处理。 + + ``` + python3 preprocess_psenet_pytorch.py.py ./ICDAR2015/val2015 ./prep_bin + ``` + + + + +## 模型推理 + +1. 模型转换。 + + 使用PyTorch将模型权重文件.pth转换为.ts文件。 + + 1. 获取权重文件。 + + ```shell + wget https://ascend-repo-modelzoo.obs.cn-east-2.myhuaweicloud.com/model/1_PyTorch_PTH/PSENet/PTH/PSENet_for_PyTorch_1.2.pth + ``` + + 2. 导出ts模型。 + + 1. 使用`export.py`导出ts文件。 + + ``` + python3 export.py --model_path="./xx.pth" + ``` + + 获得`psenet.ts`文件。 + + - 参数说明 + - model_path: pth文件路径 + + + 3. 精度测试 + + 1. 使用`run.py`执行数据集上的模型推理 + + ``` + python3 run.py --img_path ./ICDAR2015/val2015 --ts_model ./psenet.ts --bin_path ./prep_bin + ``` + + - 参数说明 + - img_path:数据集图片路径 + - ts_model:模型文件路径 + - bin_path:图片预处理得到的二进制文件路径 + + 4. 性能验证 + + 1. 使用`perf.py`执行PSENet的性能测试 + + ``` + python3 perf.py --mode ts --ts_path ./psenet.ts --batch_size 1 --opt_level 1 + ``` + + - 参数说明 + - mode:使用ts模型进行推理 + - ts_path:ts模型文件所在路径 + - batch_size:batch数 + - opt_level:模型优化参数 + + + +# 模型推理性能&精度(未更新) + +调用ACL接口推理计算,性能参考下列数据。 + +| 芯片型号 | Batch Size | 数据集 | 精度 | 性能 | +| -------- | ---------- | ------ | ---- | ---- | +| 310P3 | 1 | ICDAR2015 | acc:0.805
recall:0.640 | 33 | diff --git a/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/export.py b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/export.py new file mode 100644 index 0000000000..b3ab48c96b --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/export.py @@ -0,0 +1,76 @@ +import argparse +from collections import OrderedDict + +import torch + +from fpn_resnet_nearest import resnet50 + + +EXPORT_OM = False + + +def proc_nodes_module(checkpoint, AttrName): + new_state_dict = OrderedDict() + for k, v in checkpoint[AttrName].items(): + if (k[0:7] == "module."): + name = k[7:] + else: + name = k[0:] + new_state_dict[name] = v + return new_state_dict + + +def convert(model_path): + """ + 1. load model + 2. trace model + + Args: + model_path (str): path to the model file (pth) + """ + checkpoint = torch.load(model_path, map_location='cpu') + checkpoint['state_dict'] = proc_nodes_module(checkpoint, 'state_dict') + model = resnet50() + model.load_state_dict(checkpoint['state_dict']) + model.eval() + + input_data = torch.ones(1, 3, 704, 1216) + ts_model = torch.jit.trace(model, input_data) + ts_model.save("./psenet.ts") + + """ + Compile the model and export om + """ + if EXPORT_OM: + import torch_aie + from torch_aie import _enums + torch_aie.set_device(0) + print("start compile") + torchaie_model = torch_aie.compile( + ts_model, + inputs=[torch_aie.Input(input_data.shape)], + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version='Ascend310P3', + optimization_level=0 + ) + print("end compile") + torchaie_model.eval() + + torch_aie.export_engine(torchaie_model, + "forward", + "psenet_torchaie.om", + inputs=input) + + +def parse_args(): + parser = argparse.ArgumentParser(description='TorchAIE PSENet') + parser.add_argument('--model_path', type=str, required=True, help='Path to the model file (pth)') # "/onnx/psenet/PSENet_for_PyTorch_1.2.pth" + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + model_path = args.model_path + convert(model_path) + \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/export_onnx.py b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/export_onnx.py new file mode 100644 index 0000000000..bd0a62d975 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/export_onnx.py @@ -0,0 +1,54 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + +import torch +import torch.onnx +import torch._utils +import onnx + +from fpn_resnet_nearest import resnet50 + + +def proc_nodes_module(checkpoint, AttrName): + new_state_dict = OrderedDict() + for k, v in checkpoint[AttrName].items(): + if (k[0:7] == "module."): + name = k[7:] + else: + name = k[0:] + + new_state_dict[name] = v + return new_state_dict + + +def convert(): + checkpoint = torch.load("./PSENet_for_PyTorch_1.2.pth", map_location='cpu') + checkpoint['state_dict'] = proc_nodes_module(checkpoint, 'state_dict') + # model = mobilenet.mobilenet_v2(pretrained = False) + model = resnet50() + model.load_state_dict(checkpoint['state_dict']) + model.eval() + + input_names = ["actual_input_1"] + output_names = ["output1"] + dummy_input = torch.randn(1, 3, 704, 1216) + dynamic_axes = {'actual_input_1':{0:'-1'},'output1':{0:'-1'}} + print('\nStarting ONNX export with onnx %s...' % onnx.__version__) + torch.onnx.export(model, dummy_input, "PSENet_704_1216_nearest.onnx", input_names=input_names, output_names=output_names,dynamic_axes = dynamic_axes, opset_version=11) + + +if __name__ == "__main__": + convert() diff --git a/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/fpn_resnet_nearest.py b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/fpn_resnet_nearest.py new file mode 100644 index 0000000000..fe1f758b19 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/fpn_resnet_nearest.py @@ -0,0 +1,540 @@ +# Apache License +# Version 2.0, January 2004 +# http://www.apache.org/licenses/ +# +# TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION +# +# 1. Definitions. +# +# "License" shall mean the terms and conditions for use, reproduction, +# and distribution as defined by Sections 1 through 9 of this document. +# +# "Licensor" shall mean the copyright owner or entity authorized by +# the copyright owner that is granting the License. +# +# "Legal Entity" shall mean the union of the acting entity and all +# other entities that control, are controlled by, or are under common +# control with that entity. For the purposes of this definition, +# "control" means (i) the power, direct or indirect, to cause the +# direction or management of such entity, whether by contract or +# otherwise, or (ii) ownership of fifty percent (50%) or more of the +# outstanding shares, or (iii) beneficial ownership of such entity. +# +# "You" (or "Your") shall mean an individual or Legal Entity +# exercising permissions granted by this License. +# +# "Source" form shall mean the preferred form for making modifications, +# including but not limited to software source code, documentation +# source, and configuration files. +# +# "Object" form shall mean any form resulting from mechanical +# transformation or translation of a Source form, including but +# not limited to compiled object code, generated documentation, +# and conversions to other media types. +# +# "Work" shall mean the work of authorship, whether in Source or +# Object form, made available under the License, as indicated by a +# copyright notice that is included in or attached to the work +# (an example is provided in the Appendix below). +# +# "Derivative Works" shall mean any work, whether in Source or Object +# form, that is based on (or derived from) the Work and for which the +# editorial revisions, annotations, elaborations, or other modifications +# represent, as a whole, an original work of authorship. For the purposes +# of this License, Derivative Works shall not include works that remain +# separable from, or merely link (or bind by name) to the interfaces of, +# the Work and Derivative Works thereof. +# +# "Contribution" shall mean any work of authorship, including +# the original version of the Work and any modifications or additions +# to that Work or Derivative Works thereof, that is intentionally +# submitted to Licensor for inclusion in the Work by the copyright owner +# or by an individual or Legal Entity authorized to submit on behalf of +# the copyright owner. For the purposes of this definition, "submitted" +# means any form of electronic, verbal, or written communication sent +# to the Licensor or its representatives, including but not limited to +# communication on electronic mailing lists, source code control systems, +# and issue tracking systems that are managed by, or on behalf of, the +# Licensor for the purpose of discussing and improving the Work, but +# excluding communication that is conspicuously marked or otherwise +# designated in writing by the copyright owner as "Not a Contribution." +# +# "Contributor" shall mean Licensor and any individual or Legal Entity +# on behalf of whom a Contribution has been received by Licensor and +# subsequently incorporated within the Work. +# +# 2. Grant of Copyright License. Subject to the terms and conditions of +# this License, each Contributor hereby grants to You a perpetual, +# worldwide, non-exclusive, no-charge, royalty-free, irrevocable +# copyright license to reproduce, prepare Derivative Works of, +# publicly display, publicly perform, sublicense, and distribute the +# Work and such Derivative Works in Source or Object form. +# +# 3. Grant of Patent License. Subject to the terms and conditions of +# this License, each Contributor hereby grants to You a perpetual, +# worldwide, non-exclusive, no-charge, royalty-free, irrevocable +# (except as stated in this section) patent license to make, have made, +# use, offer to sell, sell, import, and otherwise transfer the Work, +# where such license applies only to those patent claims licensable +# by such Contributor that are necessarily infringed by their +# Contribution(s) alone or by combination of their Contribution(s) +# with the Work to which such Contribution(s) was submitted. If You +# institute patent litigation against any entity (including a +# cross-claim or counterclaim in a lawsuit) alleging that the Work +# or a Contribution incorporated within the Work constitutes direct +# or contributory patent infringement, then any patent licenses +# granted to You under this License for that Work shall terminate +# as of the date such litigation is filed. +# +# 4. Redistribution. You may reproduce and distribute copies of the +# Work or Derivative Works thereof in any medium, with or without +# modifications, and in Source or Object form, provided that You +# meet the following conditions: +# +# (a) You must give any other recipients of the Work or +# Derivative Works a copy of this License; and +# +# (b) You must cause any modified files to carry prominent notices +# stating that You changed the files; and +# +# (c) You must retain, in the Source form of any Derivative Works +# that You distribute, all copyright, patent, trademark, and +# attribution notices from the Source form of the Work, +# excluding those notices that do not pertain to any part of +# the Derivative Works; and +# +# (d) If the Work includes a "NOTICE" text file as part of its +# distribution, then any Derivative Works that You distribute must +# include a readable copy of the attribution notices contained +# within such NOTICE file, excluding those notices that do not +# pertain to any part of the Derivative Works, in at least one +# of the following places: within a NOTICE text file distributed +# as part of the Derivative Works; within the Source form or +# documentation, if provided along with the Derivative Works; or, +# within a display generated by the Derivative Works, if and +# wherever such third-party notices normally appear. The contents +# of the NOTICE file are for informational purposes only and +# do not modify the License. You may add Your own attribution +# notices within Derivative Works that You distribute, alongside +# or as an addendum to the NOTICE text from the Work, provided +# that such additional attribution notices cannot be construed +# as modifying the License. +# +# You may add Your own copyright statement to Your modifications and +# may provide additional or different license terms and conditions +# for use, reproduction, or distribution of Your modifications, or +# for any such Derivative Works as a whole, provided Your use, +# reproduction, and distribution of the Work otherwise complies with +# the conditions stated in this License. +# +# 5. Submission of Contributions. Unless You explicitly state otherwise, +# any Contribution intentionally submitted for inclusion in the Work +# by You to the Licensor shall be under the terms and conditions of +# this License, without any additional terms or conditions. +# Notwithstanding the above, nothing herein shall supersede or modify +# the terms of any separate license agreement you may have executed +# with Licensor regarding such Contributions. +# +# 6. Trademarks. This License does not grant permission to use the trade +# names, trademarks, service marks, or product names of the Licensor, +# except as required for reasonable and customary use in describing the +# origin of the Work and reproducing the content of the NOTICE file. +# +# 7. Disclaimer of Warranty. Unless required by applicable law or +# agreed to in writing, Licensor provides the Work (and each +# Contributor provides its Contributions) on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied, including, without limitation, any warranties or conditions +# of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A +# PARTICULAR PURPOSE. You are solely responsible for determining the +# appropriateness of using or redistributing the Work and assume any +# risks associated with Your exercise of permissions under this License. +# +# 8. Limitation of Liability. In no event and under no legal theory, +# whether in tort (including negligence), contract, or otherwise, +# unless required by applicable law (such as deliberate and grossly +# negligent acts) or agreed to in writing, shall any Contributor be +# liable to You for damages, including any direct, indirect, special, +# incidental, or consequential damages of any character arising as a +# result of this License or out of the use or inability to use the +# Work (including but not limited to damages for loss of goodwill, +# work stoppage, computer failure or malfunction, or any and all +# other commercial damages or losses), even if such Contributor +# has been advised of the possibility of such damages. +# +# 9. Accepting Warranty or Additional Liability. While redistributing +# the Work or Derivative Works thereof, You may choose to offer, +# and charge a fee for, acceptance of support, warranty, indemnity, +# or other liability obligations and/or rights consistent with this +# License. However, in accepting such obligations, You may act only +# on Your own behalf and on Your sole responsibility, not on behalf +# of any other Contributor, and only if You agree to indemnify, +# defend, and hold each Contributor harmless for any liability +# incurred by, or claims asserted against, such Contributor by reason +# of your accepting any such warranty or additional liability. +# +# END OF TERMS AND CONDITIONS +# +# APPENDIX: How to apply the Apache License to your work. +# +# To apply the Apache License to your work, attach the following +# boilerplate notice, with the fields enclosed by brackets "[]" +# replaced with your own identifying information. (Don't include +# the brackets!) The text should be enclosed in the appropriate +# comment syntax for the file format. We also recommend that a +# file or class name and description of purpose be included on the +# same "printed page" as the copyright notice for easier +# identification within third-party archives. +# +# Copyright [yyyy] [name of copyright owner] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://spdx.org/licenses/BSD-3-Clause.html +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import time +import os +from configparser import ConfigParser +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152'] +config = ConfigParser() +config.read(filenames='/onnx/psenet/url.ini',encoding = 'UTF-8') +resnet18_url = config.get(section="DEFAULT", option="resnet18") +resnet34_url = config.get(section="DEFAULT", option="resnet34") +resnet50_url = config.get(section="DEFAULT", option="resnet50") +resnet101_url = config.get(section="DEFAULT", option="resnet101") +resnet152_url = config.get(section="DEFAULT", option="resnet152") +model_urls = { + 'resnet18': str(resnet18_url), + 'resnet34': str(resnet34_url), + 'resnet50': str(resnet50_url), + 'resnet101': str(resnet101_url), + 'resnet152': str(resnet152_url), +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=7, scale=1): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu1 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=False) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + # Top layer + self.toplayer = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0) # Reduce channels + self.toplayer_bn = nn.BatchNorm2d(256) + self.toplayer_relu = nn.ReLU(inplace=True) + + # Smooth layers + self.smooth1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.smooth1_bn = nn.BatchNorm2d(256) + self.smooth1_relu = nn.ReLU(inplace=True) + + self.smooth2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.smooth2_bn = nn.BatchNorm2d(256) + self.smooth2_relu = nn.ReLU(inplace=True) + + self.smooth3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.smooth3_bn = nn.BatchNorm2d(256) + self.smooth3_relu = nn.ReLU(inplace=True) + + # Lateral layers + self.latlayer1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) + self.latlayer1_bn = nn.BatchNorm2d(256) + self.latlayer1_relu = nn.ReLU(inplace=True) + + self.latlayer2 = nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0) + self.latlayer2_bn = nn.BatchNorm2d(256) + self.latlayer2_relu = nn.ReLU(inplace=True) + + self.latlayer3 = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0) + self.latlayer3_bn = nn.BatchNorm2d(256) + self.latlayer3_relu = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(1024, 256, kernel_size=3, stride=1, padding=1) + self.bn2 = nn.BatchNorm2d(256) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(256, num_classes, kernel_size=1, stride=1, padding=0) + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + + self.scale = scale + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _upsample(self, x, y, scale=1): + _, _, H, W = y.size() + return F.interpolate(x, size=(H // scale, W // scale), mode='nearest') + + def _upsample_add(self, x, y): + _, _, H, W = y.size() + return F.interpolate(x, size=(H, W), mode='nearest') + y + + def forward(self, x): + h = x + h = self.conv1(h) + h = self.bn1(h) + h = self.relu1(h) + h = self.maxpool(h) + + + h = self.layer1(h) + c2 = h + h = self.layer2(h) + c3 = h + h = self.layer3(h) + c4 = h + h = self.layer4(h) + c5 = h + + # Top-down + p5 = self.toplayer(c5) + p5 = self.toplayer_relu(self.toplayer_bn(p5)) + + c4 = self.latlayer1(c4) + c4 = self.latlayer1_relu(self.latlayer1_bn(c4)) + t = time.time() + p4 = self._upsample_add(p5, c4) + p4 = self.smooth1(p4) + p4 = self.smooth1_relu(self.smooth1_bn(p4)) + + c3 = self.latlayer2(c3) + c3 = self.latlayer2_relu(self.latlayer2_bn(c3)) + t = time.time() + p3 = self._upsample_add(p4, c3) + p3 = self.smooth2(p3) + p3 = self.smooth2_relu(self.smooth2_bn(p3)) + + c2 = self.latlayer3(c2) + c2 = self.latlayer3_relu(self.latlayer3_bn(c2)) + p2 = self._upsample_add(p3, c2) + p2 = self.smooth3(p2) + p2 = self.smooth3_relu(self.smooth3_bn(p2)) + + p3 = self._upsample(p3, p2) + p4 = self._upsample(p4, p2) + p5 = self._upsample(p5, p2) + + + out = torch.cat((p2, p3, p4, p5), 1) + + + out = self.conv2(out) + + + out = self.relu2(self.bn2(out)) + + out = self.conv3(out) + + out = self._upsample(out, x, scale=self.scale) + + return out + + +def resnet18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) + return model + + +def resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) + return model + + +def resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + pretrained_model = model_zoo.load_url(model_urls['resnet50']) + state = model.state_dict() + for key in state.keys(): + if key in pretrained_model.keys(): + state[key] = pretrained_model[key] + model.load_state_dict(state) + return model + + +def resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + pretrained_model = model_zoo.load_url(model_urls['resnet101']) + state = model.state_dict() + for key in state.keys(): + if key in pretrained_model.keys(): + state[key] = pretrained_model[key] + model.load_state_dict(state) + return model + + +def resnet152(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + if pretrained: + pretrained_model = model_zoo.load_url(model_urls['resnet152']) + state = model.state_dict() + for key in state.keys(): + if key in pretrained_model.keys(): + state[key] = pretrained_model[key] + model.load_state_dict(state) + return model diff --git a/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/perf.py b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/perf.py new file mode 100644 index 0000000000..3aa8ba6621 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/perf.py @@ -0,0 +1,114 @@ +import argparse +import time +from tqdm import tqdm +from collections import OrderedDict + +import torch +import numpy as np + +from ais_bench.infer.interface import InferSession +from fpn_resnet_nearest import resnet50 +import torch_aie +from torch_aie import _enums + + +INPUT_WIDTH = 704 +INPUT_HEIGHT = 1216 + + +def proc_nodes_module(checkpoint, AttrName): + new_state_dict = OrderedDict() + for k, v in checkpoint[AttrName].items(): + if (k[0:7] == "module."): + name = k[7:] + else: + name = k[0:] + + new_state_dict[name] = v + return new_state_dict + + +def parse_args(): + args = argparse.ArgumentParser(description="A program that operates in 'om' or 'ts' mode.") + args.add_argument("--mode", choices=["om", "ts"], required=True, help="Specify the mode ('om' or 'ts').") + args.add_argument('--om_path',help='PSENet om file path', type=str, + default='./psenet.om' + ) + args.add_argument('--ts_path',help='PSENet ts file path', type=str, + default='./psenet.ts' + ) + args.add_argument("--batch_size", type=int, default=4, help="batch size.") + args.add_argument("--opt_level", type=int, default=0, help="opt level.") + return args.parse_args() + + +if __name__ == '__main__': + infer_times = 100 + om_cost = 0 + pt_cost = 0 + opts = parse_args() + OM_PATH = opts.om_path + TS_PATH = opts.ts_path + BATCH_SIZE = opts.batch_size + OPTS_LEVEL = opts.opt_level + + if opts.mode == "om": + om_model = InferSession(0, OM_PATH) + for _ in tqdm(range(0, infer_times)): + dummy_input = np.random.randn(1, 3, INPUT_WIDTH, INPUT_HEIGHT).astype(np.uint8) + start = time.time() + output = om_model.infer([dummy_input], 'static', custom_sizes=90000000) # revise static + cost = time.time() - start + om_cost += cost + + if opts.mode == "ts": + torch_aie.set_device(0) + ts_model = torch.jit.load(TS_PATH) + + # revise static + input_info = [torch_aie.Input((BATCH_SIZE, 3, INPUT_WIDTH, INPUT_HEIGHT))] + + torch_aie.set_device(0) + print("start compile") + torchaie_model = torch_aie.compile( + ts_model, + inputs=input_info, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version='Ascend310P3', + optimization_level=OPTS_LEVEL, + ) + print("end compile") + torchaie_model.eval() + + dummy_input = np.zeros((BATCH_SIZE, 3, INPUT_WIDTH, INPUT_HEIGHT)) + dummy_input = torch.Tensor(dummy_input) + input_tensor = dummy_input.to("npu:0") + + loops = 100 + warm_ctr = 10 + + default_stream = torch_aie.npu.default_stream() + time_cost = 0 + + while warm_ctr: + _ = torchaie_model(input_tensor) + default_stream.synchronize() + warm_ctr -= 1 + + for i in range(loops): + t0 = time.time() + _ = torchaie_model(input_tensor) + default_stream.synchronize() + t1 = time.time() + time_cost += (t1 - t0) + print(i) + + print(f"fps: {loops} * {BATCH_SIZE} / {time_cost : .3f} samples/s") + print("torch_aie fps: ", loops * BATCH_SIZE / time_cost) + + from datetime import datetime + current_time = datetime.now() + formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S") + print("Current Time:", formatted_time) + + torch_aie.finalize() diff --git a/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/preprocess_psenet_pytorch.py b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/preprocess_psenet_pytorch.py new file mode 100644 index 0000000000..4e6c22c08e --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/preprocess_psenet_pytorch.py @@ -0,0 +1,64 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +import numpy as np +import cv2 + + +def scale(img, long_size=2240): + h, w = img.shape[0:2] + scale = long_size * 1.0 / max(h, w) + img = cv2.resize(img, dsize=None, fx=scale, fy=scale) + # img = cv2.resize(img, (1260, 2240)) + print(img.shape) + return img + + +def psenet_onnx(file_path, bin_path): + if not os.path.exists(bin_path): + os.makedirs(bin_path) + i = 0 + in_files = os.listdir(file_path) + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + for file in in_files: + i = i + 1 + print(file, "====", i) + img = cv2.imread(os.path.join(file_path, file)) + img = img[:, :, [2, 1, 0]] # bgr -> rgb + # img = scale(img) + img = cv2.resize(img, (1216, 704)) + + img = np.array(img, dtype=np.float32) + img = img / 255. + + # 均值方差 + img[..., 0] -= mean[0] + img[..., 1] -= mean[1] + img[..., 2] -= mean[2] + img[..., 0] /= std[0] + img[..., 1] /= std[1] + img[..., 2] /= std[2] + + img = img.transpose(2, 0, 1) # HWC -> CHW + img.tofile(os.path.join(bin_path, file.split('.')[0] + '.bin')) + + +if __name__ == "__main__": + file_path = os.path.abspath(sys.argv[1]) + bin_path = os.path.abspath(sys.argv[2]) + psenet_onnx(file_path, bin_path) diff --git a/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/pypse.py b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/pypse.py new file mode 100644 index 0000000000..d53dfd8f7f --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/pypse.py @@ -0,0 +1,63 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import cv2 +import queue as Queue + +def pse(kernals, min_area): + kernal_num = len(kernals) + pred = np.zeros(kernals[0].shape, dtype='int32') + + label_num, label = cv2.connectedComponents(kernals[kernal_num - 1], connectivity=4) + + for label_idx in range(1, label_num): + if np.sum(label == label_idx) < min_area: + label[label == label_idx] = 0 + + queue = Queue.Queue(maxsize = 0) + next_queue = Queue.Queue(maxsize = 0) + points = np.array(np.where(label > 0)).transpose((1, 0)) + + for point_idx in range(points.shape[0]): + x, y = points[point_idx, 0], points[point_idx, 1] + l = label[x, y] + queue.put((x, y, l)) + pred[x, y] = l + + dx = [-1, 1, 0, 0] + dy = [0, 0, -1, 1] + for kernal_idx in range(kernal_num - 2, -1, -1): + kernal = kernals[kernal_idx].copy() + while not queue.empty(): + (x, y, l) = queue.get() + + is_edge = True + for j in range(4): + tmpx = x + dx[j] + tmpy = y + dy[j] + if tmpx < 0 or tmpx >= kernal.shape[0] or tmpy < 0 or tmpy >= kernal.shape[1]: + continue + if kernal[tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: + continue + + queue.put((tmpx, tmpy, l)) + pred[tmpx, tmpy] = l + is_edge = False + if is_edge: + next_queue.put((x, y, l)) + + queue, next_queue = next_queue, queue + + return pred \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/requirements.txt b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/requirements.txt new file mode 100644 index 0000000000..be4017c1f7 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/requirements.txt @@ -0,0 +1,192 @@ +absl-py==1.4.0 +addict==2.4.0 +attrs==23.1.0 +backcall==0.2.0 +bce-python-sdk==0.8.95 +beautifulsoup4==4.12.2 +black==21.4b2 +bleach==6.1.0 +blinker==1.7.0 +cachetools==5.3.1 +certifi==2023.5.7 +cffi==1.15.1 +chainer==7.8.1 +charset-normalizer==3.1.0 +click==8.1.7 +click-aliases==1.0.1 +cloudpickle==2.2.1 +cmake==3.26.4 +coloredlogs==15.0.1 +comm==0.1.4 +contourpy==1.2.0 +custom-passes-reduce==0.0.0 +cycler==0.12.1 +datasets==2.3.0 +debugpy==1.8.0 +decorator==5.1.1 +defusedxml==0.7.1 +dill==0.3.7 +distlib==0.3.7 +easydict==1.7 +exceptiongroup==1.1.3 +executing==1.2.0 +fastjsonschema==2.18.1 +filelock==3.12.2 +fqdn==1.5.1 +frozenlist==1.4.0 +fsspec==2023.9.0 +future==0.18.3 +gast==0.4.0 +gitdb==4.0.10 +hydra-core==1.3.2 +idna==3.4 +imageio==2.31.5 +imgaug==0.4.0 +importlib-metadata==6.8.0 +importlib-resources==6.1.1 +iopath==0.1.8 +ipykernel==6.25.2 +ipython==8.14.0 +isoduration==20.11.0 +itsdangerous==2.1.2 +jedi==0.18.2 +Jinja2==3.1.2 +joblib==1.3.2 +json-tricks==3.17.3 +json5==0.9.14 +jsonpointer==2.4 +jsonschema==4.19.1 +jsonschema-specifications==2023.7.1 +kiwisolver==1.4.5 +lazy_loader==0.3 +libclang==16.0.6 +lit==16.0.5.post0 +lmdb==1.4.1 +loguru==0.7.2 +Markdown==3.4.4 +markdown-it-py==3.0.0 +MarkupSafe==2.1.3 +matplotlib==3.8.1 +matplotlib-inline==0.1.6 +mdurl==0.1.2 +mistune==3.0.2 +mpmath==1.3.0 +multidict==6.0.4 +multiprocess==0.70.15 +mypy-extensions==1.0.0 +nbclient==0.8.0 +nbconvert==7.9.2 +nbformat==5.9.2 +nest-asyncio==1.5.8 +netron==7.3.0 +networkx==3.1 +ninja==1.11.1.1 +notebook_shim==0.2.3 +numpy==1.26.1 +oauthlib==3.2.2 +omegaconf==2.3.0 +opencv-python==4.8.1.78 +opt-einsum==3.3.0 +overrides==7.4.0 +packaging==23.2 +pandas==2.0.2 +pandocfilters==1.5.0 +parso==0.8.3 +pathlib2==2.3.7.post1 +pathspec==0.11.2 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==10.1.0 +platformdirs==3.11.0 +Polygon3==3.0.9.1 +portalocker==2.8.2 +prettytable==3.9.0 +prometheus-client==0.17.1 +prompt-toolkit==3.0.39 +protobuf==3.20.2 +psutil==5.9.5 +ptyprocess==0.7.0 +pure-eval==0.2.2 +py-cpuinfo==9.0.0 +py3nvml==0.2.7 +pyarrow==13.0.0 +pyascendie==0.0.0 +pyasn1==0.5.0 +pyasn1-modules==0.3.0 +pyclipper==1.3.0.post5 +pycocotools==2.0.7 +pycparser==2.21 +pycryptodome==3.19.0 +pydot==1.4.2 +Pygments==2.16.1 +pyparsing==3.1.1 +python-dateutil==2.8.2 +python-json-logger==2.0.7 +pytz==2023.3 +PyYAML==6.0.1 +pyzmq==25.1.1 +rarfile==4.1 +referencing==0.30.2 +regex==2023.10.3 +requests==2.31.0 +requests-oauthlib==1.3.1 +responses==0.18.0 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rich==13.6.0 +rpds-py==0.10.6 +rsa==4.9 +safetensors==0.3.3 +scikit-image==0.22.0 +scikit-learn==1.3.1 +scipy==1.10.1 +seaborn==0.12.2 +Send2Trash==1.8.2 +sentry-sdk==1.31.0 +Shapely==1.6.4 +six==1.16.0 +skl2onnx==1.15.0 +sklearn==0.0 +smmap==5.0.0 +sniffio==1.3.0 +soupsieve==2.5 +stack-data==0.6.2 +sympy==1.12 +synr==0.5.0 +tabulate==0.9.0 +threadpoolctl==3.2.0 +tifffile==2023.8.30 +timm==0.6.13 +tinycss2==1.2.1 +tokenizers==0.14.1 +toml==0.10.2 +tomli==2.0.1 +torch==2.0.1+cpu +torchaudio==2.0.2+cpu +torchsummary==1.5.1 +torchvision==0.15.2+cpu +tornado==6.3.2 +tqdm==4.64.0 +traitlets==5.9.0 +transformers==4.35.2 +triton==2.0.0 +typeguard==2.13.3 +typing_extensions==4.5.0 +uri-template==1.3.0 +urllib3==1.26.16 +virtualenv==20.24.5 +visualdl==2.5.3 +wcwidth==0.2.9 +webcolors==1.13 +webencodings==0.5.1 +websocket-client==1.6.4 +websockets==12.0 +wenet==1.0.4 +Werkzeug==3.0.1 +wikiextractor==3.0.6 +wrapt==1.15.0 +xmltodict==0.13.0 +xxhash==3.4.1 +yacs==0.1.8 +yapf==0.40.2 +yarl==1.9.2 \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/run.py b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/run.py new file mode 100644 index 0000000000..cee98b794c --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/run.py @@ -0,0 +1,145 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse + +import numpy as np +import torch +import cv2 + +from pypse import pse as pypse +import torch_aie +from torch_aie import _enums + + +def parse_args(): + args = argparse.ArgumentParser(description="A program that operates in 'om' or 'ts' mode.") + args.add_argument('--img_path',help='dataset img path', type=str, + default='/home/ascend/ICDAR2015/val2015' + ) + args.add_argument('--ts_model',help='ts model path', type=str, + default='/onnx/psenet/psenet.ts' + ) + args.add_argument('--txt_path',help='output txt path', type=str, + default='./txt' + ) + args.add_argument('--bin_path',help='data bin path', type=str, + default='./prep_bin' + ) + return args.parse_args() + + +def get_images(img_path): + ''' + find image files in test data path + :return: list of files found + ''' + files = [] + exts = ['jpg', 'png', 'jpeg', 'JPG'] + + for parent, _, filenames in os.walk(img_path): + for filename in filenames: + for ext in exts: + if filename.endswith(ext): + files.append(os.path.join(parent, filename)) + break + files.sort(key=lambda x: int(x.split('/')[-1].split('.')[0].split('_')[1])) + return files + + +def main(): + opts = parse_args() + img_path = opts.img_path + ts_model_path = opts.ts_model + txt_path = opts.txt_path + bin_path = opts.bin_path + + if not os.path.exists(txt_path): + os.makedirs(txt_path) + + kernel_num=7 + min_kernel_area=5.0 + scale=1 + min_score = 0.9 + min_area = 600 + + im_fn_list = get_images(img_path) + ts_model = torch.jit.load(ts_model_path) + input_info = [torch_aie.Input((1, 3, 704, 1216))] + torch_aie.set_device(0) + print("Start compiling...") + torchaie_model = torch_aie.compile( + ts_model, + inputs=input_info, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version='Ascend310P3' + ) + print("Compile finished!") + torchaie_model.eval() + + for im_fn in im_fn_list: # '/home/ascend/ICDAR2015/val2015/img_307.jpg' + im = cv2.imread(im_fn) + idx = os.path.basename(im_fn).split('/')[-1].split('.')[0].split('_')[1] + print(idx) # 307 + img_bin_path = os.path.join(bin_path, 'img_{}.bin'.format(idx)) + input_np_arr = np.fromfile(img_bin_path, dtype=np.float32).reshape((1, 3, 704, 1216)) + input_tensor = torch.tensor(input_np_arr, dtype=torch.float32) + input_tensor = input_tensor.to('npu:0') + seg_maps = torchaie_model(input_tensor) + seg_maps = seg_maps.to('cpu') + # print("seg_maps shape: ", seg_maps.shape) + + score = torch.sigmoid(seg_maps[:, 0, :, :]) + outputs = (torch.sign(seg_maps - 1.0) + 1) / 2 + + text = outputs[:, 0, :, :] + kernels = outputs[:, 0:kernel_num, :, :] * text + + score = score.data.numpy()[0].astype(np.float32) + text = text.data.numpy()[0].astype(np.uint8) + kernels = kernels.numpy()[0].astype(np.uint8) + + # python version pse + pred = pypse(kernels, min_kernel_area / (scale * scale)) + + img_scale = (im.shape[1] * 1.0 / pred.shape[1], im.shape[0] * 1.0 / pred.shape[0]) + label = pred + label_num = np.max(label) + 1 + bboxes = [] + + for i in range(1, label_num): + points = np.array(np.where(label == i)).transpose((1, 0))[:, ::-1] + + if points.shape[0] < min_area: + continue + + score_i = np.mean(score[label == i]) + if score_i < min_score: + continue + + rect = cv2.minAreaRect(points) + bbox = cv2.boxPoints(rect) * img_scale + bbox = bbox.astype('int32') + bboxes.append(bbox.reshape(-1)) + + # save txt + res_file = os.path.join(txt_path,'{}.txt'.format(os.path.splitext(os.path.basename(im_fn))[0])) + print("im_fn:", im_fn) + with open(res_file, 'w') as f: + for _, bbox in enumerate(bboxes): + values = [int(v) for v in bbox] + line = "%d, %d, %d, %d, %d, %d, %d, %d\n" % tuple(values) + print(" line:", line) + f.write(line) \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/url.ini b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/url.ini new file mode 100644 index 0000000000..c30e61fb17 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/ocr/PSENet/url.ini @@ -0,0 +1,6 @@ +[DEFAULT] +resnet18=https://download.pytorch.org/models/resnet18-5c106cde.pth +resnet34=https://download.pytorch.org/models/resnet34-333f7ec4.pth +resnet50=https://download.pytorch.org/models/resnet50-19c8e357.pth +resnet101=https://download.pytorch.org/models/resnet101-5d3mb4d8f.pth +resnet152=https://download.pytorch.org/models/resnet152-b121ed2d.pth \ No newline at end of file -- Gitee