diff --git a/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch b/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..c031d1413a5ac52bbfeb5c5de663ff5c1fe83884 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch +++ b/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch @@ -0,0 +1,881 @@ +diff --git a/demo/demo.py b/demo/demo.py +index 36433c45..6f28620f 100644 +--- a/demo/demo.py ++++ b/demo/demo.py +@@ -86,7 +86,7 @@ def do_parse( + image_dir = str(os.path.basename(local_image_dir)) + content_list = pipeline_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir) + md_writer.write_string( +- f"{pdf_file_name}_content_list.json", ++ f"{pdf_file_name}_content.json", + json.dumps(content_list, ensure_ascii=False, indent=4), + ) + +@@ -142,7 +142,8 @@ def do_parse( + image_dir = str(os.path.basename(local_image_dir)) + content_list = vlm_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir) + md_writer.write_string( +- f"{pdf_file_name}_content_list.json", ++ # f"{pdf_file_name}_content_list.json", ++ f"{pdf_file_name}_content.json", ## 文件名太长了,linux文件系统ext4超过255字节无法保存 + json.dumps(content_list, ensure_ascii=False, indent=4), + ) + +diff --git a/mineru/backend/pipeline/batch_analyze.py b/mineru/backend/pipeline/batch_analyze.py +index c88a52a3..b0b79a80 100644 +--- a/mineru/backend/pipeline/batch_analyze.py ++++ b/mineru/backend/pipeline/batch_analyze.py +@@ -3,6 +3,9 @@ from loguru import logger + from tqdm import tqdm + from collections import defaultdict + import numpy as np ++import time ++import torch ++import torch_npu + + from .model_init import AtomModelSingleton + from ...utils.config_reader import get_formula_enable, get_table_enable +@@ -95,6 +98,7 @@ class BatchAnalyze: + }) + + # OCR检测处理 ++ from concurrent.futures import ThreadPoolExecutor, as_completed + if self.enable_ocr_det_batch: + # 批处理模式 - 按语言和分辨率分组 + # 收集所有需要OCR检测的裁剪图像 +@@ -139,79 +143,73 @@ class BatchAnalyze: + ) + + # 按分辨率分组并同时完成padding ++ stride = 64 + resolution_groups = defaultdict(list) + for crop_info in lang_crop_list: + cropped_img = crop_info[0] + h, w = cropped_img.shape[:2] + # 使用更大的分组容差,减少分组数量 + # 将尺寸标准化到32的倍数 +- normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数 +- normalized_w = ((w + 32) // 32) * 32 ++ normalized_h = ((h + stride) // stride) * stride # 向上取整到stride的倍数 ++ normalized_w = ((w + stride) // stride) * stride + group_key = (normalized_h, normalized_w) + resolution_groups[group_key].append(crop_info) + +- # 对每个分辨率组进行批处理 +- for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"): +- +- # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数) +- max_h = max(crop_info[0].shape[0] for crop_info in group_crops) +- max_w = max(crop_info[0].shape[1] for crop_info in group_crops) +- target_h = ((max_h + 32 - 1) // 32) * 32 +- target_w = ((max_w + 32 - 1) // 32) * 32 +- +- # 对所有图像进行padding到统一尺寸 +- batch_images = [] +- for crop_info in group_crops: +- img = crop_info[0] +- h, w = img.shape[:2] +- # 创建目标尺寸的白色背景 +- padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255 +- # 将原图像粘贴到左上角 +- padded_img[:h, :w] = img +- batch_images.append(padded_img) +- +- # 批处理检测 +- det_batch_size = min(len(batch_images), self.batch_ratio * OCR_DET_BASE_BATCH_SIZE) # 增加批处理大小 +- # logger.debug(f"OCR-det batch: {det_batch_size} images, target size: {target_h}x{target_w}") +- batch_results = ocr_model.text_detector.batch_predict(batch_images, det_batch_size) +- +- # 处理批处理结果 +- for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)): +- new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info +- +- if dt_boxes is not None and len(dt_boxes) > 0: +- # 直接应用原始OCR流程中的关键处理步骤 +- from mineru.utils.ocr_utils import ( +- merge_det_boxes, update_det_boxes, sorted_boxes +- ) + +- # 1. 排序检测框 +- if len(dt_boxes) > 0: +- dt_boxes_sorted = sorted_boxes(dt_boxes) +- else: +- dt_boxes_sorted = [] +- +- # 2. 合并相邻检测框 +- if dt_boxes_sorted: +- dt_boxes_merged = merge_det_boxes(dt_boxes_sorted) +- else: +- dt_boxes_merged = [] +- +- # 3. 根据公式位置更新检测框(关键步骤!) +- if dt_boxes_merged and adjusted_mfdetrec_res: +- dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res) +- else: +- dt_boxes_final = dt_boxes_merged +- +- # 构造OCR结果格式 +- ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final] +- +- if ocr_res: +- ocr_result_list = get_ocr_result_list( +- ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang +- ) +- +- ocr_res_list_dict['layout_res'].extend(ocr_result_list) ++ def _run_one_group_ocr(group_key, group_crops): ++ ++ max_h = max(ci[0].shape[0] for ci in group_crops) ++ max_w = max(ci[0].shape[1] for ci in group_crops) ++ target_h = ((max_h + stride - 1) // stride) * stride ++ target_w = ((max_w + stride - 1) // stride) * stride ++ ++ batch_images = [] ++ for ci in group_crops: ++ img = ci[0] ++ h, w = img.shape[:2] ++ padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255 ++ padded_img[:h, :w] = img ++ batch_images.append(padded_img) ++ ++ det_batch_size = min(len(batch_images), self.batch_ratio * OCR_DET_BASE_BATCH_SIZE) ++ ++ batch_results = ocr_model.text_detector.batch_predict(batch_images, det_batch_size) ++ ++ for i, (ci, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)): ++ new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = ci ++ if dt_boxes is not None and len(dt_boxes) > 0: ++ from mineru.utils.ocr_utils import merge_det_boxes, update_det_boxes, sorted_boxes ++ ++ if len(dt_boxes) > 0: ++ dt_boxes_sorted = sorted_boxes(dt_boxes) ++ else: ++ dt_boxes_sorted = [] ++ ++ if dt_boxes_sorted: ++ dt_boxes_merged = merge_det_boxes(dt_boxes_sorted) ++ else: ++ dt_boxes_merged = [] ++ ++ if dt_boxes_merged and adjusted_mfdetrec_res: ++ dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res) ++ else: ++ dt_boxes_final = dt_boxes_merged ++ ++ ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final] ++ if ocr_res: ++ ocr_result_list = get_ocr_result_list( ++ ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang ++ ) ++ ocr_res_list_dict['layout_res'].extend(ocr_result_list) ++ ++ MAX_WORKERS = 4 ++ start = time.time() ++ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex: ++ futures = [ex.submit(_run_one_group_ocr, gk, gcs) for gk, gcs in resolution_groups.items()] ++ for f in as_completed(futures): ++ f.result() ++ end = time.time() ++ logger.info(f"ocr det run time : {end -start}") + else: + # 原始单张处理模式 + for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"): +@@ -247,7 +245,7 @@ class BatchAnalyze: + + # 表格识别 table recognition + if self.table_enable: +- for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"): ++ def _run_one_group_table(table_res_dict): + _lang = table_res_dict['lang'] + table_model = atom_model_manager.get_atom_model( + atom_model_name='table', +@@ -271,6 +269,16 @@ class BatchAnalyze: + 'table recognition processing fails, not get html return' + ) + ++ ++ MAX_WORKERS = 4 ++ start = time.time() ++ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex: ++ futures = [ex.submit(_run_one_group_table, table_res_dict) for table_res_dict in table_res_list_all_page] ++ for f in as_completed(futures): ++ f.result() ++ end = time.time() ++ logger.info(f"table run time : {end - start}") ++ + # Create dictionaries to store items by language + need_ocr_lists_by_lang = {} # Dict of lists for each language + img_crop_lists_by_lang = {} # Dict of lists for each language +diff --git a/mineru/model/layout/doclayout_yolo.py b/mineru/model/layout/doclayout_yolo.py +index 5667a909..fc5056bb 100644 +--- a/mineru/model/layout/doclayout_yolo.py ++++ b/mineru/model/layout/doclayout_yolo.py +@@ -66,6 +66,7 @@ class DocLayoutYOLOModel: + conf=self.conf, + iou=self.iou, + verbose=False, ++ half=True + ) + for pred in predictions: + results.append(self._parse_prediction(pred)) +diff --git a/mineru/model/mfd/yolo_v8.py b/mineru/model/mfd/yolo_v8.py +index 33dac091..1fb4b50e 100644 +--- a/mineru/model/mfd/yolo_v8.py ++++ b/mineru/model/mfd/yolo_v8.py +@@ -31,7 +31,8 @@ class YOLOv8MFDModel: + conf=self.conf, + iou=self.iou, + verbose=False, +- device=self.device ++ device=self.device, ++ half=True + ) + return [pred.cpu() for pred in preds] if is_batch else preds[0].cpu() + +diff --git a/mineru/model/mfr/unimernet/Unimernet.py b/mineru/model/mfr/unimernet/Unimernet.py +index ae3879da..23e56f2a 100644 +--- a/mineru/model/mfr/unimernet/Unimernet.py ++++ b/mineru/model/mfr/unimernet/Unimernet.py +@@ -1,7 +1,7 @@ + import torch + from torch.utils.data import DataLoader, Dataset + from tqdm import tqdm +- ++import numpy as np + + class MathDataset(Dataset): + def __init__(self, image_paths, transform=None): +@@ -61,7 +61,7 @@ class UnimernetModel(object): + res["latex"] = latex + return formula_list + +- def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list: ++ def _batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list: + images_formula_list = [] + mf_image_list = [] + backfill_list = [] +@@ -137,3 +137,94 @@ class UnimernetModel(object): + res["latex"] = latex + + return images_formula_list ++ ++ ++ def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list: ++ ++ images_formula_list = [] ++ mf_image_list = [] ++ backfill_list = [] ++ image_info = [] # Store (area, original_index, image) tuples ++ ++ # Collect images with their original indices ++ for image_index in range(len(images_mfd_res)): ++ mfd_res = images_mfd_res[image_index] ++ pil_img = images[image_index] ++ # split代替多次索引 ++ data = mfd_res.boxes.data.numpy() ++ xyxy, conf, cla = np.split(data, [4, 5], axis=-1) ++ ++ cla = cla.reshape(-1).astype(int).tolist() ++ conf = np.round(conf.reshape(-1).astype(float), 2).tolist() ++ ++ xyxy = xyxy.astype(np.int32) ++ xmin, ymin, xmax, ymax = xyxy[:, 0], xyxy[:, 1], xyxy[:, 2], xyxy[:, 3] ++ # area 直接矩阵运算 ++ areas = (xmax - xmin) * (ymax - ymin) ++ ++ num_boxes = len(conf) ++ ++ formula_list = [] ++ for i in range(num_boxes): ++ xmin_i, ymin_i, xmax_i, ymax_i = xyxy[i].tolist() ++ formula_list.append({ ++ "category_id": 13 + cla[i], ++ "poly": [xmin_i, ymin_i, xmax_i, ymin_i, ++ xmax_i, ymax_i, xmin_i, ymax_i], ++ "score": conf[i], ++ "latex": "", ++ }) ++ ++ # bbox_img 截取 ++ # bbox_img = pil_img[:, ymin_i:ymax_i, xmin_i:xmax_i] ++ bbox_img = pil_img.crop((xmin_i, ymin_i, xmax_i, ymax_i)) ++ curr_idx = len(mf_image_list) ++ image_info.append((areas[i], curr_idx, bbox_img)) ++ mf_image_list.append(bbox_img) ++ ++ images_formula_list.append(formula_list) ++ backfill_list += formula_list ++ ++ # Stable sort by area ++ image_info.sort(key=lambda x: x[0]) # sort by area ++ sorted_indices = [x[1] for x in image_info] ++ sorted_images = [x[2] for x in image_info] ++ ++ # Create mapping for results ++ index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)} ++ ++ # Create dataset with sorted images ++ dataset = MathDataset(sorted_images, transform=self.model.transform) ++ ++ # 如果batch_size > len(sorted_images),则设置为不超过len(sorted_images)的2的幂 ++ batch_size = min(batch_size, max(1, 2 ** (len(sorted_images).bit_length() - 1))) if sorted_images else 1 ++ ++ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0) ++ ++ # Process batches and store results ++ mfr_res = [] ++ # for mf_img in dataloader: ++ ++ with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar: ++ for index, mf_img in enumerate(dataloader): ++ mf_img = mf_img.to(dtype=self.model.dtype) ++ mf_img = mf_img.to(self.device) ++ with torch.no_grad(): ++ output = self.model.generate({"image": mf_img}, batch_size=batch_size) ++ mfr_res.extend(output["fixed_str"]) ++ ++ # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size ++ current_batch_size = min(batch_size, len(sorted_images) - index * batch_size) ++ pbar.update(current_batch_size) ++ ++ # Restore original order ++ unsorted_results = [""] * len(mfr_res) ++ for new_idx, latex in enumerate(mfr_res): ++ original_idx = index_mapping[new_idx] ++ unsorted_results[original_idx] = latex ++ ++ # Fill results back ++ for res, latex in zip(backfill_list, unsorted_results): ++ res["latex"] = latex ++ ++ return images_formula_list +diff --git a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +index 98d1deee..3866a257 100644 +--- a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py ++++ b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +@@ -5,7 +5,9 @@ import cv2 + import albumentations as alb + from albumentations.pytorch import ToTensorV2 + from torchvision.transforms.functional import resize +- ++import torch ++import torch_npu ++import torch.nn.functional as F + + # TODO: dereference cv2 if possible + class UnimerSwinImageProcessor(BaseImageProcessor): +@@ -25,10 +27,53 @@ class UnimerSwinImageProcessor(BaseImageProcessor): + ] + ) + +- def __call__(self, item): ++ self.NORMALIZE_DIVISOR = torch.tensor(255.0, dtype=torch.float16, device="npu") ++ self.weights = torch.tensor([[[0.2989]], [[0.5870]], [[0.1140]]], dtype=torch.float16, device="npu") ++ self.mean = torch.tensor(0.7931, dtype=torch.float16, device="npu") ++ self.std = torch.tensor(0.1738, dtype=torch.float16, device="npu") ++ ++ self._mul_buf = torch.empty((3, *self.input_size), dtype=torch.float16, device="npu") # 预分配 [3,H,W] ++ self._gray_buf = torch.empty((1, *self.input_size), dtype=torch.float16, device="npu") # 预分配 [1,H,W] ++ ++ ++ def ___call__(self, item): + image = self.prepare_input(item) + return self.transform(image=image)['image'][:1] + ++ def pil_to_npu(self, pil_img, device="npu"): ++ img = torch.from_numpy(np.asarray(pil_img, dtype=np.float16)) ++ img = img.to(device).permute(2, 0, 1) / self.NORMALIZE_DIVISOR ++ return img ++ ++ def __call__(self, item): ++ ++ img = self.crop_margin(item) ++ img = self.pil_to_npu(img) ++ ++ _, h, w = img.shape ++ target_h, target_w = self.input_size ++ scale = min(target_h / h, target_w / w) ++ new_h, new_w = int(h*scale), int(w*scale) ++ ++ img = img.view(1, *img.shape) # [1,C,H,W] ++ img = F.interpolate(img, size=(new_h, new_w), mode='bilinear', align_corners=False) ++ img = img.view(*img.shape[1:]) ++ ++ dw, dh = target_w - new_w, target_h - new_h ++ dw /= 2 ++ dh /= 2 ++ left, right = int(dw), int(dw + 0.5) ++ top, bottom = int(dh), int(dh + 0.5) ++ img = F.pad(img, (left, right, top, bottom), value=0.0) ++ ++ # RGB -> Gray ++ gray_tensor = (img * self.weights).sum(dim=0, keepdim=True) # [1, H, W] ++ ++ # Normalize ++ gray_tensor.sub_(self.mean).div_(self.std) ++ return gray_tensor ++ ++ + @staticmethod + def crop_margin(img: Image.Image) -> Image.Image: + data = np.array(img.convert("L")) +@@ -44,6 +89,32 @@ class UnimerSwinImageProcessor(BaseImageProcessor): + a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box + return img.crop((a, b, w + a, h + b)) + ++ def crop_margin_tensor(self, img): ++ """ ++ img: [C,H,W] tensor, uint8 或 float ++ """ ++ ++ gray = (img * self.weights).sum(dim=0) ++ ++ gray = gray.to(torch.uint8) ++ max_val = gray.max() ++ min_val = gray.min() ++ ++ if max_val == min_val: ++ return img ++ ++ norm_gray = (gray - min_val) / (max_val - min_val) ++ ++ mask = (norm_gray < self.threshold) ++ ++ coords = mask.nonzero(as_tuple=False) ++ if coords.shape[0] == 0: ++ return img ++ ymin, xmin = coords.min(0)[0] ++ ymax, xmax = coords.max(0)[0] ++ ++ return img[:, ymin:ymax+1, xmin:xmax+1] ++ + @staticmethod + def crop_margin_numpy(img: np.ndarray) -> np.ndarray: + """Crop margins of image using NumPy operations""" +diff --git a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +index 1b808e8b..0fe54751 100644 +--- a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py ++++ b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +@@ -465,11 +465,15 @@ class UnimerSwinSelfAttention(nn.Module): + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + batch_size, dim, num_channels = hidden_states.shape +- mixed_query_layer = self.query(hidden_states) + +- key_layer = self.transpose_for_scores(self.key(hidden_states)) +- value_layer = self.transpose_for_scores(self.value(hidden_states)) +- query_layer = self.transpose_for_scores(mixed_query_layer) ++ # """融合qk为大矩阵,由于加入相对位置编码,PFA接口用不了,暂时只修改矩阵乘法""" ++ batch_size, dim, num_channels = hidden_states.shape ++ qkv = self.qkv(hidden_states) ++ q, k, v = qkv.chunk(3, dim=-1) ++ ++ query_layer = q.view(*q.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) ++ key_layer = k.view(*k.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) ++ value_layer = v.view(*v.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) +diff --git a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py +index 3de483ac..23813db9 100755 +--- a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py ++++ b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py +@@ -117,6 +117,10 @@ class TextDetector(BaseOCRV20): + self.net.eval() + self.net.to(self.device) + ++ ++ import threading ++ self._dev_lock = getattr(self, "_dev_lock", threading.Lock()) ++ + def _batch_process_same_size(self, img_list): + """ + 对相同尺寸的图像进行批处理 +@@ -162,12 +166,12 @@ class TextDetector(BaseOCRV20): + return batch_results, time.time() - starttime + + # 批处理推理 +- with torch.no_grad(): +- inp = torch.from_numpy(batch_tensor) +- inp = inp.to(self.device) +- outputs = self.net(inp) +- +- # 处理输出 ++ with self._dev_lock: ++ with torch.no_grad(): ++ inp = torch.from_numpy(batch_tensor) ++ inp = inp.to(self.device) ++ outputs = self.net(inp) ++ # 处理输出 + preds = {} + if self.det_algorithm == "EAST": + preds['f_geo'] = outputs['f_geo'].cpu().numpy() +@@ -304,10 +308,11 @@ class TextDetector(BaseOCRV20): + img = img.copy() + starttime = time.time() + +- with torch.no_grad(): +- inp = torch.from_numpy(img) +- inp = inp.to(self.device) +- outputs = self.net(inp) ++ with self._dev_lock: ++ with torch.no_grad(): ++ inp = torch.from_numpy(img) ++ inp = inp.to(self.device) ++ outputs = self.net(inp) + + preds = {} + if self.det_algorithm == "EAST": +diff --git a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +index c06ca5fe..d865b201 100755 +--- a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py ++++ b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +@@ -94,6 +94,9 @@ class TextRecognizer(BaseOCRV20): + self.net.eval() + self.net.to(self.device) + ++ import threading ++ self._dev_lock = getattr(self, "_dev_lock", threading.Lock()) ++ + def resize_norm_img(self, img, max_wh_ratio): + imgC, imgH, imgW = self.rec_image_shape + if self.rec_algorithm == 'NRTR' or self.rec_algorithm == 'ViTSTR': +@@ -301,74 +304,78 @@ class TextRecognizer(BaseOCRV20): + rec_res = [['', 0.0]] * img_num + batch_num = self.rec_batch_num + elapse = 0 +- # for beg_img_no in range(0, img_num, batch_num): +- with tqdm(total=img_num, desc='OCR-rec Predict', disable=not tqdm_enable) as pbar: +- index = 0 +- for beg_img_no in range(0, img_num, batch_num): +- end_img_no = min(img_num, beg_img_no + batch_num) +- norm_img_batch = [] +- max_wh_ratio = 0 +- for ino in range(beg_img_no, end_img_no): +- # h, w = img_list[ino].shape[0:2] +- h, w = img_list[indices[ino]].shape[0:2] +- wh_ratio = w * 1.0 / h +- max_wh_ratio = max(max_wh_ratio, wh_ratio) +- for ino in range(beg_img_no, end_img_no): +- if self.rec_algorithm == "SAR": +- norm_img, _, _, valid_ratio = self.resize_norm_img_sar( +- img_list[indices[ino]], self.rec_image_shape) +- norm_img = norm_img[np.newaxis, :] +- valid_ratio = np.expand_dims(valid_ratio, axis=0) +- valid_ratios = [] +- valid_ratios.append(valid_ratio) +- norm_img_batch.append(norm_img) +- +- elif self.rec_algorithm == "SVTR": +- norm_img = self.resize_norm_img_svtr(img_list[indices[ino]], +- self.rec_image_shape) +- norm_img = norm_img[np.newaxis, :] +- norm_img_batch.append(norm_img) +- elif self.rec_algorithm == "SRN": +- norm_img = self.process_image_srn(img_list[indices[ino]], +- self.rec_image_shape, 8, +- self.max_text_length) +- encoder_word_pos_list = [] +- gsrm_word_pos_list = [] +- gsrm_slf_attn_bias1_list = [] +- gsrm_slf_attn_bias2_list = [] +- encoder_word_pos_list.append(norm_img[1]) +- gsrm_word_pos_list.append(norm_img[2]) +- gsrm_slf_attn_bias1_list.append(norm_img[3]) +- gsrm_slf_attn_bias2_list.append(norm_img[4]) +- norm_img_batch.append(norm_img[0]) +- elif self.rec_algorithm == "CAN": +- norm_img = self.norm_img_can(img_list[indices[ino]], +- max_wh_ratio) +- norm_img = norm_img[np.newaxis, :] +- norm_img_batch.append(norm_img) +- norm_image_mask = np.ones(norm_img.shape, dtype='float32') +- word_label = np.ones([1, 36], dtype='int64') +- norm_img_mask_batch = [] +- word_label_list = [] +- norm_img_mask_batch.append(norm_image_mask) +- word_label_list.append(word_label) +- else: +- norm_img = self.resize_norm_img(img_list[indices[ino]], +- max_wh_ratio) +- norm_img = norm_img[np.newaxis, :] +- norm_img_batch.append(norm_img) +- norm_img_batch = np.concatenate(norm_img_batch) +- norm_img_batch = norm_img_batch.copy() +- +- if self.rec_algorithm == "SRN": +- starttime = time.time() +- encoder_word_pos_list = np.concatenate(encoder_word_pos_list) +- gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) +- gsrm_slf_attn_bias1_list = np.concatenate( +- gsrm_slf_attn_bias1_list) +- gsrm_slf_attn_bias2_list = np.concatenate( +- gsrm_slf_attn_bias2_list) + ++ # for beg_img_no in range(0, img_num, batch_num): ++ from concurrent.futures import ThreadPoolExecutor, as_completed ++ def _rec_batch_worker(beg_img_no: int, end_img_no: int): ++ ++ ++ max_wh_ratio = 0.0 ++ norm_img_batch = [] ++ for ino in range(beg_img_no, end_img_no): ++ # h, w = img_list[ino].shape[0:2] ++ h, w = img_list[indices[ino]].shape[0:2] ++ wh_ratio = w * 1.0 / h ++ max_wh_ratio = max(max_wh_ratio, wh_ratio) ++ for ino in range(beg_img_no, end_img_no): ++ if self.rec_algorithm == "SAR": ++ norm_img, _, _, valid_ratio = self.resize_norm_img_sar( ++ img_list[indices[ino]], self.rec_image_shape) ++ norm_img = norm_img[np.newaxis, :] ++ valid_ratio = np.expand_dims(valid_ratio, axis=0) ++ valid_ratios = [] ++ valid_ratios.append(valid_ratio) ++ norm_img_batch.append(norm_img) ++ ++ elif self.rec_algorithm == "SVTR": ++ norm_img = self.resize_norm_img_svtr(img_list[indices[ino]], ++ self.rec_image_shape) ++ norm_img = norm_img[np.newaxis, :] ++ norm_img_batch.append(norm_img) ++ elif self.rec_algorithm == "SRN": ++ norm_img = self.process_image_srn(img_list[indices[ino]], ++ self.rec_image_shape, 8, ++ self.max_text_length) ++ encoder_word_pos_list = [] ++ gsrm_word_pos_list = [] ++ gsrm_slf_attn_bias1_list = [] ++ gsrm_slf_attn_bias2_list = [] ++ encoder_word_pos_list.append(norm_img[1]) ++ gsrm_word_pos_list.append(norm_img[2]) ++ gsrm_slf_attn_bias1_list.append(norm_img[3]) ++ gsrm_slf_attn_bias2_list.append(norm_img[4]) ++ norm_img_batch.append(norm_img[0]) ++ elif self.rec_algorithm == "CAN": ++ norm_img = self.norm_img_can(img_list[indices[ino]], ++ max_wh_ratio) ++ norm_img = norm_img[np.newaxis, :] ++ norm_img_batch.append(norm_img) ++ norm_image_mask = np.ones(norm_img.shape, dtype='float32') ++ word_label = np.ones([1, 36], dtype='int64') ++ norm_img_mask_batch = [] ++ word_label_list = [] ++ norm_img_mask_batch.append(norm_image_mask) ++ word_label_list.append(word_label) ++ else: ++ norm_img = self.resize_norm_img(img_list[indices[ino]], ++ max_wh_ratio) ++ norm_img = norm_img[np.newaxis, :] ++ norm_img_batch.append(norm_img) ++ norm_img_batch = np.concatenate(norm_img_batch) ++ norm_img_batch = norm_img_batch.copy() ++ ++ starttime = time.time() ++ ++ if self.rec_algorithm == "SRN": ++ starttime = time.time() ++ encoder_word_pos_list = np.concatenate(encoder_word_pos_list) ++ gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) ++ gsrm_slf_attn_bias1_list = np.concatenate( ++ gsrm_slf_attn_bias1_list) ++ gsrm_slf_attn_bias2_list = np.concatenate( ++ gsrm_slf_attn_bias2_list) ++ ++ with self._dev_lock: + with torch.no_grad(): + inp = torch.from_numpy(norm_img_batch) + encoder_word_pos_inp = torch.from_numpy(encoder_word_pos_list) +@@ -384,58 +391,67 @@ class TextRecognizer(BaseOCRV20): + + backbone_out = self.net.backbone(inp) # backbone_feat + prob_out = self.net.head(backbone_out, [encoder_word_pos_inp, gsrm_word_pos_inp, gsrm_slf_attn_bias1_inp, gsrm_slf_attn_bias2_inp]) +- # preds = {"predict": prob_out[2]} +- preds = {"predict": prob_out["predict"]} +- +- elif self.rec_algorithm == "SAR": +- starttime = time.time() +- # valid_ratios = np.concatenate(valid_ratios) +- # inputs = [ +- # norm_img_batch, +- # valid_ratios, +- # ] +- ++ # preds = {"predict": prob_out[2]} ++ preds = {"predict": prob_out["predict"]} ++ ++ elif self.rec_algorithm == "SAR": ++ starttime = time.time() ++ # valid_ratios = np.concatenate(valid_ratios) ++ # inputs = [ ++ # norm_img_batch, ++ # valid_ratios, ++ # ] ++ ++ with self._dev_lock: + with torch.no_grad(): + inp = torch.from_numpy(norm_img_batch) + inp = inp.to(self.device) + preds = self.net(inp) + +- elif self.rec_algorithm == "CAN": +- starttime = time.time() +- norm_img_mask_batch = np.concatenate(norm_img_mask_batch) +- word_label_list = np.concatenate(word_label_list) +- inputs = [norm_img_batch, norm_img_mask_batch, word_label_list] ++ elif self.rec_algorithm == "CAN": ++ starttime = time.time() ++ norm_img_mask_batch = np.concatenate(norm_img_mask_batch) ++ word_label_list = np.concatenate(word_label_list) ++ inputs = [norm_img_batch, norm_img_mask_batch, word_label_list] + +- inp = [torch.from_numpy(e_i) for e_i in inputs] +- inp = [e_i.to(self.device) for e_i in inp] ++ inp = [torch.from_numpy(e_i) for e_i in inputs] ++ inp = [e_i.to(self.device) for e_i in inp] ++ with self._dev_lock: + with torch.no_grad(): + outputs = self.net(inp) + outputs = [v.cpu().numpy() for k, v in enumerate(outputs)] + +- preds = outputs +- +- else: +- starttime = time.time() ++ preds = outputs + ++ else: ++ with self._dev_lock: + with torch.no_grad(): +- inp = torch.from_numpy(norm_img_batch) +- inp = inp.to(self.device) ++ inp = torch.from_numpy(norm_img_batch).to(self.device) + prob_out = self.net(inp) ++ preds = [v.cpu().numpy() for v in prob_out] if isinstance(prob_out, list) else prob_out.cpu().numpy() + +- if isinstance(prob_out, list): +- preds = [v.cpu().numpy() for v in prob_out] +- else: +- preds = prob_out.cpu().numpy() ++ rec_result = self.postprocess_op(preds) + +- rec_result = self.postprocess_op(preds) +- for rno in range(len(rec_result)): +- rec_res[indices[beg_img_no + rno]] = rec_result[rno] +- elapse += time.time() - starttime ++ for rno in range(len(rec_result)): ++ global_idx = indices[beg_img_no + rno] ++ rec_res[global_idx] = rec_result[rno] ++ ++ batch_elapse = time.time() - starttime ++ return len(rec_result), batch_elapse ++ ++ MAX_WORKERS = 4 ++ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex, \ ++ tqdm(total=img_num, desc='OCR-rec Predict', disable=not tqdm_enable) as pbar: ++ ++ futures = [] ++ for beg_img_no in range(0, img_num, batch_num): ++ end_img_no = min(img_num, beg_img_no + batch_num) ++ futures.append(ex.submit(_rec_batch_worker, beg_img_no, end_img_no)) + +- # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size +- current_batch_size = min(batch_num, img_num - index * batch_num) +- index += 1 +- pbar.update(current_batch_size) ++ for fut in as_completed(futures): ++ n_done, batch_elapse = fut.result() ++ elapse += batch_elapse ++ pbar.update(n_done) + + # Fix NaN values in recognition results + for i in range(len(rec_res)): +diff --git a/mineru/model/table/rapid_table.py b/mineru/model/table/rapid_table.py +index 174a8052..dd796bcc 100644 +--- a/mineru/model/table/rapid_table.py ++++ b/mineru/model/table/rapid_table.py +@@ -21,6 +21,8 @@ class RapidTableModel(object): + self.table_model = RapidTable(input_args) + self.ocr_engine = ocr_engine + ++ import threading ++ self._dev_lock = getattr(self, "_dev_lock", threading.Lock()) + + def predict(self, image): + bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) +@@ -30,44 +32,45 @@ class RapidTableModel(object): + img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0 + img_is_portrait = img_aspect_ratio > 1.2 + +- if img_is_portrait: ++ with self._dev_lock: ++ if img_is_portrait: + +- det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0] +- # Check if table is rotated by analyzing text box aspect ratios +- is_rotated = False +- if det_res: +- vertical_count = 0 ++ det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0] ++ # Check if table is rotated by analyzing text box aspect ratios ++ is_rotated = False ++ if det_res: ++ vertical_count = 0 + +- for box_ocr_res in det_res: +- p1, p2, p3, p4 = box_ocr_res ++ for box_ocr_res in det_res: ++ p1, p2, p3, p4 = box_ocr_res + +- # Calculate width and height +- width = p3[0] - p1[0] +- height = p3[1] - p1[1] ++ # Calculate width and height ++ width = p3[0] - p1[0] ++ height = p3[1] - p1[1] + +- aspect_ratio = width / height if height > 0 else 1.0 ++ aspect_ratio = width / height if height > 0 else 1.0 + +- # Count vertical vs horizontal text boxes +- if aspect_ratio < 0.8: # Taller than wide - vertical text +- vertical_count += 1 +- # elif aspect_ratio > 1.2: # Wider than tall - horizontal text +- # horizontal_count += 1 ++ # Count vertical vs horizontal text boxes ++ if aspect_ratio < 0.8: # Taller than wide - vertical text ++ vertical_count += 1 ++ # elif aspect_ratio > 1.2: # Wider than tall - horizontal text ++ # horizontal_count += 1 + +- # If we have more vertical text boxes than horizontal ones, +- # and vertical ones are significant, table might be rotated +- if vertical_count >= len(det_res) * 0.3: +- is_rotated = True ++ # If we have more vertical text boxes than horizontal ones, ++ # and vertical ones are significant, table might be rotated ++ if vertical_count >= len(det_res) * 0.3: ++ is_rotated = True + +- # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}") ++ # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}") + +- # Rotate image if necessary +- if is_rotated: +- # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise") +- image = cv2.rotate(np.asarray(image), cv2.ROTATE_90_CLOCKWISE) +- bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) ++ # Rotate image if necessary ++ if is_rotated: ++ # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise") ++ image = cv2.rotate(np.asarray(image), cv2.ROTATE_90_CLOCKWISE) ++ bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + +- # Continue with OCR on potentially rotated image +- ocr_result = self.ocr_engine.ocr(bgr_image)[0] ++ # Continue with OCR on potentially rotated image ++ ocr_result = self.ocr_engine.ocr(bgr_image)[0] + if ocr_result: + ocr_result = [[item[0], escape_html(item[1][0]), item[1][1]] for item in ocr_result if + len(item) == 2 and isinstance(item[1], tuple)] +