From 01b28dad9144340baf848bc88e813033892ac303 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Mon, 13 Oct 2025 11:09:38 +0800 Subject: [PATCH 01/20] fix whisper metric --- ACL_PyTorch/built-in/audio/whisper/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ACL_PyTorch/built-in/audio/whisper/README.md b/ACL_PyTorch/built-in/audio/whisper/README.md index bba8f63259..024d54530d 100644 --- a/ACL_PyTorch/built-in/audio/whisper/README.md +++ b/ACL_PyTorch/built-in/audio/whisper/README.md @@ -101,4 +101,4 @@ warmup结束之后,开始推理librispeech_asr_dummy数据集,推理过程 | 模型 | 芯片 | 平均E2E时间 | WER | |---------|------------|----------|-------| - | whisper base | 800I A2 64G | 71.73ms | 8.21% | + | whisper base | 800I A2 64G | 65.82ms | 8.21% | -- Gitee From cee6c09a38d5423b9071fe028529855722bc0cfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Mon, 13 Oct 2025 11:29:52 +0800 Subject: [PATCH 02/20] 1 --- ACL_PyTorch/built-in/audio/whisper/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ACL_PyTorch/built-in/audio/whisper/README.md b/ACL_PyTorch/built-in/audio/whisper/README.md index 024d54530d..fd6a094b9b 100644 --- a/ACL_PyTorch/built-in/audio/whisper/README.md +++ b/ACL_PyTorch/built-in/audio/whisper/README.md @@ -101,4 +101,4 @@ warmup结束之后,开始推理librispeech_asr_dummy数据集,推理过程 | 模型 | 芯片 | 平均E2E时间 | WER | |---------|------------|----------|-------| - | whisper base | 800I A2 64G | 65.82ms | 8.21% | + | whisper base | 800I A2 64G | 67.68ms | 8.21% | -- Gitee From fea13f9c0eac0ab31de5da499c8bad4f4e535d78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Tue, 14 Oct 2025 09:34:19 +0800 Subject: [PATCH 03/20] fix patch bug --- ACL_PyTorch/built-in/ocr/MinerU/README.md | 10 +++++++--- ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch | 3 ++- ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch | 3 ++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/ACL_PyTorch/built-in/ocr/MinerU/README.md b/ACL_PyTorch/built-in/ocr/MinerU/README.md index cdf01eb0ea..d4306ded69 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/README.md +++ b/ACL_PyTorch/built-in/ocr/MinerU/README.md @@ -154,7 +154,8 @@ python3 infer.py --data_path=OmniDocBench_dataset --model_source=local 使用`OmniDocBench`数据集配套评测代码测试精度。 1. 推理结果整理 -将解析结果文件夹中的markdown文件整理放置于同一目录,本例将所有markdown文件存放于OmniDocBench_dataset目录下的results_md文件夹 + + 将解析结果文件夹中的markdown文件整理放置于同一目录,本例将所有markdown文件存放于OmniDocBench_dataset目录下的results_md文件夹 ``` cp OmniDocBench_dataset/output/*/auto/*.md OmniDocBench_dataset/results_md/ ``` @@ -164,13 +165,15 @@ python3 infer.py --data_path=OmniDocBench_dataset --model_source=local ``` git clone https://github.com/opendatalab/OmniDocBench.git cd OmniDocBench + git reset --hard dc96d812d219960773399c02ae8f89e4706120d4 conda create -n omnidocbench python=3.10 conda activate omnidocbench pip install -r requirements.txt ``` 3. 测评配置修改 -修改`OmniDocBench`测评代码中的config文件,具体来说,我们使用端到端测评配置,修改configs/end2end.yaml文件中的ground_truth的data_path为下载的OmniDocBench.json路径,修改prediction的data_path中提供整理的推理结果的文件夹路径,如下: + + 修改`OmniDocBench`测评代码中的config文件,具体来说,我们使用端到端测评配置,修改configs/end2end.yaml文件中的ground_truth的data_path为下载的OmniDocBench.json路径,修改prediction的data_path中提供整理的推理结果的文件夹路径,如下: ``` # -----以下是需要修改的部分 ----- dataset: @@ -182,7 +185,8 @@ python3 infer.py --data_path=OmniDocBench_dataset --model_source=local ``` 4. 精度测量结果 -配置好config文件后,只需要将config文件作为参数传入,运行以下代码即可进行评测: + + 配置好config文件后,只需要将config文件作为参数传入,运行以下代码即可进行评测: ``` python pdf_validation.py --config ./configs/end2end.yaml ``` diff --git a/ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch b/ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch index 7cf22c0b32..9b526e909e 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch +++ b/ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch @@ -47,4 +47,5 @@ diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/utils/tal.py doclayout_yolo-0.0.4_ - stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) + # stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) + stride_tensor.append(torch.ones((h * w, 1), dtype=dtype, device=device)*stride) - return torch.cat(anchor_points), torch.cat(stride_tensor) \ No newline at end of file + return torch.cat(anchor_points), torch.cat(stride_tensor) + \ No newline at end of file diff --git a/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch b/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch index 4fab87d605..70feba449f 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch +++ b/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch @@ -74,4 +74,5 @@ diff -ruN ultralytics-8.3.193/ultralytics/utils/tal.py ultralytics_/ultralytics/ - stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) + # stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) + stride_tensor.append(torch.ones((h * w, 1), dtype=dtype, device=device)*stride) - return torch.cat(anchor_points), torch.cat(stride_tensor) \ No newline at end of file + return torch.cat(anchor_points), torch.cat(stride_tensor) + \ No newline at end of file -- Gitee From 9e6c195bb87f9b7da748d0617173aa3753495b58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Tue, 14 Oct 2025 10:23:29 +0800 Subject: [PATCH 04/20] 1 --- ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch | 1 - ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch | 1 - 2 files changed, 2 deletions(-) diff --git a/ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch b/ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch index 9b526e909e..b5fd6669aa 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch +++ b/ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch @@ -48,4 +48,3 @@ diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/utils/tal.py doclayout_yolo-0.0.4_ + # stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) + stride_tensor.append(torch.ones((h * w, 1), dtype=dtype, device=device)*stride) return torch.cat(anchor_points), torch.cat(stride_tensor) - \ No newline at end of file diff --git a/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch b/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch index 70feba449f..5511fa6a9e 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch +++ b/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch @@ -75,4 +75,3 @@ diff -ruN ultralytics-8.3.193/ultralytics/utils/tal.py ultralytics_/ultralytics/ + # stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) + stride_tensor.append(torch.ones((h * w, 1), dtype=dtype, device=device)*stride) return torch.cat(anchor_points), torch.cat(stride_tensor) - \ No newline at end of file -- Gitee From 796693c60d4941f981ceb67402728c9ea7e8a611 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Tue, 14 Oct 2025 14:46:33 +0800 Subject: [PATCH 05/20] 1 --- ACL_PyTorch/built-in/ocr/MinerU/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ACL_PyTorch/built-in/ocr/MinerU/README.md b/ACL_PyTorch/built-in/ocr/MinerU/README.md index d4306ded69..af345076b8 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/README.md +++ b/ACL_PyTorch/built-in/ocr/MinerU/README.md @@ -181,7 +181,7 @@ python3 infer.py --data_path=OmniDocBench_dataset --model_source=local ground_truth: data_path: ../OmniDocBench_dataset/OmniDocBench.json prediction: - data_path: ../OmniDocBench_dataset/result_md + data_path: ../OmniDocBench_dataset/results_md ``` 4. 精度测量结果 -- Gitee From adbd35983673cf7084a4c27b8921f545f28e2602 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Tue, 14 Oct 2025 19:26:36 +0800 Subject: [PATCH 06/20] 1 --- ACL_PyTorch/README.md | 61 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/ACL_PyTorch/README.md b/ACL_PyTorch/README.md index 51cd1c0085..1651f4ee21 100755 --- a/ACL_PyTorch/README.md +++ b/ACL_PyTorch/README.md @@ -3,7 +3,7 @@

如何贡献

在开始贡献之前,请先阅读CONTRIBUTING。 谢谢!

-

目前ACL_PyTorch仓库已有模型398

+

目前ACL_PyTorch仓库已有模型401

注意:
在提交新模型时,请加上模型ID用于区分,为防止重复提交模型,请执行脚本get_modelID.py,该脚本会自动检索ACL_PyTorch仓库中所有与您提交模型相关的已有模型,请自行查看脚本给出的链接,如果均不同,则可以输入1或true用于获取模型ID。由于该脚本使用正则匹配,后续新模型刷新到主页需要添加README内容时,格式请参考其余模型,并且同步刷新上文模型数量。脚本执行方式如下:
@@ -4654,6 +4654,7 @@ python3 get_modelID.py --model your_model_name

ROC_AUC mel_loss 300I Pro + 800I A2 100313 @@ -4824,7 +4825,7 @@ python3 get_modelID.py --model your_model_name

多尺度 - 100321 + 100409 whisper @@ -4837,6 +4838,34 @@ python3 get_modelID.py --model your_model_name

67.32(bs1) bs x 80 x 3000 + + 100410 + + CosyVoice2 + + 代码仓提供 + + + + + 0.75 + 0.28 + 多尺度 + + + 100411 + + whisperx + + librispeech dev clean + + 0.050 + + + + 70(转录比) + 多尺度 +

Knowledge

@@ -5258,6 +5287,34 @@ python3 get_modelID.py --model your_model_name

多尺度 +

OCR

+ + + + + + + + + + + + + + + + + + + + + + +
IDNameDataset精度300I Pro最优性能(对应bs)输入shape
overall_ENoverall_CH
100408 + + MinerU-ocr + OmniDocBench0.15880.2527多尺度
+

RL

-- Gitee From d0340be1a65cd893edb92d09e2b20420d975fbf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Tue, 14 Oct 2025 19:31:17 +0800 Subject: [PATCH 07/20] 1 --- ACL_PyTorch/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ACL_PyTorch/README.md b/ACL_PyTorch/README.md index 1651f4ee21..2f625f52fb 100755 --- a/ACL_PyTorch/README.md +++ b/ACL_PyTorch/README.md @@ -5295,6 +5295,7 @@ python3 get_modelID.py --model your_model_name

+ @@ -5310,6 +5311,7 @@ python3 get_modelID.py --model your_model_name

+ -- Gitee From 028a03a664c391d9f5c719875c971b12225f03b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Sat, 18 Oct 2025 18:58:03 +0800 Subject: [PATCH 08/20] add performance optimization adaptation --- .../built-in/ocr/MinerU/doclayout_yolo.patch | 135 ++- .../ocr/MinerU/mfr_encoder_mhsa.patch | 23 - ACL_PyTorch/built-in/ocr/MinerU/mineru.patch | 859 ++++++++++++++++++ .../built-in/ocr/MinerU/ultralytics.patch | 134 ++- 4 files changed, 1107 insertions(+), 44 deletions(-) delete mode 100644 ACL_PyTorch/built-in/ocr/MinerU/mfr_encoder_mhsa.patch create mode 100644 ACL_PyTorch/built-in/ocr/MinerU/mineru.patch diff --git a/ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch b/ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch index b5fd6669aa..291a2914ab 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch +++ b/ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch @@ -1,7 +1,123 @@ -diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/engine/predictor.py doclayout_yolo-0.0.4_fix/doclayout_yolo/engine/predictor.py +diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/data/loaders.py doclayout_yolo/data/loaders.py +--- doclayout_yolo-0.0.4/doclayout_yolo/data/loaders.py 2025-02-11 15:49:31.000000000 +0800 ++++ doclayout_yolo/data/loaders.py 2025-10-19 01:27:41.984000000 +0800 +@@ -14,6 +14,7 @@ + import requests + import torch + from PIL import Image ++from torchvision.transforms import functional as TF + + from doclayout_yolo.data.utils import IMG_FORMATS, VID_FORMATS + from doclayout_yolo.utils import LOGGER, is_colab, is_kaggle, ops +@@ -411,7 +412,7 @@ + self.bs = len(self.im0) + + @staticmethod +- def _single_check(im): ++ def __single_check(im): ## origin _single_check + """Validate and format an image to numpy array.""" + assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}" + if isinstance(im, Image.Image): +@@ -419,6 +420,18 @@ + im = im.convert("RGB") + im = np.asarray(im)[:, :, ::-1] + im = np.ascontiguousarray(im) # contiguous ++ ++ return im ++ ++ @staticmethod ++ def _single_check(im): ++ """Validate and format an image to numpy array.""" ++ assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}" ++ if isinstance(im, Image.Image): ++ if im.mode != "RGB": ++ im = im.convert("RGB") ++ im = np.asarray(im) ++ + return im + + def __len__(self): +diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/engine/model.py doclayout_yolo/engine/model.py +--- doclayout_yolo-0.0.4/doclayout_yolo/engine/model.py 2025-02-11 15:49:31.000000000 +0800 ++++ doclayout_yolo/engine/model.py 2025-10-19 01:27:41.988000000 +0800 +@@ -143,6 +143,8 @@ + else: + self._load(model, task=task) + ++ self.model.half() ++ + def __call__( + self, + source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, +diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/engine/predictor.py doclayout_yolo/engine/predictor.py --- doclayout_yolo-0.0.4/doclayout_yolo/engine/predictor.py 2025-02-11 15:49:31.000000000 +0800 -+++ doclayout_yolo-0.0.4_fix/doclayout_yolo/engine/predictor.py 2025-09-09 16:05:20.011737230 +0800 -@@ -152,7 +152,8 @@ ++++ doclayout_yolo/engine/predictor.py 2025-10-19 01:27:41.988000000 +0800 +@@ -47,6 +47,8 @@ + from doclayout_yolo.utils.files import increment_path + from doclayout_yolo.utils.torch_utils import select_device, smart_inference_mode + ++import torch.nn.functional as F ++ + STREAM_WARNING = """ + WARNING ⚠️ inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory + errors for large sources or long-running streams and videos. See https://docs.doclayout_yolo.com/modes/predict/ for help. +@@ -112,7 +114,7 @@ + self._lock = threading.Lock() # for automatic thread-safe inference + callbacks.add_integration_callbacks(self) + +- def preprocess(self, im): ++ def _preprocess(self, im): ### origin preprocess + """ + Prepares input image before inference. + +@@ -132,6 +134,46 @@ + im /= 255 # 0 - 255 to 0.0 - 1.0 + return im + ++ ++ def preprocess(self, images): ### adapt preprocess ++ """ ++ Prepares input image before inference. ++ ++ Args: ++ images (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list. ++ """ ++ new_shape = (new_shape, new_shape) if isinstance(self.imgsz, int) else self.imgsz ++ tensors = [] ++ for im in images: ++ im = torch.from_numpy(im).to(self.device).permute((2, 0, 1)) / 255.0 ++ ++ c, h, w = im.shape ++ ++ r = min(new_shape[0] / h, new_shape[1] / w) ++ ++ new_unpad = (int(round(w * r)), int(round(h * r))) ++ ++ if (w, h) != new_unpad: ++ im = F.interpolate(im.unsqueeze(0), size=(new_unpad[1], new_unpad[0]), ++ mode="bilinear", align_corners=False).squeeze(0) ++ ++ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] ++ dw /= 2 ++ dh /= 2 ++ left, right = int(dw), int(dw + 0.5) ++ top, bottom = int(dh), int(dh + 0.5) ++ im = F.pad(im, (left, right, top, bottom), value=114/255.0) ++ ++ _, H, W = im.shape ++ assert (H, W) == (new_shape[0], new_shape[1]), f"Expected image size do not match: padding image size:{(H, W)} != expected image size: {(new_shape[0], new_shape[1])}" ++ ++ im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32 ++ ++ tensors.append(im) ++ ++ return torch.stack(tensors, dim=0) ++ ++ + def inference(self, im, *args, **kwargs): + """Runs inference on a given image using the specified model and arguments.""" + visualize = ( +@@ -152,7 +194,8 @@ (list): A list of transformed images. """ same_shapes = len({x.shape for x in im}) == 1 @@ -11,7 +127,7 @@ diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/engine/predictor.py doclayout_yolo return [letterbox(image=x) for x in im] def postprocess(self, preds, img, orig_imgs): -@@ -225,7 +226,8 @@ +@@ -225,7 +268,8 @@ # Warmup model if not self.done_warmup: @@ -21,10 +137,9 @@ diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/engine/predictor.py doclayout_yolo self.done_warmup = True self.seen, self.windows, self.batch = 0, [], None - -diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/nn/modules/block.py doclayout_yolo-0.0.4_fix/doclayout_yolo/nn/modules/block.py +diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/nn/modules/block.py doclayout_yolo/nn/modules/block.py --- doclayout_yolo-0.0.4/doclayout_yolo/nn/modules/block.py 2025-02-11 15:49:31.000000000 +0800 -+++ doclayout_yolo-0.0.4_fix/doclayout_yolo/nn/modules/block.py 2025-09-09 16:05:20.019737230 +0800 ++++ doclayout_yolo/nn/modules/block.py 2025-10-19 01:27:41.996000000 +0800 @@ -230,7 +230,9 @@ def forward(self, x): """Forward pass through C2f layer.""" @@ -36,10 +151,9 @@ diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/nn/modules/block.py doclayout_yolo return self.cv2(torch.cat(y, 1)) def forward_split(self, x): - -diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/utils/tal.py doclayout_yolo-0.0.4_fix/doclayout_yolo/utils/tal.py +diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/utils/tal.py doclayout_yolo/utils/tal.py --- doclayout_yolo-0.0.4/doclayout_yolo/utils/tal.py 2025-02-11 15:49:31.000000000 +0800 -+++ doclayout_yolo-0.0.4_fix/doclayout_yolo/utils/tal.py 2025-09-09 16:05:20.023737230 +0800 ++++ doclayout_yolo/utils/tal.py 2025-10-19 01:27:42.000000000 +0800 @@ -328,7 +328,8 @@ sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx) @@ -48,3 +162,4 @@ diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/utils/tal.py doclayout_yolo-0.0.4_ + # stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) + stride_tensor.append(torch.ones((h * w, 1), dtype=dtype, device=device)*stride) return torch.cat(anchor_points), torch.cat(stride_tensor) + diff --git a/ACL_PyTorch/built-in/ocr/MinerU/mfr_encoder_mhsa.patch b/ACL_PyTorch/built-in/ocr/MinerU/mfr_encoder_mhsa.patch deleted file mode 100644 index 1fe80a05cb..0000000000 --- a/ACL_PyTorch/built-in/ocr/MinerU/mfr_encoder_mhsa.patch +++ /dev/null @@ -1,23 +0,0 @@ ---- MinerU/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py 2025-09-02 17:58:15.032000000 +0800 -+++ copy_mfr.py 2025-09-10 13:58:36.616000000 +0800 -@@ -465,11 +465,15 @@ - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor]: - batch_size, dim, num_channels = hidden_states.shape -- mixed_query_layer = self.query(hidden_states) - -- key_layer = self.transpose_for_scores(self.key(hidden_states)) -- value_layer = self.transpose_for_scores(self.value(hidden_states)) -- query_layer = self.transpose_for_scores(mixed_query_layer) -+ # """融合qk为大矩阵,由于加入相对位置编码,PFA接口用不了,暂时只修改矩阵乘法""" -+ batch_size, dim, num_channels = hidden_states.shape -+ qkv = self.qkv(hidden_states) -+ q, k, v = qkv.chunk(3, dim=-1) -+ -+ query_layer = q.view(*q.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) -+ key_layer = k.view(*k.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) -+ value_layer = v.view(*v.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - diff --git a/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch b/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch new file mode 100644 index 0000000000..909a60a4a7 --- /dev/null +++ b/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch @@ -0,0 +1,859 @@ +diff --git a/mineru/backend/pipeline/batch_analyze.py b/mineru/backend/pipeline/batch_analyze.py +index c88a52a3..b0b79a80 +--- a/mineru/backend/pipeline/batch_analyze.py ++++ b/mineru/backend/pipeline/batch_analyze.py +@@ -3,6 +3,9 @@ from loguru import logger + from tqdm import tqdm + from collections import defaultdict + import numpy as np ++import time ++import torch ++import torch_npu + + from .model_init import AtomModelSingleton + from ...utils.config_reader import get_formula_enable, get_table_enable +@@ -95,6 +98,7 @@ class BatchAnalyze: + }) + + # OCR检测处理 ++ from concurrent.futures import ThreadPoolExecutor, as_completed + if self.enable_ocr_det_batch: + # 批处理模式 - 按语言和分辨率分组 + # 收集所有需要OCR检测的裁剪图像 +@@ -139,79 +143,73 @@ class BatchAnalyze: + ) + + # 按分辨率分组并同时完成padding ++ stride = 64 + resolution_groups = defaultdict(list) + for crop_info in lang_crop_list: + cropped_img = crop_info[0] + h, w = cropped_img.shape[:2] + # 使用更大的分组容差,减少分组数量 + # 将尺寸标准化到32的倍数 +- normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数 +- normalized_w = ((w + 32) // 32) * 32 ++ normalized_h = ((h + stride) // stride) * stride # 向上取整到32的倍数 ++ normalized_w = ((w + stride) // stride) * stride + group_key = (normalized_h, normalized_w) + resolution_groups[group_key].append(crop_info) + +- # 对每个分辨率组进行批处理 +- for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"): +- +- # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数) +- max_h = max(crop_info[0].shape[0] for crop_info in group_crops) +- max_w = max(crop_info[0].shape[1] for crop_info in group_crops) +- target_h = ((max_h + 32 - 1) // 32) * 32 +- target_w = ((max_w + 32 - 1) // 32) * 32 +- +- # 对所有图像进行padding到统一尺寸 +- batch_images = [] +- for crop_info in group_crops: +- img = crop_info[0] +- h, w = img.shape[:2] +- # 创建目标尺寸的白色背景 +- padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255 +- # 将原图像粘贴到左上角 +- padded_img[:h, :w] = img +- batch_images.append(padded_img) +- +- # 批处理检测 +- det_batch_size = min(len(batch_images), self.batch_ratio * OCR_DET_BASE_BATCH_SIZE) # 增加批处理大小 +- # logger.debug(f"OCR-det batch: {det_batch_size} images, target size: {target_h}x{target_w}") +- batch_results = ocr_model.text_detector.batch_predict(batch_images, det_batch_size) +- +- # 处理批处理结果 +- for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)): +- new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info +- +- if dt_boxes is not None and len(dt_boxes) > 0: +- # 直接应用原始OCR流程中的关键处理步骤 +- from mineru.utils.ocr_utils import ( +- merge_det_boxes, update_det_boxes, sorted_boxes +- ) + +- # 1. 排序检测框 +- if len(dt_boxes) > 0: +- dt_boxes_sorted = sorted_boxes(dt_boxes) +- else: +- dt_boxes_sorted = [] +- +- # 2. 合并相邻检测框 +- if dt_boxes_sorted: +- dt_boxes_merged = merge_det_boxes(dt_boxes_sorted) +- else: +- dt_boxes_merged = [] +- +- # 3. 根据公式位置更新检测框(关键步骤!) +- if dt_boxes_merged and adjusted_mfdetrec_res: +- dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res) +- else: +- dt_boxes_final = dt_boxes_merged +- +- # 构造OCR结果格式 +- ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final] +- +- if ocr_res: +- ocr_result_list = get_ocr_result_list( +- ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang +- ) +- +- ocr_res_list_dict['layout_res'].extend(ocr_result_list) ++ def _run_one_group_ocr(group_key, group_crops): ++ ++ max_h = max(ci[0].shape[0] for ci in group_crops) ++ max_w = max(ci[0].shape[1] for ci in group_crops) ++ target_h = ((max_h + stride - 1) // stride) * stride ++ target_w = ((max_w + stride - 1) // stride) * stride ++ ++ batch_images = [] ++ for ci in group_crops: ++ img = ci[0] ++ h, w = img.shape[:2] ++ padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255 ++ padded_img[:h, :w] = img ++ batch_images.append(padded_img) ++ ++ det_batch_size = min(len(batch_images), self.batch_ratio * OCR_DET_BASE_BATCH_SIZE) ++ ++ batch_results = ocr_model.text_detector.batch_predict(batch_images, det_batch_size) ++ ++ for i, (ci, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)): ++ new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = ci ++ if dt_boxes is not None and len(dt_boxes) > 0: ++ from mineru.utils.ocr_utils import merge_det_boxes, update_det_boxes, sorted_boxes ++ ++ if len(dt_boxes) > 0: ++ dt_boxes_sorted = sorted_boxes(dt_boxes) ++ else: ++ dt_boxes_sorted = [] ++ ++ if dt_boxes_sorted: ++ dt_boxes_merged = merge_det_boxes(dt_boxes_sorted) ++ else: ++ dt_boxes_merged = [] ++ ++ if dt_boxes_merged and adjusted_mfdetrec_res: ++ dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res) ++ else: ++ dt_boxes_final = dt_boxes_merged ++ ++ ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final] ++ if ocr_res: ++ ocr_result_list = get_ocr_result_list( ++ ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang ++ ) ++ ocr_res_list_dict['layout_res'].extend(ocr_result_list) ++ ++ MAX_WORKERS = 4 ++ start = time.time() ++ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex: ++ futures = [ex.submit(_run_one_group_ocr, gk, gcs) for gk, gcs in resolution_groups.items()] ++ for f in as_completed(futures): ++ f.result() ++ end = time.time() ++ logger.info(f"ocr det run time : {end -start}") + else: + # 原始单张处理模式 + for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"): +@@ -247,7 +245,7 @@ class BatchAnalyze: + + # 表格识别 table recognition + if self.table_enable: +- for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"): ++ def _run_one_group_table(table_res_dict): + _lang = table_res_dict['lang'] + table_model = atom_model_manager.get_atom_model( + atom_model_name='table', +@@ -271,6 +269,16 @@ class BatchAnalyze: + 'table recognition processing fails, not get html return' + ) + ++ ++ MAX_WORKERS = 4 ++ start = time.time() ++ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex: ++ futures = [ex.submit(_run_one_group_table, table_res_dict) for table_res_dict in table_res_list_all_page] ++ for f in as_completed(futures): ++ f.result() ++ end = time.time() ++ logger.info(f"table run time : {end - start}") ++ + # Create dictionaries to store items by language + need_ocr_lists_by_lang = {} # Dict of lists for each language + img_crop_lists_by_lang = {} # Dict of lists for each language +diff --git a/mineru/model/layout/doclayout_yolo.py b/mineru/model/layout/doclayout_yolo.py +index 5667a909..fc5056bb +--- a/mineru/model/layout/doclayout_yolo.py ++++ b/mineru/model/layout/doclayout_yolo.py +@@ -66,6 +66,7 @@ class DocLayoutYOLOModel: + conf=self.conf, + iou=self.iou, + verbose=False, ++ half=True + ) + for pred in predictions: + results.append(self._parse_prediction(pred)) +diff --git a/mineru/model/mfd/yolo_v8.py b/mineru/model/mfd/yolo_v8.py +index 33dac091..1fb4b50e +--- a/mineru/model/mfd/yolo_v8.py ++++ b/mineru/model/mfd/yolo_v8.py +@@ -31,7 +31,8 @@ class YOLOv8MFDModel: + conf=self.conf, + iou=self.iou, + verbose=False, +- device=self.device ++ device=self.device, ++ half=True + ) + return [pred.cpu() for pred in preds] if is_batch else preds[0].cpu() +diff --git a/mineru/model/mfr/unimernet/Unimernet.py b/mineru/model/mfr/unimernet/Unimernet.py +index ae3879da..23e56f2a +--- a/mineru/model/mfr/unimernet/Unimernet.py ++++ b/mineru/model/mfr/unimernet/Unimernet.py +@@ -1,7 +1,7 @@ + import torch + from torch.utils.data import DataLoader, Dataset + from tqdm import tqdm +- ++import numpy as np + + class MathDataset(Dataset): + def __init__(self, image_paths, transform=None): +@@ -61,7 +61,7 @@ class UnimernetModel(object): + res["latex"] = latex + return formula_list + +- def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list: ++ def _batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list: + images_formula_list = [] + mf_image_list = [] + backfill_list = [] +@@ -137,3 +137,94 @@ class UnimernetModel(object): + res["latex"] = latex + + return images_formula_list ++ ++ ++ def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list: ++ ++ images_formula_list = [] ++ mf_image_list = [] ++ backfill_list = [] ++ image_info = [] # Store (area, original_index, image) tuples ++ ++ # Collect images with their original indices ++ for image_index in range(len(images_mfd_res)): ++ mfd_res = images_mfd_res[image_index] ++ pil_img = images[image_index] ++ # split代替多次索引 ++ data = mfd_res.boxes.data.numpy() ++ xyxy, conf, cla = np.split(data, [4, 5], axis=-1) ++ ++ cla = cla.reshape(-1).astype(int).tolist() ++ conf = np.round(conf.reshape(-1).astype(float), 2).tolist() ++ ++ xyxy = xyxy.astype(np.int32) ++ xmin, ymin, xmax, ymax = xyxy[:, 0], xyxy[:, 1], xyxy[:, 2], xyxy[:, 3] ++ # area 直接矩阵运算 ++ areas = (xmax - xmin) * (ymax - ymin) ++ ++ num_boxes = len(conf) ++ ++ formula_list = [] ++ for i in range(num_boxes): ++ xmin_i, ymin_i, xmax_i, ymax_i = xyxy[i].tolist() ++ formula_list.append({ ++ "category_id": 13 + cla[i], ++ "poly": [xmin_i, ymin_i, xmax_i, ymin_i, ++ xmax_i, ymax_i, xmin_i, ymax_i], ++ "score": conf[i], ++ "latex": "", ++ }) ++ ++ # bbox_img 截取 ++ # bbox_img = pil_img[:, ymin_i:ymax_i, xmin_i:xmax_i] ++ bbox_img = pil_img.crop((xmin_i, ymin_i, xmax_i, ymax_i)) ++ curr_idx = len(mf_image_list) ++ image_info.append((areas[i], curr_idx, bbox_img)) ++ mf_image_list.append(bbox_img) ++ ++ images_formula_list.append(formula_list) ++ backfill_list += formula_list ++ ++ # Stable sort by area ++ image_info.sort(key=lambda x: x[0]) # sort by area ++ sorted_indices = [x[1] for x in image_info] ++ sorted_images = [x[2] for x in image_info] ++ ++ # Create mapping for results ++ index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)} ++ ++ # Create dataset with sorted images ++ dataset = MathDataset(sorted_images, transform=self.model.transform) ++ ++ # 如果batch_size > len(sorted_images),则设置为不超过len(sorted_images)的2的幂 ++ batch_size = min(batch_size, max(1, 2 ** (len(sorted_images).bit_length() - 1))) if sorted_images else 1 ++ ++ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0) ++ ++ # Process batches and store results ++ mfr_res = [] ++ # for mf_img in dataloader: ++ ++ with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar: ++ for index, mf_img in enumerate(dataloader): ++ mf_img = mf_img.to(dtype=self.model.dtype) ++ mf_img = mf_img.to(self.device) ++ with torch.no_grad(): ++ output = self.model.generate({"image": mf_img}, batch_size=batch_size) ++ mfr_res.extend(output["fixed_str"]) ++ ++ # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size ++ current_batch_size = min(batch_size, len(sorted_images) - index * batch_size) ++ pbar.update(current_batch_size) ++ ++ # Restore original order ++ unsorted_results = [""] * len(mfr_res) ++ for new_idx, latex in enumerate(mfr_res): ++ original_idx = index_mapping[new_idx] ++ unsorted_results[original_idx] = latex ++ ++ # Fill results back ++ for res, latex in zip(backfill_list, unsorted_results): ++ res["latex"] = latex ++ ++ return images_formula_list +diff --git a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +index 98d1deee..2c9d8328 +--- a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py ++++ b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +@@ -5,7 +5,9 @@ import cv2 + import albumentations as alb + from albumentations.pytorch import ToTensorV2 + from torchvision.transforms.functional import resize +- ++import torch ++import torch_npu ++import torch.nn.functional as F + + # TODO: dereference cv2 if possible + class UnimerSwinImageProcessor(BaseImageProcessor): +@@ -25,10 +27,53 @@ class UnimerSwinImageProcessor(BaseImageProcessor): + ] + ) + +- def __call__(self, item): ++ self.NORMALIZE_DIVISOR = torch.tensor(255.0, dtype=torch.float16, device="npu") ++ self.weights = torch.tensor([[[0.2989]], [[0.5870]], [[0.1140]]], dtype=torch.float16, device="npu") ++ self.mean = torch.tensor(0.7931, dtype=torch.float16, device="npu") ++ self.std = torch.tensor(0.1738, dtype=torch.float16, device="npu") ++ ++ self._mul_buf = torch.empty((3, *self.input_size), dtype=torch.float16, device="npu") # 预分配 [3,H,W] ++ self._gray_buf = torch.empty((1, *self.input_size), dtype=torch.float16, device="npu") # 预分配 [1,H,W] ++ ++ ++ def ___call__(self, item): + image = self.prepare_input(item) + return self.transform(image=image)['image'][:1] + ++ def pil_to_npu(self, pil_img, device="npu"): ++ img = torch.from_numpy(np.asarray(pil_img, dtype=np.float16)) ++ img = img.to(device).permute(2, 0, 1) / self.NORMALIZE_DIVISOR ++ return img ++ ++ def __call__(self, item): ++ ++ img = self.crop_margin(item) ++ img = self.pil_to_npu(img) ++ ++ _, h, w = img.shape ++ target_h, target_w = self.input_size ++ scale = min(target_h / h, target_w / w) ++ new_h, new_w = int(h*scale), int(w*scale) ++ ++ img = img.view(1, *img.shape) # [1,C,H,W] ++ img = F.interpolate(img, size=(new_h, new_w), mode='bilinear', align_corners=False) ++ img = img.view(*img.shape[1:]) ++ ++ dw, dh = target_w - new_w, target_h - new_h ++ dw /= 2 ++ dh /= 2 ++ left, right = int(dw), int(dw + 0.5) ++ top, bottom = int(dh), int(dh + 0.5) ++ img = F.pad(img, (left, right, top, bottom), value=0.0) ++ ++ # RGB -> Gray ++ gray_tensor = (img * self.weights).sum(dim=0, keepdim=True) # [1, H, W] ++ ++ # Normalize ++ gray_tensor.sub_(self.mean).div_(self.std) ++ return gray_tensor ++ ++ + @staticmethod + def crop_margin(img: Image.Image) -> Image.Image: + data = np.array(img.convert("L")) +@@ -44,6 +89,34 @@ class UnimerSwinImageProcessor(BaseImageProcessor): + a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box + return img.crop((a, b, w + a, h + b)) + ++ def crop_margin_tensor(self, img): ++ """ ++ img: [C,H,W] tensor, uint8 或 float ++ """ ++ ++ gray = (img * self.weights).sum(dim=0) ++ ++ ++ ++ gray = gray.to(torch.uint8) ++ max_val = gray.max() ++ min_val = gray.min() ++ ++ if max_val == min_val: ++ return img ++ ++ norm_gray = (gray - min_val) / (max_val - min_val) ++ ++ mask = (norm_gray < self.threshold) ++ ++ coords = mask.nonzero(as_tuple=False) ++ if coords.shape[0] == 0: ++ return img ++ ymin, xmin = coords.min(0)[0] ++ ymax, xmax = coords.max(0)[0] ++ ++ return img[:, ymin:ymax+1, xmin:xmax+1] ++ + @staticmethod + def crop_margin_numpy(img: np.ndarray) -> np.ndarray: + """Crop margins of image using NumPy operations""" +diff --git a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +index 1b808e8b..0fe54751 +--- a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py ++++ b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +@@ -465,11 +465,15 @@ class UnimerSwinSelfAttention(nn.Module): + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + batch_size, dim, num_channels = hidden_states.shape +- mixed_query_layer = self.query(hidden_states) + +- key_layer = self.transpose_for_scores(self.key(hidden_states)) +- value_layer = self.transpose_for_scores(self.value(hidden_states)) +- query_layer = self.transpose_for_scores(mixed_query_layer) ++ # """融合qk为大矩阵,由于加入相对位置编码,PFA接口用不了,暂时只修改矩阵乘法""" ++ batch_size, dim, num_channels = hidden_states.shape ++ qkv = self.qkv(hidden_states) ++ q, k, v = qkv.chunk(3, dim=-1) ++ ++ query_layer = q.view(*q.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) ++ key_layer = k.view(*k.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) ++ value_layer = v.view(*v.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) +diff --git a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py +index 3de483ac..23813db9 100755 +--- a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py ++++ b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py +@@ -117,6 +117,10 @@ class TextDetector(BaseOCRV20): + self.net.eval() + self.net.to(self.device) + ++ ++ import threading ++ self._dev_lock = getattr(self, "_dev_lock", threading.Lock()) ++ + def _batch_process_same_size(self, img_list): + """ + 对相同尺寸的图像进行批处理 +@@ -162,12 +166,12 @@ class TextDetector(BaseOCRV20): + return batch_results, time.time() - starttime + + # 批处理推理 +- with torch.no_grad(): +- inp = torch.from_numpy(batch_tensor) +- inp = inp.to(self.device) +- outputs = self.net(inp) +- +- # 处理输出 ++ with self._dev_lock: ++ with torch.no_grad(): ++ inp = torch.from_numpy(batch_tensor) ++ inp = inp.to(self.device) ++ outputs = self.net(inp) ++ # 处理输出 + preds = {} + if self.det_algorithm == "EAST": + preds['f_geo'] = outputs['f_geo'].cpu().numpy() +@@ -304,10 +308,11 @@ class TextDetector(BaseOCRV20): + img = img.copy() + starttime = time.time() + +- with torch.no_grad(): +- inp = torch.from_numpy(img) +- inp = inp.to(self.device) +- outputs = self.net(inp) ++ with self._dev_lock: ++ with torch.no_grad(): ++ inp = torch.from_numpy(img) ++ inp = inp.to(self.device) ++ outputs = self.net(inp) + + preds = {} + if self.det_algorithm == "EAST": +diff --git a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +index c06ca5fe..d865b201 100755 +--- a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py ++++ b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +@@ -94,6 +94,9 @@ class TextRecognizer(BaseOCRV20): + self.net.eval() + self.net.to(self.device) + ++ import threading ++ self._dev_lock = getattr(self, "_dev_lock", threading.Lock()) ++ + def resize_norm_img(self, img, max_wh_ratio): + imgC, imgH, imgW = self.rec_image_shape + if self.rec_algorithm == 'NRTR' or self.rec_algorithm == 'ViTSTR': +@@ -301,74 +304,78 @@ class TextRecognizer(BaseOCRV20): + rec_res = [['', 0.0]] * img_num + batch_num = self.rec_batch_num + elapse = 0 +- # for beg_img_no in range(0, img_num, batch_num): +- with tqdm(total=img_num, desc='OCR-rec Predict', disable=not tqdm_enable) as pbar: +- index = 0 +- for beg_img_no in range(0, img_num, batch_num): +- end_img_no = min(img_num, beg_img_no + batch_num) +- norm_img_batch = [] +- max_wh_ratio = 0 +- for ino in range(beg_img_no, end_img_no): +- # h, w = img_list[ino].shape[0:2] +- h, w = img_list[indices[ino]].shape[0:2] +- wh_ratio = w * 1.0 / h +- max_wh_ratio = max(max_wh_ratio, wh_ratio) +- for ino in range(beg_img_no, end_img_no): +- if self.rec_algorithm == "SAR": +- norm_img, _, _, valid_ratio = self.resize_norm_img_sar( +- img_list[indices[ino]], self.rec_image_shape) +- norm_img = norm_img[np.newaxis, :] +- valid_ratio = np.expand_dims(valid_ratio, axis=0) +- valid_ratios = [] +- valid_ratios.append(valid_ratio) +- norm_img_batch.append(norm_img) +- +- elif self.rec_algorithm == "SVTR": +- norm_img = self.resize_norm_img_svtr(img_list[indices[ino]], +- self.rec_image_shape) +- norm_img = norm_img[np.newaxis, :] +- norm_img_batch.append(norm_img) +- elif self.rec_algorithm == "SRN": +- norm_img = self.process_image_srn(img_list[indices[ino]], +- self.rec_image_shape, 8, +- self.max_text_length) +- encoder_word_pos_list = [] +- gsrm_word_pos_list = [] +- gsrm_slf_attn_bias1_list = [] +- gsrm_slf_attn_bias2_list = [] +- encoder_word_pos_list.append(norm_img[1]) +- gsrm_word_pos_list.append(norm_img[2]) +- gsrm_slf_attn_bias1_list.append(norm_img[3]) +- gsrm_slf_attn_bias2_list.append(norm_img[4]) +- norm_img_batch.append(norm_img[0]) +- elif self.rec_algorithm == "CAN": +- norm_img = self.norm_img_can(img_list[indices[ino]], +- max_wh_ratio) +- norm_img = norm_img[np.newaxis, :] +- norm_img_batch.append(norm_img) +- norm_image_mask = np.ones(norm_img.shape, dtype='float32') +- word_label = np.ones([1, 36], dtype='int64') +- norm_img_mask_batch = [] +- word_label_list = [] +- norm_img_mask_batch.append(norm_image_mask) +- word_label_list.append(word_label) +- else: +- norm_img = self.resize_norm_img(img_list[indices[ino]], +- max_wh_ratio) +- norm_img = norm_img[np.newaxis, :] +- norm_img_batch.append(norm_img) +- norm_img_batch = np.concatenate(norm_img_batch) +- norm_img_batch = norm_img_batch.copy() +- +- if self.rec_algorithm == "SRN": +- starttime = time.time() +- encoder_word_pos_list = np.concatenate(encoder_word_pos_list) +- gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) +- gsrm_slf_attn_bias1_list = np.concatenate( +- gsrm_slf_attn_bias1_list) +- gsrm_slf_attn_bias2_list = np.concatenate( +- gsrm_slf_attn_bias2_list) + ++ # for beg_img_no in range(0, img_num, batch_num): ++ from concurrent.futures import ThreadPoolExecutor, as_completed ++ def _rec_batch_worker(beg_img_no: int, end_img_no: int): ++ ++ ++ max_wh_ratio = 0.0 ++ norm_img_batch = [] ++ for ino in range(beg_img_no, end_img_no): ++ # h, w = img_list[ino].shape[0:2] ++ h, w = img_list[indices[ino]].shape[0:2] ++ wh_ratio = w * 1.0 / h ++ max_wh_ratio = max(max_wh_ratio, wh_ratio) ++ for ino in range(beg_img_no, end_img_no): ++ if self.rec_algorithm == "SAR": ++ norm_img, _, _, valid_ratio = self.resize_norm_img_sar( ++ img_list[indices[ino]], self.rec_image_shape) ++ norm_img = norm_img[np.newaxis, :] ++ valid_ratio = np.expand_dims(valid_ratio, axis=0) ++ valid_ratios = [] ++ valid_ratios.append(valid_ratio) ++ norm_img_batch.append(norm_img) ++ ++ elif self.rec_algorithm == "SVTR": ++ norm_img = self.resize_norm_img_svtr(img_list[indices[ino]], ++ self.rec_image_shape) ++ norm_img = norm_img[np.newaxis, :] ++ norm_img_batch.append(norm_img) ++ elif self.rec_algorithm == "SRN": ++ norm_img = self.process_image_srn(img_list[indices[ino]], ++ self.rec_image_shape, 8, ++ self.max_text_length) ++ encoder_word_pos_list = [] ++ gsrm_word_pos_list = [] ++ gsrm_slf_attn_bias1_list = [] ++ gsrm_slf_attn_bias2_list = [] ++ encoder_word_pos_list.append(norm_img[1]) ++ gsrm_word_pos_list.append(norm_img[2]) ++ gsrm_slf_attn_bias1_list.append(norm_img[3]) ++ gsrm_slf_attn_bias2_list.append(norm_img[4]) ++ norm_img_batch.append(norm_img[0]) ++ elif self.rec_algorithm == "CAN": ++ norm_img = self.norm_img_can(img_list[indices[ino]], ++ max_wh_ratio) ++ norm_img = norm_img[np.newaxis, :] ++ norm_img_batch.append(norm_img) ++ norm_image_mask = np.ones(norm_img.shape, dtype='float32') ++ word_label = np.ones([1, 36], dtype='int64') ++ norm_img_mask_batch = [] ++ word_label_list = [] ++ norm_img_mask_batch.append(norm_image_mask) ++ word_label_list.append(word_label) ++ else: ++ norm_img = self.resize_norm_img(img_list[indices[ino]], ++ max_wh_ratio) ++ norm_img = norm_img[np.newaxis, :] ++ norm_img_batch.append(norm_img) ++ norm_img_batch = np.concatenate(norm_img_batch) ++ norm_img_batch = norm_img_batch.copy() ++ ++ starttime = time.time() ++ ++ if self.rec_algorithm == "SRN": ++ starttime = time.time() ++ encoder_word_pos_list = np.concatenate(encoder_word_pos_list) ++ gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) ++ gsrm_slf_attn_bias1_list = np.concatenate( ++ gsrm_slf_attn_bias1_list) ++ gsrm_slf_attn_bias2_list = np.concatenate( ++ gsrm_slf_attn_bias2_list) ++ ++ with self._dev_lock: + with torch.no_grad(): + inp = torch.from_numpy(norm_img_batch) + encoder_word_pos_inp = torch.from_numpy(encoder_word_pos_list) +@@ -384,58 +391,67 @@ class TextRecognizer(BaseOCRV20): + + backbone_out = self.net.backbone(inp) # backbone_feat + prob_out = self.net.head(backbone_out, [encoder_word_pos_inp, gsrm_word_pos_inp, gsrm_slf_attn_bias1_inp, gsrm_slf_attn_bias2_inp]) +- # preds = {"predict": prob_out[2]} +- preds = {"predict": prob_out["predict"]} +- +- elif self.rec_algorithm == "SAR": +- starttime = time.time() +- # valid_ratios = np.concatenate(valid_ratios) +- # inputs = [ +- # norm_img_batch, +- # valid_ratios, +- # ] +- ++ # preds = {"predict": prob_out[2]} ++ preds = {"predict": prob_out["predict"]} ++ ++ elif self.rec_algorithm == "SAR": ++ starttime = time.time() ++ # valid_ratios = np.concatenate(valid_ratios) ++ # inputs = [ ++ # norm_img_batch, ++ # valid_ratios, ++ # ] ++ ++ with self._dev_lock: + with torch.no_grad(): + inp = torch.from_numpy(norm_img_batch) + inp = inp.to(self.device) + preds = self.net(inp) + +- elif self.rec_algorithm == "CAN": +- starttime = time.time() +- norm_img_mask_batch = np.concatenate(norm_img_mask_batch) +- word_label_list = np.concatenate(word_label_list) +- inputs = [norm_img_batch, norm_img_mask_batch, word_label_list] ++ elif self.rec_algorithm == "CAN": ++ starttime = time.time() ++ norm_img_mask_batch = np.concatenate(norm_img_mask_batch) ++ word_label_list = np.concatenate(word_label_list) ++ inputs = [norm_img_batch, norm_img_mask_batch, word_label_list] + +- inp = [torch.from_numpy(e_i) for e_i in inputs] +- inp = [e_i.to(self.device) for e_i in inp] ++ inp = [torch.from_numpy(e_i) for e_i in inputs] ++ inp = [e_i.to(self.device) for e_i in inp] ++ with self._dev_lock: + with torch.no_grad(): + outputs = self.net(inp) + outputs = [v.cpu().numpy() for k, v in enumerate(outputs)] + +- preds = outputs +- +- else: +- starttime = time.time() ++ preds = outputs + ++ else: ++ with self._dev_lock: + with torch.no_grad(): +- inp = torch.from_numpy(norm_img_batch) +- inp = inp.to(self.device) ++ inp = torch.from_numpy(norm_img_batch).to(self.device) + prob_out = self.net(inp) ++ preds = [v.cpu().numpy() for v in prob_out] if isinstance(prob_out, list) else prob_out.cpu().numpy() + +- if isinstance(prob_out, list): +- preds = [v.cpu().numpy() for v in prob_out] +- else: +- preds = prob_out.cpu().numpy() ++ rec_result = self.postprocess_op(preds) + +- rec_result = self.postprocess_op(preds) +- for rno in range(len(rec_result)): +- rec_res[indices[beg_img_no + rno]] = rec_result[rno] +- elapse += time.time() - starttime ++ for rno in range(len(rec_result)): ++ global_idx = indices[beg_img_no + rno] ++ rec_res[global_idx] = rec_result[rno] ++ ++ batch_elapse = time.time() - starttime ++ return len(rec_result), batch_elapse ++ ++ MAX_WORKERS = 4 ++ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex, \ ++ tqdm(total=img_num, desc='OCR-rec Predict', disable=not tqdm_enable) as pbar: ++ ++ futures = [] ++ for beg_img_no in range(0, img_num, batch_num): ++ end_img_no = min(img_num, beg_img_no + batch_num) ++ futures.append(ex.submit(_rec_batch_worker, beg_img_no, end_img_no)) + +- # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size +- current_batch_size = min(batch_num, img_num - index * batch_num) +- index += 1 +- pbar.update(current_batch_size) ++ for fut in as_completed(futures): ++ n_done, batch_elapse = fut.result() ++ elapse += batch_elapse ++ pbar.update(n_done) + + # Fix NaN values in recognition results + for i in range(len(rec_res)): +diff --git a/mineru/model/table/rapid_table.py b/mineru/model/table/rapid_table.py +index 174a8052..dd796bcc +--- a/mineru/model/table/rapid_table.py ++++ b/mineru/model/table/rapid_table.py +@@ -21,6 +21,8 @@ class RapidTableModel(object): + self.table_model = RapidTable(input_args) + self.ocr_engine = ocr_engine + ++ import threading ++ self._dev_lock = getattr(self, "_dev_lock", threading.Lock()) + + def predict(self, image): + bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) +@@ -30,44 +32,45 @@ class RapidTableModel(object): + img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0 + img_is_portrait = img_aspect_ratio > 1.2 + +- if img_is_portrait: ++ with self._dev_lock: ++ if img_is_portrait: + +- det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0] +- # Check if table is rotated by analyzing text box aspect ratios +- is_rotated = False +- if det_res: +- vertical_count = 0 ++ det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0] ++ # Check if table is rotated by analyzing text box aspect ratios ++ is_rotated = False ++ if det_res: ++ vertical_count = 0 + +- for box_ocr_res in det_res: +- p1, p2, p3, p4 = box_ocr_res ++ for box_ocr_res in det_res: ++ p1, p2, p3, p4 = box_ocr_res + +- # Calculate width and height +- width = p3[0] - p1[0] +- height = p3[1] - p1[1] ++ # Calculate width and height ++ width = p3[0] - p1[0] ++ height = p3[1] - p1[1] + +- aspect_ratio = width / height if height > 0 else 1.0 ++ aspect_ratio = width / height if height > 0 else 1.0 + +- # Count vertical vs horizontal text boxes +- if aspect_ratio < 0.8: # Taller than wide - vertical text +- vertical_count += 1 +- # elif aspect_ratio > 1.2: # Wider than tall - horizontal text +- # horizontal_count += 1 ++ # Count vertical vs horizontal text boxes ++ if aspect_ratio < 0.8: # Taller than wide - vertical text ++ vertical_count += 1 ++ # elif aspect_ratio > 1.2: # Wider than tall - horizontal text ++ # horizontal_count += 1 + +- # If we have more vertical text boxes than horizontal ones, +- # and vertical ones are significant, table might be rotated +- if vertical_count >= len(det_res) * 0.3: +- is_rotated = True ++ # If we have more vertical text boxes than horizontal ones, ++ # and vertical ones are significant, table might be rotated ++ if vertical_count >= len(det_res) * 0.3: ++ is_rotated = True + +- # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}") ++ # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}") + +- # Rotate image if necessary +- if is_rotated: +- # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise") +- image = cv2.rotate(np.asarray(image), cv2.ROTATE_90_CLOCKWISE) +- bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) ++ # Rotate image if necessary ++ if is_rotated: ++ # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise") ++ image = cv2.rotate(np.asarray(image), cv2.ROTATE_90_CLOCKWISE) ++ bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + +- # Continue with OCR on potentially rotated image +- ocr_result = self.ocr_engine.ocr(bgr_image)[0] ++ # Continue with OCR on potentially rotated image ++ ocr_result = self.ocr_engine.ocr(bgr_image)[0] + if ocr_result: + ocr_result = [[item[0], escape_html(item[1][0]), item[1][1]] for item in ocr_result if + len(item) == 2 and isinstance(item[1], tuple)] + diff --git a/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch b/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch index 5511fa6a9e..2647c7695d 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch +++ b/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch @@ -1,7 +1,120 @@ -diff -ruN ultralytics-8.3.193/ultralytics/engine/predictor.py ultralytics_/ultralytics/engine/predictor.py +diff -ruN ultralytics-8.3.193/ultralytics/data/loaders.py ultralytics/data/loaders.py +--- ultralytics-8.3.193/ultralytics/data/loaders.py 2025-09-04 19:51:11.000000000 +0800 ++++ ultralytics/data/loaders.py 2025-10-19 01:27:48.412000000 +0800 +@@ -534,7 +534,7 @@ + self.bs = len(self.im0) + + @staticmethod +- def _single_check(im: Image.Image | np.ndarray, flag: str = "RGB") -> np.ndarray: ++ def __single_check(im: Image.Image | np.ndarray, flag: str = "RGB") -> np.ndarray: + """Validate and format an image to numpy array, ensuring RGB order and contiguous memory.""" + assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}" + if isinstance(im, Image.Image): +@@ -546,6 +546,19 @@ + im = im[..., None] + return im + ++ @staticmethod ++ def _single_check(im: Image.Image | np.ndarray, flag: str = "RGB") -> np.ndarray: ++ """Validate and format an image to numpy array, ensuring RGB order and contiguous memory.""" ++ assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}" ++ if isinstance(im, Image.Image): ++ if im.mode != "RGB": ++ im = im.convert("RGB") ++ im = np.asarray(im) ++ elif im.ndim == 2: # grayscale in numpy form ++ im = im[..., None] ++ return im ++ ++ + def __len__(self) -> int: + """Return the length of the 'im0' attribute, representing the number of loaded images.""" + return len(self.im0) +diff -ruN ultralytics-8.3.193/ultralytics/engine/model.py ultralytics/engine/model.py +--- ultralytics-8.3.193/ultralytics/engine/model.py 2025-09-04 19:51:11.000000000 +0800 ++++ ultralytics/engine/model.py 2025-10-19 01:27:48.412000000 +0800 +@@ -152,6 +152,8 @@ + else: + self._load(model, task=task) + ++ self.model.half() ++ + # Delete super().training for accessing self.model.training + del self.training + +diff -ruN ultralytics-8.3.193/ultralytics/engine/predictor.py ultralytics/engine/predictor.py --- ultralytics-8.3.193/ultralytics/engine/predictor.py 2025-09-04 19:51:11.000000000 +0800 -+++ ultralytics_/ultralytics/engine/predictor.py 2025-09-09 14:56:14.535737230 +0800 -@@ -196,9 +196,10 @@ ++++ ultralytics/engine/predictor.py 2025-10-19 01:27:48.412000000 +0800 +@@ -43,6 +43,7 @@ + import cv2 + import numpy as np + import torch ++import torch.nn.functional as F + + from ultralytics.cfg import get_cfg, get_save_dir + from ultralytics.data import load_inference_source +@@ -149,7 +150,7 @@ + self._lock = threading.Lock() # for automatic thread-safe inference + callbacks.add_integration_callbacks(self) + +- def preprocess(self, im: torch.Tensor | list[np.ndarray]) -> torch.Tensor: ++ def _preprocess(self, im: torch.Tensor | list[np.ndarray]) -> torch.Tensor: + """ + Prepare input image before inference. + +@@ -174,6 +175,51 @@ + im /= 255 # 0 - 255 to 0.0 - 1.0 + return im + ++ def preprocess(self, images: torch.Tensor | list[np.ndarray]) -> torch.Tensor: ++ """ ++ Prepare input image before inference. ++ ++ Args: ++ images (torch.Tensor | List[np.ndarray]): Images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for list. ++ ++ Returns: ++ (torch.Tensor): Preprocessed image tensor of shape (N, 3, H, W). ++ """ ++ ++ new_shape = (new_shape, new_shape) if isinstance(self.imgsz, int) else self.imgsz ++ tensors = [] ++ for im in images: ++ im = torch.from_numpy(im).to(self.device).permute((2, 0, 1)) / 255.0 ++ ++ c, h, w = im.shape ++ ++ r = min(new_shape[0] / h, new_shape[1] / w) ++ ++ new_unpad = (int(round(w * r)), int(round(h * r))) ++ ++ if (w, h) != new_unpad: ++ im = F.interpolate(im.unsqueeze(0), size=(new_unpad[1], new_unpad[0]), ++ mode="bilinear", align_corners=False).squeeze(0) ++ ++ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] ++ dw /= 2 ++ dh /= 2 ++ left, right = int(dw), int(dw + 0.5) ++ top, bottom = int(dh), int(dh + 0.5) ++ im = F.pad(im, (left, right, top, bottom), value=114/255.0) ++ ++ _, H, W = im.shape ++ assert (H, W) == (new_shape[0], new_shape[1]), f"Expected image size do not match: padding image size:{(H, W)} != expected image size: {(new_shape[0], new_shape[1])}" ++ ++ im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32 ++ ++ tensors.append(im) ++ ++ return torch.stack(tensors, dim=0) ++ ++ ++ ++ + def inference(self, im: torch.Tensor, *args, **kwargs): + """Run inference on a given image using the specified model and arguments.""" + visualize = ( +@@ -196,9 +242,10 @@ same_shapes = len({x.shape for x in im}) == 1 letterbox = LetterBox( self.imgsz, @@ -15,7 +128,7 @@ diff -ruN ultralytics-8.3.193/ultralytics/engine/predictor.py ultralytics_/ultra stride=self.model.stride, ) return [letterbox(image=x) for x in im] -@@ -311,8 +312,11 @@ +@@ -311,8 +358,11 @@ # Warmup model if not self.done_warmup: @@ -28,7 +141,7 @@ diff -ruN ultralytics-8.3.193/ultralytics/engine/predictor.py ultralytics_/ultra ) self.done_warmup = True -@@ -400,7 +404,8 @@ +@@ -400,7 +450,8 @@ dnn=self.args.dnn, data=self.args.data, fp16=self.args.half, @@ -37,10 +150,9 @@ diff -ruN ultralytics-8.3.193/ultralytics/engine/predictor.py ultralytics_/ultra + fuse=False, verbose=verbose, ) - -diff -ruN ultralytics-8.3.193/ultralytics/nn/modules/block.py ultralytics_/ultralytics/nn/modules/block.py +diff -ruN ultralytics-8.3.193/ultralytics/nn/modules/block.py ultralytics/nn/modules/block.py --- ultralytics-8.3.193/ultralytics/nn/modules/block.py 2025-09-04 19:51:11.000000000 +0800 -+++ ultralytics_/ultralytics/nn/modules/block.py 2025-09-09 14:56:14.543737230 +0800 ++++ ultralytics/nn/modules/block.py 2025-10-19 01:27:48.424000000 +0800 @@ -237,7 +237,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply sequential pooling operations to input and return concatenated feature maps.""" @@ -63,10 +175,9 @@ diff -ruN ultralytics-8.3.193/ultralytics/nn/modules/block.py ultralytics_/ultra return self.cv2(torch.cat(y, 1)) def forward_split(self, x: torch.Tensor) -> torch.Tensor: - -diff -ruN ultralytics-8.3.193/ultralytics/utils/tal.py ultralytics_/ultralytics/utils/tal.py +diff -ruN ultralytics-8.3.193/ultralytics/utils/tal.py ultralytics/utils/tal.py --- ultralytics-8.3.193/ultralytics/utils/tal.py 2025-09-04 19:51:11.000000000 +0800 -+++ ultralytics_/ultralytics/utils/tal.py 2025-09-09 14:56:14.551737230 +0800 ++++ ultralytics/utils/tal.py 2025-10-19 01:27:48.428000000 +0800 @@ -375,7 +375,8 @@ sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx) @@ -75,3 +186,4 @@ diff -ruN ultralytics-8.3.193/ultralytics/utils/tal.py ultralytics_/ultralytics/ + # stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) + stride_tensor.append(torch.ones((h * w, 1), dtype=dtype, device=device)*stride) return torch.cat(anchor_points), torch.cat(stride_tensor) + -- Gitee From d74785d4a365b082b6fbdc0ee5fcdd5649168f7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Tue, 21 Oct 2025 21:54:32 +0800 Subject: [PATCH 09/20] 1 --- ACL_PyTorch/built-in/ocr/MinerU/mineru.patch | 153 ++++++++++++------- 1 file changed, 97 insertions(+), 56 deletions(-) diff --git a/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch b/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch index 909a60a4a7..0e57a9c2c0 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch +++ b/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch @@ -1,5 +1,28 @@ +diff --git a/demo/demo.py b/demo/demo.py +index 36433c45..6f28620f 100644 +--- a/demo/demo.py ++++ b/demo/demo.py +@@ -86,7 +86,7 @@ def do_parse( + image_dir = str(os.path.basename(local_image_dir)) + content_list = pipeline_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir) + md_writer.write_string( +- f"{pdf_file_name}_content_list.json", ++ f"{pdf_file_name}_content.json", + json.dumps(content_list, ensure_ascii=False, indent=4), + ) + +@@ -142,7 +142,8 @@ def do_parse( + image_dir = str(os.path.basename(local_image_dir)) + content_list = vlm_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir) + md_writer.write_string( +- f"{pdf_file_name}_content_list.json", ++ # f"{pdf_file_name}_content_list.json", ++ f"{pdf_file_name}_content.json", ## 文件名太长了,linux文件系统ext4超过255字节无法保存 + json.dumps(content_list, ensure_ascii=False, indent=4), + ) + diff --git a/mineru/backend/pipeline/batch_analyze.py b/mineru/backend/pipeline/batch_analyze.py -index c88a52a3..b0b79a80 +index c88a52a3..b0b79a80 100644 --- a/mineru/backend/pipeline/batch_analyze.py +++ b/mineru/backend/pipeline/batch_analyze.py @@ -3,6 +3,9 @@ from loguru import logger @@ -9,12 +32,12 @@ index c88a52a3..b0b79a80 +import time +import torch +import torch_npu - + from .model_init import AtomModelSingleton from ...utils.config_reader import get_formula_enable, get_table_enable @@ -95,6 +98,7 @@ class BatchAnalyze: }) - + # OCR检测处理 + from concurrent.futures import ThreadPoolExecutor, as_completed if self.enable_ocr_det_batch: @@ -22,7 +45,7 @@ index c88a52a3..b0b79a80 # 收集所有需要OCR检测的裁剪图像 @@ -139,79 +143,73 @@ class BatchAnalyze: ) - + # 按分辨率分组并同时完成padding + stride = 64 resolution_groups = defaultdict(list) @@ -37,7 +60,7 @@ index c88a52a3..b0b79a80 + normalized_w = ((w + stride) // stride) * stride group_key = (normalized_h, normalized_w) resolution_groups[group_key].append(crop_info) - + - # 对每个分辨率组进行批处理 - for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"): - @@ -72,7 +95,7 @@ index c88a52a3..b0b79a80 - from mineru.utils.ocr_utils import ( - merge_det_boxes, update_det_boxes, sorted_boxes - ) - + - # 1. 排序检测框 - if len(dt_boxes) > 0: - dt_boxes_sorted = sorted_boxes(dt_boxes) @@ -151,14 +174,14 @@ index c88a52a3..b0b79a80 + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex: + futures = [ex.submit(_run_one_group_ocr, gk, gcs) for gk, gcs in resolution_groups.items()] + for f in as_completed(futures): -+ f.result() ++ f.result() + end = time.time() + logger.info(f"ocr det run time : {end -start}") else: # 原始单张处理模式 for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"): @@ -247,7 +245,7 @@ class BatchAnalyze: - + # 表格识别 table recognition if self.table_enable: - for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"): @@ -169,22 +192,41 @@ index c88a52a3..b0b79a80 @@ -271,6 +269,16 @@ class BatchAnalyze: 'table recognition processing fails, not get html return' ) - + + + MAX_WORKERS = 4 + start = time.time() + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex: + futures = [ex.submit(_run_one_group_table, table_res_dict) for table_res_dict in table_res_list_all_page] + for f in as_completed(futures): -+ f.result() ++ f.result() + end = time.time() + logger.info(f"table run time : {end - start}") + # Create dictionaries to store items by language need_ocr_lists_by_lang = {} # Dict of lists for each language img_crop_lists_by_lang = {} # Dict of lists for each language +diff --git a/mineru/cli/common.py b/mineru/cli/common.py +index cd9f0803..3cdf7f21 100644 +--- a/mineru/cli/common.py ++++ b/mineru/cli/common.py +@@ -34,7 +34,13 @@ def read_fn(path): + + + def prepare_env(output_dir, pdf_file_name, parse_method): +- local_md_dir = str(os.path.join(output_dir, pdf_file_name, parse_method)) ++ print(pdf_file_name) ++ if len(pdf_file_name)>100: ++ pdf_file_name_ = pdf_file_name[:30] + pdf_file_name[-30:] ++ else: ++ pdf_file_name_ = pdf_file_name ++ ++ local_md_dir = str(os.path.join(output_dir, pdf_file_name_, parse_method)) + local_image_dir = os.path.join(str(local_md_dir), "images") + os.makedirs(local_image_dir, exist_ok=True) + os.makedirs(local_md_dir, exist_ok=True) diff --git a/mineru/model/layout/doclayout_yolo.py b/mineru/model/layout/doclayout_yolo.py -index 5667a909..fc5056bb +index 5667a909..fc5056bb 100644 --- a/mineru/model/layout/doclayout_yolo.py +++ b/mineru/model/layout/doclayout_yolo.py @@ -66,6 +66,7 @@ class DocLayoutYOLOModel: @@ -196,7 +238,7 @@ index 5667a909..fc5056bb for pred in predictions: results.append(self._parse_prediction(pred)) diff --git a/mineru/model/mfd/yolo_v8.py b/mineru/model/mfd/yolo_v8.py -index 33dac091..1fb4b50e +index 33dac091..1fb4b50e 100644 --- a/mineru/model/mfd/yolo_v8.py +++ b/mineru/model/mfd/yolo_v8.py @@ -31,7 +31,8 @@ class YOLOv8MFDModel: @@ -208,8 +250,9 @@ index 33dac091..1fb4b50e + half=True ) return [pred.cpu() for pred in preds] if is_batch else preds[0].cpu() + diff --git a/mineru/model/mfr/unimernet/Unimernet.py b/mineru/model/mfr/unimernet/Unimernet.py -index ae3879da..23e56f2a +index ae3879da..23e56f2a 100644 --- a/mineru/model/mfr/unimernet/Unimernet.py +++ b/mineru/model/mfr/unimernet/Unimernet.py @@ -1,7 +1,7 @@ @@ -218,13 +261,13 @@ index ae3879da..23e56f2a from tqdm import tqdm - +import numpy as np - + class MathDataset(Dataset): def __init__(self, image_paths, transform=None): @@ -61,7 +61,7 @@ class UnimernetModel(object): res["latex"] = latex return formula_list - + - def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list: + def _batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list: images_formula_list = [] @@ -232,7 +275,7 @@ index ae3879da..23e56f2a backfill_list = [] @@ -137,3 +137,94 @@ class UnimernetModel(object): res["latex"] = latex - + return images_formula_list + + @@ -326,7 +369,7 @@ index ae3879da..23e56f2a + + return images_formula_list diff --git a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py -index 98d1deee..2c9d8328 +index 98d1deee..3866a257 100644 --- a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +++ b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py @@ -5,7 +5,9 @@ import cv2 @@ -337,13 +380,13 @@ index 98d1deee..2c9d8328 +import torch +import torch_npu +import torch.nn.functional as F - + # TODO: dereference cv2 if possible class UnimerSwinImageProcessor(BaseImageProcessor): @@ -25,10 +27,53 @@ class UnimerSwinImageProcessor(BaseImageProcessor): ] ) - + - def __call__(self, item): + self.NORMALIZE_DIVISOR = torch.tensor(255.0, dtype=torch.float16, device="npu") + self.weights = torch.tensor([[[0.2989]], [[0.5870]], [[0.1140]]], dtype=torch.float16, device="npu") @@ -357,7 +400,7 @@ index 98d1deee..2c9d8328 + def ___call__(self, item): image = self.prepare_input(item) return self.transform(image=image)['image'][:1] - + + def pil_to_npu(self, pil_img, device="npu"): + img = torch.from_numpy(np.asarray(pil_img, dtype=np.float16)) + img = img.to(device).permute(2, 0, 1) / self.NORMALIZE_DIVISOR @@ -395,10 +438,10 @@ index 98d1deee..2c9d8328 @staticmethod def crop_margin(img: Image.Image) -> Image.Image: data = np.array(img.convert("L")) -@@ -44,6 +89,34 @@ class UnimerSwinImageProcessor(BaseImageProcessor): +@@ -44,6 +89,32 @@ class UnimerSwinImageProcessor(BaseImageProcessor): a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box return img.crop((a, b, w + a, h + b)) - + + def crop_margin_tensor(self, img): + """ + img: [C,H,W] tensor, uint8 或 float @@ -406,8 +449,6 @@ index 98d1deee..2c9d8328 + + gray = (img * self.weights).sum(dim=0) + -+ -+ + gray = gray.to(torch.uint8) + max_val = gray.max() + min_val = gray.min() @@ -431,7 +472,7 @@ index 98d1deee..2c9d8328 def crop_margin_numpy(img: np.ndarray) -> np.ndarray: """Crop margins of image using NumPy operations""" diff --git a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py -index 1b808e8b..0fe54751 +index 1b808e8b..0fe54751 100644 --- a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +++ b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py @@ -465,11 +465,15 @@ class UnimerSwinSelfAttention(nn.Module): @@ -439,7 +480,7 @@ index 1b808e8b..0fe54751 ) -> Tuple[torch.Tensor]: batch_size, dim, num_channels = hidden_states.shape - mixed_query_layer = self.query(hidden_states) - + - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) @@ -451,7 +492,7 @@ index 1b808e8b..0fe54751 + query_layer = q.view(*q.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) + key_layer = k.view(*k.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) + value_layer = v.view(*v.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) - + # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py @@ -461,7 +502,7 @@ index 3de483ac..23813db9 100755 @@ -117,6 +117,10 @@ class TextDetector(BaseOCRV20): self.net.eval() self.net.to(self.device) - + + + import threading + self._dev_lock = getattr(self, "_dev_lock", threading.Lock()) @@ -471,7 +512,7 @@ index 3de483ac..23813db9 100755 对相同尺寸的图像进行批处理 @@ -162,12 +166,12 @@ class TextDetector(BaseOCRV20): return batch_results, time.time() - starttime - + # 批处理推理 - with torch.no_grad(): - inp = torch.from_numpy(batch_tensor) @@ -484,14 +525,14 @@ index 3de483ac..23813db9 100755 + inp = torch.from_numpy(batch_tensor) + inp = inp.to(self.device) + outputs = self.net(inp) -+ # 处理输出 ++ # 处理输出 preds = {} if self.det_algorithm == "EAST": preds['f_geo'] = outputs['f_geo'].cpu().numpy() @@ -304,10 +308,11 @@ class TextDetector(BaseOCRV20): img = img.copy() starttime = time.time() - + - with torch.no_grad(): - inp = torch.from_numpy(img) - inp = inp.to(self.device) @@ -501,7 +542,7 @@ index 3de483ac..23813db9 100755 + inp = torch.from_numpy(img) + inp = inp.to(self.device) + outputs = self.net(inp) - + preds = {} if self.det_algorithm == "EAST": diff --git a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py @@ -511,7 +552,7 @@ index c06ca5fe..d865b201 100755 @@ -94,6 +94,9 @@ class TextRecognizer(BaseOCRV20): self.net.eval() self.net.to(self.device) - + + import threading + self._dev_lock = getattr(self, "_dev_lock", threading.Lock()) + @@ -589,7 +630,7 @@ index c06ca5fe..d865b201 100755 - gsrm_slf_attn_bias1_list) - gsrm_slf_attn_bias2_list = np.concatenate( - gsrm_slf_attn_bias2_list) - + + # for beg_img_no in range(0, img_num, batch_num): + from concurrent.futures import ThreadPoolExecutor, as_completed + def _rec_batch_worker(beg_img_no: int, end_img_no: int): @@ -665,7 +706,7 @@ index c06ca5fe..d865b201 100755 inp = torch.from_numpy(norm_img_batch) encoder_word_pos_inp = torch.from_numpy(encoder_word_pos_list) @@ -384,58 +391,67 @@ class TextRecognizer(BaseOCRV20): - + backbone_out = self.net.backbone(inp) # backbone_feat prob_out = self.net.head(backbone_out, [encoder_word_pos_inp, gsrm_word_pos_inp, gsrm_slf_attn_bias1_inp, gsrm_slf_attn_bias2_inp]) - # preds = {"predict": prob_out[2]} @@ -695,7 +736,7 @@ index c06ca5fe..d865b201 100755 inp = torch.from_numpy(norm_img_batch) inp = inp.to(self.device) preds = self.net(inp) - + - elif self.rec_algorithm == "CAN": - starttime = time.time() - norm_img_mask_batch = np.concatenate(norm_img_mask_batch) @@ -706,7 +747,7 @@ index c06ca5fe..d865b201 100755 + norm_img_mask_batch = np.concatenate(norm_img_mask_batch) + word_label_list = np.concatenate(word_label_list) + inputs = [norm_img_batch, norm_img_mask_batch, word_label_list] - + - inp = [torch.from_numpy(e_i) for e_i in inputs] - inp = [e_i.to(self.device) for e_i in inp] + inp = [torch.from_numpy(e_i) for e_i in inputs] @@ -715,13 +756,13 @@ index c06ca5fe..d865b201 100755 with torch.no_grad(): outputs = self.net(inp) outputs = [v.cpu().numpy() for k, v in enumerate(outputs)] - + - preds = outputs - - else: - starttime = time.time() + preds = outputs - + + else: + with self._dev_lock: with torch.no_grad(): @@ -730,13 +771,13 @@ index c06ca5fe..d865b201 100755 + inp = torch.from_numpy(norm_img_batch).to(self.device) prob_out = self.net(inp) + preds = [v.cpu().numpy() for v in prob_out] if isinstance(prob_out, list) else prob_out.cpu().numpy() - + - if isinstance(prob_out, list): - preds = [v.cpu().numpy() for v in prob_out] - else: - preds = prob_out.cpu().numpy() + rec_result = self.postprocess_op(preds) - + - rec_result = self.postprocess_op(preds) - for rno in range(len(rec_result)): - rec_res[indices[beg_img_no + rno]] = rec_result[rno] @@ -756,7 +797,7 @@ index c06ca5fe..d865b201 100755 + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + futures.append(ex.submit(_rec_batch_worker, beg_img_no, end_img_no)) - + - # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size - current_batch_size = min(batch_num, img_num - index * batch_num) - index += 1 @@ -765,30 +806,30 @@ index c06ca5fe..d865b201 100755 + n_done, batch_elapse = fut.result() + elapse += batch_elapse + pbar.update(n_done) - + # Fix NaN values in recognition results for i in range(len(rec_res)): diff --git a/mineru/model/table/rapid_table.py b/mineru/model/table/rapid_table.py -index 174a8052..dd796bcc +index 174a8052..dd796bcc 100644 --- a/mineru/model/table/rapid_table.py +++ b/mineru/model/table/rapid_table.py @@ -21,6 +21,8 @@ class RapidTableModel(object): self.table_model = RapidTable(input_args) self.ocr_engine = ocr_engine - + + import threading + self._dev_lock = getattr(self, "_dev_lock", threading.Lock()) - + def predict(self, image): bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) @@ -30,44 +32,45 @@ class RapidTableModel(object): img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0 img_is_portrait = img_aspect_ratio > 1.2 - + - if img_is_portrait: + with self._dev_lock: + if img_is_portrait: - + - det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0] - # Check if table is rotated by analyzing text box aspect ratios - is_rotated = False @@ -799,22 +840,22 @@ index 174a8052..dd796bcc + is_rotated = False + if det_res: + vertical_count = 0 - + - for box_ocr_res in det_res: - p1, p2, p3, p4 = box_ocr_res + for box_ocr_res in det_res: + p1, p2, p3, p4 = box_ocr_res - + - # Calculate width and height - width = p3[0] - p1[0] - height = p3[1] - p1[1] + # Calculate width and height + width = p3[0] - p1[0] + height = p3[1] - p1[1] - + - aspect_ratio = width / height if height > 0 else 1.0 + aspect_ratio = width / height if height > 0 else 1.0 - + - # Count vertical vs horizontal text boxes - if aspect_ratio < 0.8: # Taller than wide - vertical text - vertical_count += 1 @@ -825,7 +866,7 @@ index 174a8052..dd796bcc + vertical_count += 1 + # elif aspect_ratio > 1.2: # Wider than tall - horizontal text + # horizontal_count += 1 - + - # If we have more vertical text boxes than horizontal ones, - # and vertical ones are significant, table might be rotated - if vertical_count >= len(det_res) * 0.3: @@ -834,10 +875,10 @@ index 174a8052..dd796bcc + # and vertical ones are significant, table might be rotated + if vertical_count >= len(det_res) * 0.3: + is_rotated = True - + - # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}") + # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}") - + - # Rotate image if necessary - if is_rotated: - # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise") @@ -848,7 +889,7 @@ index 174a8052..dd796bcc + # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise") + image = cv2.rotate(np.asarray(image), cv2.ROTATE_90_CLOCKWISE) + bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) - + - # Continue with OCR on potentially rotated image - ocr_result = self.ocr_engine.ocr(bgr_image)[0] + # Continue with OCR on potentially rotated image -- Gitee From b4dd7c838e423bf0a9e075255295923d1d147e1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Tue, 21 Oct 2025 22:28:01 +0800 Subject: [PATCH 10/20] 1 --- .../built-in/ocr/MinerU/ultralytics.patch | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch b/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch index 2647c7695d..4baf8c7b15 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch +++ b/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch @@ -1,5 +1,5 @@ -diff -ruN ultralytics-8.3.193/ultralytics/data/loaders.py ultralytics/data/loaders.py ---- ultralytics-8.3.193/ultralytics/data/loaders.py 2025-09-04 19:51:11.000000000 +0800 +diff -ruN ultralytics/data/loaders.py ultralytics/data/loaders.py +--- ultralytics/data/loaders.py 2025-09-04 19:51:11.000000000 +0800 +++ ultralytics/data/loaders.py 2025-10-19 01:27:48.412000000 +0800 @@ -534,7 +534,7 @@ self.bs = len(self.im0) @@ -30,8 +30,8 @@ diff -ruN ultralytics-8.3.193/ultralytics/data/loaders.py ultralytics/data/loade def __len__(self) -> int: """Return the length of the 'im0' attribute, representing the number of loaded images.""" return len(self.im0) -diff -ruN ultralytics-8.3.193/ultralytics/engine/model.py ultralytics/engine/model.py ---- ultralytics-8.3.193/ultralytics/engine/model.py 2025-09-04 19:51:11.000000000 +0800 +diff -ruN ultralytics/engine/model.py ultralytics/engine/model.py +--- ultralytics/engine/model.py 2025-09-04 19:51:11.000000000 +0800 +++ ultralytics/engine/model.py 2025-10-19 01:27:48.412000000 +0800 @@ -152,6 +152,8 @@ else: @@ -42,8 +42,8 @@ diff -ruN ultralytics-8.3.193/ultralytics/engine/model.py ultralytics/engine/mod # Delete super().training for accessing self.model.training del self.training -diff -ruN ultralytics-8.3.193/ultralytics/engine/predictor.py ultralytics/engine/predictor.py ---- ultralytics-8.3.193/ultralytics/engine/predictor.py 2025-09-04 19:51:11.000000000 +0800 +diff -ruN ultralytics/engine/predictor.py ultralytics/engine/predictor.py +--- ultralytics/engine/predictor.py 2025-09-04 19:51:11.000000000 +0800 +++ ultralytics/engine/predictor.py 2025-10-19 01:27:48.412000000 +0800 @@ -43,6 +43,7 @@ import cv2 @@ -150,8 +150,9 @@ diff -ruN ultralytics-8.3.193/ultralytics/engine/predictor.py ultralytics/engine + fuse=False, verbose=verbose, ) -diff -ruN ultralytics-8.3.193/ultralytics/nn/modules/block.py ultralytics/nn/modules/block.py ---- ultralytics-8.3.193/ultralytics/nn/modules/block.py 2025-09-04 19:51:11.000000000 +0800 + +diff -ruN ultralytics/nn/modules/block.py ultralytics/nn/modules/block.py +--- ultralytics/nn/modules/block.py 2025-09-04 19:51:11.000000000 +0800 +++ ultralytics/nn/modules/block.py 2025-10-19 01:27:48.424000000 +0800 @@ -237,7 +237,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -175,8 +176,8 @@ diff -ruN ultralytics-8.3.193/ultralytics/nn/modules/block.py ultralytics/nn/mod return self.cv2(torch.cat(y, 1)) def forward_split(self, x: torch.Tensor) -> torch.Tensor: -diff -ruN ultralytics-8.3.193/ultralytics/utils/tal.py ultralytics/utils/tal.py ---- ultralytics-8.3.193/ultralytics/utils/tal.py 2025-09-04 19:51:11.000000000 +0800 +diff -ruN ultralytics/utils/tal.py ultralytics/utils/tal.py +--- ultralytics/utils/tal.py 2025-09-04 19:51:11.000000000 +0800 +++ ultralytics/utils/tal.py 2025-10-19 01:27:48.428000000 +0800 @@ -375,7 +375,8 @@ sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y -- Gitee From 0bdf7ddd1717965e334f65ea0a64a1743ef05c27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Mon, 27 Oct 2025 16:15:11 +0800 Subject: [PATCH 11/20] modify dataset validation metric --- ACL_PyTorch/built-in/ocr/MinerU/README.md | 69 ++++++++++++++++--- .../built-in/ocr/MinerU/overall_metric.py | 48 +++++++++++++ 2 files changed, 109 insertions(+), 8 deletions(-) create mode 100644 ACL_PyTorch/built-in/ocr/MinerU/overall_metric.py diff --git a/ACL_PyTorch/built-in/ocr/MinerU/README.md b/ACL_PyTorch/built-in/ocr/MinerU/README.md index af345076b8..253b0ea2d6 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/README.md +++ b/ACL_PyTorch/built-in/ocr/MinerU/README.md @@ -44,6 +44,8 @@ MinerU是由上海人工智能实验室OpenDataLab团队开发的开源文档解 1. 获取`Pytorch`源码 ``` + git clone https://gitee.com/ascend/ModelZoo-PyTorch.git + cd ModelZoo-PyTorch/ACL_PyTorch/built-in/ocr/MinerU git clone https://github.com/opendatalab/MinerU.git cd MinerU git reset --hard de41fa58590263e43b783fe224b6d07cae290a33 @@ -155,33 +157,76 @@ python3 infer.py --data_path=OmniDocBench_dataset --model_source=local 1. 推理结果整理 - 将解析结果文件夹中的markdown文件整理放置于同一目录,本例将所有markdown文件存放于OmniDocBench_dataset目录下的results_md文件夹 + 将解析结果文件夹中的markdown文件整理放置于同一目录,本例将所有markdown文件存放于OmniDocBench_dataset目录下的`end2end`文件夹 ``` - cp OmniDocBench_dataset/output/*/auto/*.md OmniDocBench_dataset/results_md/ + cp OmniDocBench_dataset/output/*/auto/*.md OmniDocBench_dataset/end2end/ ``` 2. 获取测评源码并构建环境 + + - 安装OmniDocBench基础环境 ``` git clone https://github.com/opendatalab/OmniDocBench.git cd OmniDocBench - git reset --hard dc96d812d219960773399c02ae8f89e4706120d4 + git reset --hard 523fd1d529c3e9d0088c662e983aa70fb9585c9a conda create -n omnidocbench python=3.10 conda activate omnidocbench pip install -r requirements.txt ``` + - 公式精度指标CDM需要额外安装环境 + + step.1 install nodejs + ``` + wget https://nodejs.org/dist/v16.13.1/node-v16.13.1-linux-arm64.tar.xz + tar -xf node-v16.13.1-linux-arm64.tar.xz + mv node-v16.13.1-linux-arm64/* /usr/local/nodejs/ + ln -s /usr/local/nodejs/bin/node /usr/local/bin + ln -s /usr/local/nodejs/bin/npm /usr/local/bin + node -v + ``` + + step.2 install imagemagic + ``` + git clone https://github.com/ImageMagick/ImageMagick.git ImageMagick-7.1.2 + cd ImageMagick-7.1.2 + apt-get update && apt-get install -y libpng-dev zlib1g-dev + apt-get install -y ghostscript + ./configure + make + sudo make install + sudo ldconfig /usr/local/lib + convert --version + ``` + + step.3 install latexpdf + ``` + sudo apt-get install texlive-full + ``` + + step.4 install python requriements + ``` + pip install -r requirements.txt + ``` + 3. 测评配置修改 修改`OmniDocBench`测评代码中的config文件,具体来说,我们使用端到端测评配置,修改configs/end2end.yaml文件中的ground_truth的data_path为下载的OmniDocBench.json路径,修改prediction的data_path中提供整理的推理结果的文件夹路径,如下: ``` # -----以下是需要修改的部分 ----- + display_formula: + metric: + - Edit_dist + - CDM ### 安装好CDM环境后,可以在config文件中设置并直接计算 + - CDM_plain + ... dataset: dataset_name: end2end_dataset ground_truth: data_path: ../OmniDocBench_dataset/OmniDocBench.json prediction: - data_path: ../OmniDocBench_dataset/results_md + data_path: ../OmniDocBench_dataset/end2end ``` 4. 精度测量结果 @@ -190,10 +235,18 @@ python3 infer.py --data_path=OmniDocBench_dataset --model_source=local ``` python pdf_validation.py --config ./configs/end2end.yaml ``` + 评测结果将会存储在result目录下,Overall指标的计算方式为: + $$\text{Overall} = \frac{(1-\textit{Text Edit Distance}) \times 100 + \textit{Table TEDS} +\textit{Formula CDM}}{3}$$ + + 运行overall_metric.py可以得到精度结果: + ``` + cd .. + python overall_metric.py + ``` - 在`OmniDocBench`数据集上的精度为: - |模型|芯片|overall_EN|overall_CH| + 在`OmniDocBench`数据集上的精度和性能数据分别为: + |模型|芯片|overall|性能| |------|------|------|------| - |MinerU|300I DUO|0.1588|0.2527| - |MinerU|800I A2 64G|0.1580|0.2510| + |MinerU|300I DUO||------| + |MinerU|800I A2 64G|81.51|------| diff --git a/ACL_PyTorch/built-in/ocr/MinerU/overall_metric.py b/ACL_PyTorch/built-in/ocr/MinerU/overall_metric.py new file mode 100644 index 0000000000..74ef8f8f39 --- /dev/null +++ b/ACL_PyTorch/built-in/ocr/MinerU/overall_metric.py @@ -0,0 +1,48 @@ +import os +import pandas as pd +import numpy as np +import json +import argparse + +parser = argparse.ArgumentParser(description='result path') +parser.add_argument('--result', type=str, default='OmniDocBench/result') +args = parser.parse_args(args) + + +ocr_types_dict = { + 'end2end': 'end2end' +} + +result_folder = args.result + +# match_name = 'no_split' +match_name = 'quick_match' + +# overall result: not distinguishing between Chinese and English, page-level average + +dict_list = [] + +for ocr_type in ocr_types_dict.values(): + result_path = os.path.join(result_folder, f'{ocr_type}_{match_name}_metric_result.json') + + with open(result_path, 'r') as f: + result = json.load(f) + + save_dict = {} + + for category_type, metric in [("text_block", "Edit_dist"), ("display_formula", "CDM"), ("table", "TEDS"), ("table", "TEDS_structure_only"), ("reading_order", "Edit_dist")]: + if metric == 'CDM' or metric == "TEDS" or metric == "TEDS_structure_only": + if result[category_type]["page"].get(metric): + save_dict[category_type+'_'+metric] = result[category_type]["page"][metric]["ALL"] * 100 # page级别的avg + else: + save_dict[category_type+'_'+metric] = 0 + else: + save_dict[category_type+'_'+metric] = result[category_type]["all"][metric].get("ALL_page_avg", np.nan) + + dict_list.append(save_dict) + +df = pd.DataFrame(dict_list, index=ocr_types_dict.keys()).round(3) +df['overall'] = ((1-df['text_block_Edit_dist'])*100 + df['display_formula_CDM'] + df['table_TEDS'])/3 +# df.to_csv('./overall.csv') + +print(df) -- Gitee From b22f99ef4a2bd2b7c9bca8ad13ed1eba2a0e5927 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Mon, 27 Oct 2025 16:19:09 +0800 Subject: [PATCH 12/20] 1 --- ACL_PyTorch/built-in/ocr/MinerU/README.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ACL_PyTorch/built-in/ocr/MinerU/README.md b/ACL_PyTorch/built-in/ocr/MinerU/README.md index 253b0ea2d6..20610724af 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/README.md +++ b/ACL_PyTorch/built-in/ocr/MinerU/README.md @@ -75,11 +75,12 @@ MinerU是由上海人工智能实验室OpenDataLab团队开发的开源文档解 ``` source_path=/usr/local/lib/python3.11/site-packages cd ${source_path}/ultralytics - patch -p2 < ${workdir}/ultralytics.patch + patch -p1 < ${workdir}/ultralytics.patch cd ${source_path}/doclayout_yolo - patch -p2 < ${workdir}/doclayout_yolo.patch - cd ${workdir} - patch -p0 < mfr_encoder_mhsa.patch + patch -p1 < ${workdir}/doclayout_yolo.patch + cd ${workdir}/MinerU + git apply ../mineru.patch + cd .. ``` ## 获取权重 -- Gitee From ba4d57f12be6ef902c10665f720d1a0982a35241 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Tue, 28 Oct 2025 15:28:14 +0800 Subject: [PATCH 13/20] 1 --- ACL_PyTorch/built-in/ocr/MinerU/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ACL_PyTorch/built-in/ocr/MinerU/README.md b/ACL_PyTorch/built-in/ocr/MinerU/README.md index 20610724af..cf6a432c11 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/README.md +++ b/ACL_PyTorch/built-in/ocr/MinerU/README.md @@ -246,8 +246,8 @@ python3 infer.py --data_path=OmniDocBench_dataset --model_source=local ``` 在`OmniDocBench`数据集上的精度和性能数据分别为: - |模型|芯片|overall|性能| + |模型|芯片|overall|性能(s)| |------|------|------|------| - |MinerU|300I DUO||------| - |MinerU|800I A2 64G|81.51|------| + |MinerU|300I DUO|81.68| 3.37 | + |MinerU|800I A2 64G|81.51| 1.85 | -- Gitee From d0914a1257bc41aacb2b1744d931313121adb74e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Tue, 28 Oct 2025 15:32:45 +0800 Subject: [PATCH 14/20] 1 --- ACL_PyTorch/built-in/ocr/MinerU/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ACL_PyTorch/built-in/ocr/MinerU/README.md b/ACL_PyTorch/built-in/ocr/MinerU/README.md index cf6a432c11..cc0e765fb1 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/README.md +++ b/ACL_PyTorch/built-in/ocr/MinerU/README.md @@ -208,7 +208,7 @@ python3 infer.py --data_path=OmniDocBench_dataset --model_source=local step.4 install python requriements ``` - pip install -r requirements.txt + pip install -r metrics/cdm/requirements.txt ``` 3. 测评配置修改 -- Gitee From 7718b058b9acea2916e0343260390a60024195de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Tue, 28 Oct 2025 17:13:04 +0800 Subject: [PATCH 15/20] 1 --- ACL_PyTorch/built-in/ocr/MinerU/overall_metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ACL_PyTorch/built-in/ocr/MinerU/overall_metric.py b/ACL_PyTorch/built-in/ocr/MinerU/overall_metric.py index 74ef8f8f39..127306e56c 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/overall_metric.py +++ b/ACL_PyTorch/built-in/ocr/MinerU/overall_metric.py @@ -6,7 +6,7 @@ import argparse parser = argparse.ArgumentParser(description='result path') parser.add_argument('--result', type=str, default='OmniDocBench/result') -args = parser.parse_args(args) +args = parser.parse_args() ocr_types_dict = { -- Gitee From 35b5a8d7356605a2c0dd765743c1c18825b77527 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Tue, 28 Oct 2025 17:29:12 +0800 Subject: [PATCH 16/20] 1 --- ACL_PyTorch/built-in/ocr/MinerU/overall_metric.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/ACL_PyTorch/built-in/ocr/MinerU/overall_metric.py b/ACL_PyTorch/built-in/ocr/MinerU/overall_metric.py index 127306e56c..5d89b41693 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/overall_metric.py +++ b/ACL_PyTorch/built-in/ocr/MinerU/overall_metric.py @@ -1,9 +1,10 @@ import os -import pandas as pd -import numpy as np import json import argparse +import pandas as pd +import numpy as np + parser = argparse.ArgumentParser(description='result path') parser.add_argument('--result', type=str, default='OmniDocBench/result') args = parser.parse_args() @@ -15,7 +16,6 @@ ocr_types_dict = { result_folder = args.result -# match_name = 'no_split' match_name = 'quick_match' # overall result: not distinguishing between Chinese and English, page-level average @@ -33,16 +33,15 @@ for ocr_type in ocr_types_dict.values(): for category_type, metric in [("text_block", "Edit_dist"), ("display_formula", "CDM"), ("table", "TEDS"), ("table", "TEDS_structure_only"), ("reading_order", "Edit_dist")]: if metric == 'CDM' or metric == "TEDS" or metric == "TEDS_structure_only": if result[category_type]["page"].get(metric): - save_dict[category_type+'_'+metric] = result[category_type]["page"][metric]["ALL"] * 100 # page级别的avg + save_dict[category_type + '_' + metric] = result[category_type]["page"][metric]["ALL"] * 100 # page级别的avg else: - save_dict[category_type+'_'+metric] = 0 + save_dict[category_type + '_' + metric] = 0 else: - save_dict[category_type+'_'+metric] = result[category_type]["all"][metric].get("ALL_page_avg", np.nan) + save_dict[category_type + '_' + metric] = result[category_type]["all"][metric].get("ALL_page_avg", np.nan) dict_list.append(save_dict) df = pd.DataFrame(dict_list, index=ocr_types_dict.keys()).round(3) -df['overall'] = ((1-df['text_block_Edit_dist'])*100 + df['display_formula_CDM'] + df['table_TEDS'])/3 -# df.to_csv('./overall.csv') +df['overall'] = ((1 - df['text_block_Edit_dist']) * 100 + df['display_formula_CDM'] + df['table_TEDS']) / 3 print(df) -- Gitee From 9a97ed74909a55cad59974fcddaae41ec6c17db1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Tue, 28 Oct 2025 18:41:54 +0800 Subject: [PATCH 17/20] 1 --- ACL_PyTorch/built-in/ocr/MinerU/mineru.patch | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch b/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch index 0e57a9c2c0..24e709cd61 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch +++ b/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch @@ -206,25 +206,6 @@ index c88a52a3..b0b79a80 100644 # Create dictionaries to store items by language need_ocr_lists_by_lang = {} # Dict of lists for each language img_crop_lists_by_lang = {} # Dict of lists for each language -diff --git a/mineru/cli/common.py b/mineru/cli/common.py -index cd9f0803..3cdf7f21 100644 ---- a/mineru/cli/common.py -+++ b/mineru/cli/common.py -@@ -34,7 +34,13 @@ def read_fn(path): - - - def prepare_env(output_dir, pdf_file_name, parse_method): -- local_md_dir = str(os.path.join(output_dir, pdf_file_name, parse_method)) -+ print(pdf_file_name) -+ if len(pdf_file_name)>100: -+ pdf_file_name_ = pdf_file_name[:30] + pdf_file_name[-30:] -+ else: -+ pdf_file_name_ = pdf_file_name -+ -+ local_md_dir = str(os.path.join(output_dir, pdf_file_name_, parse_method)) - local_image_dir = os.path.join(str(local_md_dir), "images") - os.makedirs(local_image_dir, exist_ok=True) - os.makedirs(local_md_dir, exist_ok=True) diff --git a/mineru/model/layout/doclayout_yolo.py b/mineru/model/layout/doclayout_yolo.py index 5667a909..fc5056bb 100644 --- a/mineru/model/layout/doclayout_yolo.py -- Gitee From 2de1ff88b542ded015e975d5d9bc330465e28d05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Tue, 28 Oct 2025 19:04:33 +0800 Subject: [PATCH 18/20] 1 --- ACL_PyTorch/built-in/ocr/MinerU/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ACL_PyTorch/built-in/ocr/MinerU/README.md b/ACL_PyTorch/built-in/ocr/MinerU/README.md index cc0e765fb1..00f78f10c4 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/README.md +++ b/ACL_PyTorch/built-in/ocr/MinerU/README.md @@ -73,7 +73,8 @@ MinerU是由上海人工智能实验室OpenDataLab团队开发的开源文档解 3. 修改第三方库 进入第三方库安装路径,默认为`source_path = /usr/local/lib/python3.11/site-packages`,通过工作目录`workdir`(自定义)中的`ultralytics.patch`和`doclayout_yolo.patch`进行修改 ``` - source_path=/usr/local/lib/python3.11/site-packages + workdir=$(pwd) + source_path=$(pip show ultralytics | grep Location | awk '{print $2}') cd ${source_path}/ultralytics patch -p1 < ${workdir}/ultralytics.patch cd ${source_path}/doclayout_yolo -- Gitee From c51215f40bb941924dbf25e680ec77f2d8a734c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Wed, 29 Oct 2025 09:20:51 +0800 Subject: [PATCH 19/20] 1 --- ACL_PyTorch/built-in/ocr/MinerU/mineru.patch | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch b/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch index 24e709cd61..c031d1413a 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch +++ b/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch @@ -56,7 +56,7 @@ index c88a52a3..b0b79a80 100644 # 将尺寸标准化到32的倍数 - normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数 - normalized_w = ((w + 32) // 32) * 32 -+ normalized_h = ((h + stride) // stride) * stride # 向上取整到32的倍数 ++ normalized_h = ((h + stride) // stride) * stride # 向上取整到stride的倍数 + normalized_w = ((w + stride) // stride) * stride group_key = (normalized_h, normalized_w) resolution_groups[group_key].append(crop_info) -- Gitee From 1a2bd6a4229518cb30720c44d6b9c75876ca6fb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E4=BA=A6=E8=88=9F?= Date: Wed, 29 Oct 2025 10:39:20 +0800 Subject: [PATCH 20/20] 1 --- ACL_PyTorch/built-in/ocr/MinerU/mineru.patch | 881 ------------------- 1 file changed, 881 deletions(-) diff --git a/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch b/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch index c031d1413a..e69de29bb2 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch +++ b/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch @@ -1,881 +0,0 @@ -diff --git a/demo/demo.py b/demo/demo.py -index 36433c45..6f28620f 100644 ---- a/demo/demo.py -+++ b/demo/demo.py -@@ -86,7 +86,7 @@ def do_parse( - image_dir = str(os.path.basename(local_image_dir)) - content_list = pipeline_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir) - md_writer.write_string( -- f"{pdf_file_name}_content_list.json", -+ f"{pdf_file_name}_content.json", - json.dumps(content_list, ensure_ascii=False, indent=4), - ) - -@@ -142,7 +142,8 @@ def do_parse( - image_dir = str(os.path.basename(local_image_dir)) - content_list = vlm_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir) - md_writer.write_string( -- f"{pdf_file_name}_content_list.json", -+ # f"{pdf_file_name}_content_list.json", -+ f"{pdf_file_name}_content.json", ## 文件名太长了,linux文件系统ext4超过255字节无法保存 - json.dumps(content_list, ensure_ascii=False, indent=4), - ) - -diff --git a/mineru/backend/pipeline/batch_analyze.py b/mineru/backend/pipeline/batch_analyze.py -index c88a52a3..b0b79a80 100644 ---- a/mineru/backend/pipeline/batch_analyze.py -+++ b/mineru/backend/pipeline/batch_analyze.py -@@ -3,6 +3,9 @@ from loguru import logger - from tqdm import tqdm - from collections import defaultdict - import numpy as np -+import time -+import torch -+import torch_npu - - from .model_init import AtomModelSingleton - from ...utils.config_reader import get_formula_enable, get_table_enable -@@ -95,6 +98,7 @@ class BatchAnalyze: - }) - - # OCR检测处理 -+ from concurrent.futures import ThreadPoolExecutor, as_completed - if self.enable_ocr_det_batch: - # 批处理模式 - 按语言和分辨率分组 - # 收集所有需要OCR检测的裁剪图像 -@@ -139,79 +143,73 @@ class BatchAnalyze: - ) - - # 按分辨率分组并同时完成padding -+ stride = 64 - resolution_groups = defaultdict(list) - for crop_info in lang_crop_list: - cropped_img = crop_info[0] - h, w = cropped_img.shape[:2] - # 使用更大的分组容差,减少分组数量 - # 将尺寸标准化到32的倍数 -- normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数 -- normalized_w = ((w + 32) // 32) * 32 -+ normalized_h = ((h + stride) // stride) * stride # 向上取整到stride的倍数 -+ normalized_w = ((w + stride) // stride) * stride - group_key = (normalized_h, normalized_w) - resolution_groups[group_key].append(crop_info) - -- # 对每个分辨率组进行批处理 -- for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"): -- -- # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数) -- max_h = max(crop_info[0].shape[0] for crop_info in group_crops) -- max_w = max(crop_info[0].shape[1] for crop_info in group_crops) -- target_h = ((max_h + 32 - 1) // 32) * 32 -- target_w = ((max_w + 32 - 1) // 32) * 32 -- -- # 对所有图像进行padding到统一尺寸 -- batch_images = [] -- for crop_info in group_crops: -- img = crop_info[0] -- h, w = img.shape[:2] -- # 创建目标尺寸的白色背景 -- padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255 -- # 将原图像粘贴到左上角 -- padded_img[:h, :w] = img -- batch_images.append(padded_img) -- -- # 批处理检测 -- det_batch_size = min(len(batch_images), self.batch_ratio * OCR_DET_BASE_BATCH_SIZE) # 增加批处理大小 -- # logger.debug(f"OCR-det batch: {det_batch_size} images, target size: {target_h}x{target_w}") -- batch_results = ocr_model.text_detector.batch_predict(batch_images, det_batch_size) -- -- # 处理批处理结果 -- for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)): -- new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info -- -- if dt_boxes is not None and len(dt_boxes) > 0: -- # 直接应用原始OCR流程中的关键处理步骤 -- from mineru.utils.ocr_utils import ( -- merge_det_boxes, update_det_boxes, sorted_boxes -- ) - -- # 1. 排序检测框 -- if len(dt_boxes) > 0: -- dt_boxes_sorted = sorted_boxes(dt_boxes) -- else: -- dt_boxes_sorted = [] -- -- # 2. 合并相邻检测框 -- if dt_boxes_sorted: -- dt_boxes_merged = merge_det_boxes(dt_boxes_sorted) -- else: -- dt_boxes_merged = [] -- -- # 3. 根据公式位置更新检测框(关键步骤!) -- if dt_boxes_merged and adjusted_mfdetrec_res: -- dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res) -- else: -- dt_boxes_final = dt_boxes_merged -- -- # 构造OCR结果格式 -- ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final] -- -- if ocr_res: -- ocr_result_list = get_ocr_result_list( -- ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang -- ) -- -- ocr_res_list_dict['layout_res'].extend(ocr_result_list) -+ def _run_one_group_ocr(group_key, group_crops): -+ -+ max_h = max(ci[0].shape[0] for ci in group_crops) -+ max_w = max(ci[0].shape[1] for ci in group_crops) -+ target_h = ((max_h + stride - 1) // stride) * stride -+ target_w = ((max_w + stride - 1) // stride) * stride -+ -+ batch_images = [] -+ for ci in group_crops: -+ img = ci[0] -+ h, w = img.shape[:2] -+ padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255 -+ padded_img[:h, :w] = img -+ batch_images.append(padded_img) -+ -+ det_batch_size = min(len(batch_images), self.batch_ratio * OCR_DET_BASE_BATCH_SIZE) -+ -+ batch_results = ocr_model.text_detector.batch_predict(batch_images, det_batch_size) -+ -+ for i, (ci, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)): -+ new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = ci -+ if dt_boxes is not None and len(dt_boxes) > 0: -+ from mineru.utils.ocr_utils import merge_det_boxes, update_det_boxes, sorted_boxes -+ -+ if len(dt_boxes) > 0: -+ dt_boxes_sorted = sorted_boxes(dt_boxes) -+ else: -+ dt_boxes_sorted = [] -+ -+ if dt_boxes_sorted: -+ dt_boxes_merged = merge_det_boxes(dt_boxes_sorted) -+ else: -+ dt_boxes_merged = [] -+ -+ if dt_boxes_merged and adjusted_mfdetrec_res: -+ dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res) -+ else: -+ dt_boxes_final = dt_boxes_merged -+ -+ ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final] -+ if ocr_res: -+ ocr_result_list = get_ocr_result_list( -+ ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang -+ ) -+ ocr_res_list_dict['layout_res'].extend(ocr_result_list) -+ -+ MAX_WORKERS = 4 -+ start = time.time() -+ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex: -+ futures = [ex.submit(_run_one_group_ocr, gk, gcs) for gk, gcs in resolution_groups.items()] -+ for f in as_completed(futures): -+ f.result() -+ end = time.time() -+ logger.info(f"ocr det run time : {end -start}") - else: - # 原始单张处理模式 - for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"): -@@ -247,7 +245,7 @@ class BatchAnalyze: - - # 表格识别 table recognition - if self.table_enable: -- for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"): -+ def _run_one_group_table(table_res_dict): - _lang = table_res_dict['lang'] - table_model = atom_model_manager.get_atom_model( - atom_model_name='table', -@@ -271,6 +269,16 @@ class BatchAnalyze: - 'table recognition processing fails, not get html return' - ) - -+ -+ MAX_WORKERS = 4 -+ start = time.time() -+ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex: -+ futures = [ex.submit(_run_one_group_table, table_res_dict) for table_res_dict in table_res_list_all_page] -+ for f in as_completed(futures): -+ f.result() -+ end = time.time() -+ logger.info(f"table run time : {end - start}") -+ - # Create dictionaries to store items by language - need_ocr_lists_by_lang = {} # Dict of lists for each language - img_crop_lists_by_lang = {} # Dict of lists for each language -diff --git a/mineru/model/layout/doclayout_yolo.py b/mineru/model/layout/doclayout_yolo.py -index 5667a909..fc5056bb 100644 ---- a/mineru/model/layout/doclayout_yolo.py -+++ b/mineru/model/layout/doclayout_yolo.py -@@ -66,6 +66,7 @@ class DocLayoutYOLOModel: - conf=self.conf, - iou=self.iou, - verbose=False, -+ half=True - ) - for pred in predictions: - results.append(self._parse_prediction(pred)) -diff --git a/mineru/model/mfd/yolo_v8.py b/mineru/model/mfd/yolo_v8.py -index 33dac091..1fb4b50e 100644 ---- a/mineru/model/mfd/yolo_v8.py -+++ b/mineru/model/mfd/yolo_v8.py -@@ -31,7 +31,8 @@ class YOLOv8MFDModel: - conf=self.conf, - iou=self.iou, - verbose=False, -- device=self.device -+ device=self.device, -+ half=True - ) - return [pred.cpu() for pred in preds] if is_batch else preds[0].cpu() - -diff --git a/mineru/model/mfr/unimernet/Unimernet.py b/mineru/model/mfr/unimernet/Unimernet.py -index ae3879da..23e56f2a 100644 ---- a/mineru/model/mfr/unimernet/Unimernet.py -+++ b/mineru/model/mfr/unimernet/Unimernet.py -@@ -1,7 +1,7 @@ - import torch - from torch.utils.data import DataLoader, Dataset - from tqdm import tqdm -- -+import numpy as np - - class MathDataset(Dataset): - def __init__(self, image_paths, transform=None): -@@ -61,7 +61,7 @@ class UnimernetModel(object): - res["latex"] = latex - return formula_list - -- def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list: -+ def _batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list: - images_formula_list = [] - mf_image_list = [] - backfill_list = [] -@@ -137,3 +137,94 @@ class UnimernetModel(object): - res["latex"] = latex - - return images_formula_list -+ -+ -+ def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list: -+ -+ images_formula_list = [] -+ mf_image_list = [] -+ backfill_list = [] -+ image_info = [] # Store (area, original_index, image) tuples -+ -+ # Collect images with their original indices -+ for image_index in range(len(images_mfd_res)): -+ mfd_res = images_mfd_res[image_index] -+ pil_img = images[image_index] -+ # split代替多次索引 -+ data = mfd_res.boxes.data.numpy() -+ xyxy, conf, cla = np.split(data, [4, 5], axis=-1) -+ -+ cla = cla.reshape(-1).astype(int).tolist() -+ conf = np.round(conf.reshape(-1).astype(float), 2).tolist() -+ -+ xyxy = xyxy.astype(np.int32) -+ xmin, ymin, xmax, ymax = xyxy[:, 0], xyxy[:, 1], xyxy[:, 2], xyxy[:, 3] -+ # area 直接矩阵运算 -+ areas = (xmax - xmin) * (ymax - ymin) -+ -+ num_boxes = len(conf) -+ -+ formula_list = [] -+ for i in range(num_boxes): -+ xmin_i, ymin_i, xmax_i, ymax_i = xyxy[i].tolist() -+ formula_list.append({ -+ "category_id": 13 + cla[i], -+ "poly": [xmin_i, ymin_i, xmax_i, ymin_i, -+ xmax_i, ymax_i, xmin_i, ymax_i], -+ "score": conf[i], -+ "latex": "", -+ }) -+ -+ # bbox_img 截取 -+ # bbox_img = pil_img[:, ymin_i:ymax_i, xmin_i:xmax_i] -+ bbox_img = pil_img.crop((xmin_i, ymin_i, xmax_i, ymax_i)) -+ curr_idx = len(mf_image_list) -+ image_info.append((areas[i], curr_idx, bbox_img)) -+ mf_image_list.append(bbox_img) -+ -+ images_formula_list.append(formula_list) -+ backfill_list += formula_list -+ -+ # Stable sort by area -+ image_info.sort(key=lambda x: x[0]) # sort by area -+ sorted_indices = [x[1] for x in image_info] -+ sorted_images = [x[2] for x in image_info] -+ -+ # Create mapping for results -+ index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)} -+ -+ # Create dataset with sorted images -+ dataset = MathDataset(sorted_images, transform=self.model.transform) -+ -+ # 如果batch_size > len(sorted_images),则设置为不超过len(sorted_images)的2的幂 -+ batch_size = min(batch_size, max(1, 2 ** (len(sorted_images).bit_length() - 1))) if sorted_images else 1 -+ -+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0) -+ -+ # Process batches and store results -+ mfr_res = [] -+ # for mf_img in dataloader: -+ -+ with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar: -+ for index, mf_img in enumerate(dataloader): -+ mf_img = mf_img.to(dtype=self.model.dtype) -+ mf_img = mf_img.to(self.device) -+ with torch.no_grad(): -+ output = self.model.generate({"image": mf_img}, batch_size=batch_size) -+ mfr_res.extend(output["fixed_str"]) -+ -+ # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size -+ current_batch_size = min(batch_size, len(sorted_images) - index * batch_size) -+ pbar.update(current_batch_size) -+ -+ # Restore original order -+ unsorted_results = [""] * len(mfr_res) -+ for new_idx, latex in enumerate(mfr_res): -+ original_idx = index_mapping[new_idx] -+ unsorted_results[original_idx] = latex -+ -+ # Fill results back -+ for res, latex in zip(backfill_list, unsorted_results): -+ res["latex"] = latex -+ -+ return images_formula_list -diff --git a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py -index 98d1deee..3866a257 100644 ---- a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py -+++ b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py -@@ -5,7 +5,9 @@ import cv2 - import albumentations as alb - from albumentations.pytorch import ToTensorV2 - from torchvision.transforms.functional import resize -- -+import torch -+import torch_npu -+import torch.nn.functional as F - - # TODO: dereference cv2 if possible - class UnimerSwinImageProcessor(BaseImageProcessor): -@@ -25,10 +27,53 @@ class UnimerSwinImageProcessor(BaseImageProcessor): - ] - ) - -- def __call__(self, item): -+ self.NORMALIZE_DIVISOR = torch.tensor(255.0, dtype=torch.float16, device="npu") -+ self.weights = torch.tensor([[[0.2989]], [[0.5870]], [[0.1140]]], dtype=torch.float16, device="npu") -+ self.mean = torch.tensor(0.7931, dtype=torch.float16, device="npu") -+ self.std = torch.tensor(0.1738, dtype=torch.float16, device="npu") -+ -+ self._mul_buf = torch.empty((3, *self.input_size), dtype=torch.float16, device="npu") # 预分配 [3,H,W] -+ self._gray_buf = torch.empty((1, *self.input_size), dtype=torch.float16, device="npu") # 预分配 [1,H,W] -+ -+ -+ def ___call__(self, item): - image = self.prepare_input(item) - return self.transform(image=image)['image'][:1] - -+ def pil_to_npu(self, pil_img, device="npu"): -+ img = torch.from_numpy(np.asarray(pil_img, dtype=np.float16)) -+ img = img.to(device).permute(2, 0, 1) / self.NORMALIZE_DIVISOR -+ return img -+ -+ def __call__(self, item): -+ -+ img = self.crop_margin(item) -+ img = self.pil_to_npu(img) -+ -+ _, h, w = img.shape -+ target_h, target_w = self.input_size -+ scale = min(target_h / h, target_w / w) -+ new_h, new_w = int(h*scale), int(w*scale) -+ -+ img = img.view(1, *img.shape) # [1,C,H,W] -+ img = F.interpolate(img, size=(new_h, new_w), mode='bilinear', align_corners=False) -+ img = img.view(*img.shape[1:]) -+ -+ dw, dh = target_w - new_w, target_h - new_h -+ dw /= 2 -+ dh /= 2 -+ left, right = int(dw), int(dw + 0.5) -+ top, bottom = int(dh), int(dh + 0.5) -+ img = F.pad(img, (left, right, top, bottom), value=0.0) -+ -+ # RGB -> Gray -+ gray_tensor = (img * self.weights).sum(dim=0, keepdim=True) # [1, H, W] -+ -+ # Normalize -+ gray_tensor.sub_(self.mean).div_(self.std) -+ return gray_tensor -+ -+ - @staticmethod - def crop_margin(img: Image.Image) -> Image.Image: - data = np.array(img.convert("L")) -@@ -44,6 +89,32 @@ class UnimerSwinImageProcessor(BaseImageProcessor): - a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box - return img.crop((a, b, w + a, h + b)) - -+ def crop_margin_tensor(self, img): -+ """ -+ img: [C,H,W] tensor, uint8 或 float -+ """ -+ -+ gray = (img * self.weights).sum(dim=0) -+ -+ gray = gray.to(torch.uint8) -+ max_val = gray.max() -+ min_val = gray.min() -+ -+ if max_val == min_val: -+ return img -+ -+ norm_gray = (gray - min_val) / (max_val - min_val) -+ -+ mask = (norm_gray < self.threshold) -+ -+ coords = mask.nonzero(as_tuple=False) -+ if coords.shape[0] == 0: -+ return img -+ ymin, xmin = coords.min(0)[0] -+ ymax, xmax = coords.max(0)[0] -+ -+ return img[:, ymin:ymax+1, xmin:xmax+1] -+ - @staticmethod - def crop_margin_numpy(img: np.ndarray) -> np.ndarray: - """Crop margins of image using NumPy operations""" -diff --git a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py -index 1b808e8b..0fe54751 100644 ---- a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py -+++ b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py -@@ -465,11 +465,15 @@ class UnimerSwinSelfAttention(nn.Module): - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor]: - batch_size, dim, num_channels = hidden_states.shape -- mixed_query_layer = self.query(hidden_states) - -- key_layer = self.transpose_for_scores(self.key(hidden_states)) -- value_layer = self.transpose_for_scores(self.value(hidden_states)) -- query_layer = self.transpose_for_scores(mixed_query_layer) -+ # """融合qk为大矩阵,由于加入相对位置编码,PFA接口用不了,暂时只修改矩阵乘法""" -+ batch_size, dim, num_channels = hidden_states.shape -+ qkv = self.qkv(hidden_states) -+ q, k, v = qkv.chunk(3, dim=-1) -+ -+ query_layer = q.view(*q.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) -+ key_layer = k.view(*k.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) -+ value_layer = v.view(*v.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) -diff --git a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py -index 3de483ac..23813db9 100755 ---- a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py -+++ b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py -@@ -117,6 +117,10 @@ class TextDetector(BaseOCRV20): - self.net.eval() - self.net.to(self.device) - -+ -+ import threading -+ self._dev_lock = getattr(self, "_dev_lock", threading.Lock()) -+ - def _batch_process_same_size(self, img_list): - """ - 对相同尺寸的图像进行批处理 -@@ -162,12 +166,12 @@ class TextDetector(BaseOCRV20): - return batch_results, time.time() - starttime - - # 批处理推理 -- with torch.no_grad(): -- inp = torch.from_numpy(batch_tensor) -- inp = inp.to(self.device) -- outputs = self.net(inp) -- -- # 处理输出 -+ with self._dev_lock: -+ with torch.no_grad(): -+ inp = torch.from_numpy(batch_tensor) -+ inp = inp.to(self.device) -+ outputs = self.net(inp) -+ # 处理输出 - preds = {} - if self.det_algorithm == "EAST": - preds['f_geo'] = outputs['f_geo'].cpu().numpy() -@@ -304,10 +308,11 @@ class TextDetector(BaseOCRV20): - img = img.copy() - starttime = time.time() - -- with torch.no_grad(): -- inp = torch.from_numpy(img) -- inp = inp.to(self.device) -- outputs = self.net(inp) -+ with self._dev_lock: -+ with torch.no_grad(): -+ inp = torch.from_numpy(img) -+ inp = inp.to(self.device) -+ outputs = self.net(inp) - - preds = {} - if self.det_algorithm == "EAST": -diff --git a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py -index c06ca5fe..d865b201 100755 ---- a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py -+++ b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py -@@ -94,6 +94,9 @@ class TextRecognizer(BaseOCRV20): - self.net.eval() - self.net.to(self.device) - -+ import threading -+ self._dev_lock = getattr(self, "_dev_lock", threading.Lock()) -+ - def resize_norm_img(self, img, max_wh_ratio): - imgC, imgH, imgW = self.rec_image_shape - if self.rec_algorithm == 'NRTR' or self.rec_algorithm == 'ViTSTR': -@@ -301,74 +304,78 @@ class TextRecognizer(BaseOCRV20): - rec_res = [['', 0.0]] * img_num - batch_num = self.rec_batch_num - elapse = 0 -- # for beg_img_no in range(0, img_num, batch_num): -- with tqdm(total=img_num, desc='OCR-rec Predict', disable=not tqdm_enable) as pbar: -- index = 0 -- for beg_img_no in range(0, img_num, batch_num): -- end_img_no = min(img_num, beg_img_no + batch_num) -- norm_img_batch = [] -- max_wh_ratio = 0 -- for ino in range(beg_img_no, end_img_no): -- # h, w = img_list[ino].shape[0:2] -- h, w = img_list[indices[ino]].shape[0:2] -- wh_ratio = w * 1.0 / h -- max_wh_ratio = max(max_wh_ratio, wh_ratio) -- for ino in range(beg_img_no, end_img_no): -- if self.rec_algorithm == "SAR": -- norm_img, _, _, valid_ratio = self.resize_norm_img_sar( -- img_list[indices[ino]], self.rec_image_shape) -- norm_img = norm_img[np.newaxis, :] -- valid_ratio = np.expand_dims(valid_ratio, axis=0) -- valid_ratios = [] -- valid_ratios.append(valid_ratio) -- norm_img_batch.append(norm_img) -- -- elif self.rec_algorithm == "SVTR": -- norm_img = self.resize_norm_img_svtr(img_list[indices[ino]], -- self.rec_image_shape) -- norm_img = norm_img[np.newaxis, :] -- norm_img_batch.append(norm_img) -- elif self.rec_algorithm == "SRN": -- norm_img = self.process_image_srn(img_list[indices[ino]], -- self.rec_image_shape, 8, -- self.max_text_length) -- encoder_word_pos_list = [] -- gsrm_word_pos_list = [] -- gsrm_slf_attn_bias1_list = [] -- gsrm_slf_attn_bias2_list = [] -- encoder_word_pos_list.append(norm_img[1]) -- gsrm_word_pos_list.append(norm_img[2]) -- gsrm_slf_attn_bias1_list.append(norm_img[3]) -- gsrm_slf_attn_bias2_list.append(norm_img[4]) -- norm_img_batch.append(norm_img[0]) -- elif self.rec_algorithm == "CAN": -- norm_img = self.norm_img_can(img_list[indices[ino]], -- max_wh_ratio) -- norm_img = norm_img[np.newaxis, :] -- norm_img_batch.append(norm_img) -- norm_image_mask = np.ones(norm_img.shape, dtype='float32') -- word_label = np.ones([1, 36], dtype='int64') -- norm_img_mask_batch = [] -- word_label_list = [] -- norm_img_mask_batch.append(norm_image_mask) -- word_label_list.append(word_label) -- else: -- norm_img = self.resize_norm_img(img_list[indices[ino]], -- max_wh_ratio) -- norm_img = norm_img[np.newaxis, :] -- norm_img_batch.append(norm_img) -- norm_img_batch = np.concatenate(norm_img_batch) -- norm_img_batch = norm_img_batch.copy() -- -- if self.rec_algorithm == "SRN": -- starttime = time.time() -- encoder_word_pos_list = np.concatenate(encoder_word_pos_list) -- gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) -- gsrm_slf_attn_bias1_list = np.concatenate( -- gsrm_slf_attn_bias1_list) -- gsrm_slf_attn_bias2_list = np.concatenate( -- gsrm_slf_attn_bias2_list) - -+ # for beg_img_no in range(0, img_num, batch_num): -+ from concurrent.futures import ThreadPoolExecutor, as_completed -+ def _rec_batch_worker(beg_img_no: int, end_img_no: int): -+ -+ -+ max_wh_ratio = 0.0 -+ norm_img_batch = [] -+ for ino in range(beg_img_no, end_img_no): -+ # h, w = img_list[ino].shape[0:2] -+ h, w = img_list[indices[ino]].shape[0:2] -+ wh_ratio = w * 1.0 / h -+ max_wh_ratio = max(max_wh_ratio, wh_ratio) -+ for ino in range(beg_img_no, end_img_no): -+ if self.rec_algorithm == "SAR": -+ norm_img, _, _, valid_ratio = self.resize_norm_img_sar( -+ img_list[indices[ino]], self.rec_image_shape) -+ norm_img = norm_img[np.newaxis, :] -+ valid_ratio = np.expand_dims(valid_ratio, axis=0) -+ valid_ratios = [] -+ valid_ratios.append(valid_ratio) -+ norm_img_batch.append(norm_img) -+ -+ elif self.rec_algorithm == "SVTR": -+ norm_img = self.resize_norm_img_svtr(img_list[indices[ino]], -+ self.rec_image_shape) -+ norm_img = norm_img[np.newaxis, :] -+ norm_img_batch.append(norm_img) -+ elif self.rec_algorithm == "SRN": -+ norm_img = self.process_image_srn(img_list[indices[ino]], -+ self.rec_image_shape, 8, -+ self.max_text_length) -+ encoder_word_pos_list = [] -+ gsrm_word_pos_list = [] -+ gsrm_slf_attn_bias1_list = [] -+ gsrm_slf_attn_bias2_list = [] -+ encoder_word_pos_list.append(norm_img[1]) -+ gsrm_word_pos_list.append(norm_img[2]) -+ gsrm_slf_attn_bias1_list.append(norm_img[3]) -+ gsrm_slf_attn_bias2_list.append(norm_img[4]) -+ norm_img_batch.append(norm_img[0]) -+ elif self.rec_algorithm == "CAN": -+ norm_img = self.norm_img_can(img_list[indices[ino]], -+ max_wh_ratio) -+ norm_img = norm_img[np.newaxis, :] -+ norm_img_batch.append(norm_img) -+ norm_image_mask = np.ones(norm_img.shape, dtype='float32') -+ word_label = np.ones([1, 36], dtype='int64') -+ norm_img_mask_batch = [] -+ word_label_list = [] -+ norm_img_mask_batch.append(norm_image_mask) -+ word_label_list.append(word_label) -+ else: -+ norm_img = self.resize_norm_img(img_list[indices[ino]], -+ max_wh_ratio) -+ norm_img = norm_img[np.newaxis, :] -+ norm_img_batch.append(norm_img) -+ norm_img_batch = np.concatenate(norm_img_batch) -+ norm_img_batch = norm_img_batch.copy() -+ -+ starttime = time.time() -+ -+ if self.rec_algorithm == "SRN": -+ starttime = time.time() -+ encoder_word_pos_list = np.concatenate(encoder_word_pos_list) -+ gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) -+ gsrm_slf_attn_bias1_list = np.concatenate( -+ gsrm_slf_attn_bias1_list) -+ gsrm_slf_attn_bias2_list = np.concatenate( -+ gsrm_slf_attn_bias2_list) -+ -+ with self._dev_lock: - with torch.no_grad(): - inp = torch.from_numpy(norm_img_batch) - encoder_word_pos_inp = torch.from_numpy(encoder_word_pos_list) -@@ -384,58 +391,67 @@ class TextRecognizer(BaseOCRV20): - - backbone_out = self.net.backbone(inp) # backbone_feat - prob_out = self.net.head(backbone_out, [encoder_word_pos_inp, gsrm_word_pos_inp, gsrm_slf_attn_bias1_inp, gsrm_slf_attn_bias2_inp]) -- # preds = {"predict": prob_out[2]} -- preds = {"predict": prob_out["predict"]} -- -- elif self.rec_algorithm == "SAR": -- starttime = time.time() -- # valid_ratios = np.concatenate(valid_ratios) -- # inputs = [ -- # norm_img_batch, -- # valid_ratios, -- # ] -- -+ # preds = {"predict": prob_out[2]} -+ preds = {"predict": prob_out["predict"]} -+ -+ elif self.rec_algorithm == "SAR": -+ starttime = time.time() -+ # valid_ratios = np.concatenate(valid_ratios) -+ # inputs = [ -+ # norm_img_batch, -+ # valid_ratios, -+ # ] -+ -+ with self._dev_lock: - with torch.no_grad(): - inp = torch.from_numpy(norm_img_batch) - inp = inp.to(self.device) - preds = self.net(inp) - -- elif self.rec_algorithm == "CAN": -- starttime = time.time() -- norm_img_mask_batch = np.concatenate(norm_img_mask_batch) -- word_label_list = np.concatenate(word_label_list) -- inputs = [norm_img_batch, norm_img_mask_batch, word_label_list] -+ elif self.rec_algorithm == "CAN": -+ starttime = time.time() -+ norm_img_mask_batch = np.concatenate(norm_img_mask_batch) -+ word_label_list = np.concatenate(word_label_list) -+ inputs = [norm_img_batch, norm_img_mask_batch, word_label_list] - -- inp = [torch.from_numpy(e_i) for e_i in inputs] -- inp = [e_i.to(self.device) for e_i in inp] -+ inp = [torch.from_numpy(e_i) for e_i in inputs] -+ inp = [e_i.to(self.device) for e_i in inp] -+ with self._dev_lock: - with torch.no_grad(): - outputs = self.net(inp) - outputs = [v.cpu().numpy() for k, v in enumerate(outputs)] - -- preds = outputs -- -- else: -- starttime = time.time() -+ preds = outputs - -+ else: -+ with self._dev_lock: - with torch.no_grad(): -- inp = torch.from_numpy(norm_img_batch) -- inp = inp.to(self.device) -+ inp = torch.from_numpy(norm_img_batch).to(self.device) - prob_out = self.net(inp) -+ preds = [v.cpu().numpy() for v in prob_out] if isinstance(prob_out, list) else prob_out.cpu().numpy() - -- if isinstance(prob_out, list): -- preds = [v.cpu().numpy() for v in prob_out] -- else: -- preds = prob_out.cpu().numpy() -+ rec_result = self.postprocess_op(preds) - -- rec_result = self.postprocess_op(preds) -- for rno in range(len(rec_result)): -- rec_res[indices[beg_img_no + rno]] = rec_result[rno] -- elapse += time.time() - starttime -+ for rno in range(len(rec_result)): -+ global_idx = indices[beg_img_no + rno] -+ rec_res[global_idx] = rec_result[rno] -+ -+ batch_elapse = time.time() - starttime -+ return len(rec_result), batch_elapse -+ -+ MAX_WORKERS = 4 -+ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex, \ -+ tqdm(total=img_num, desc='OCR-rec Predict', disable=not tqdm_enable) as pbar: -+ -+ futures = [] -+ for beg_img_no in range(0, img_num, batch_num): -+ end_img_no = min(img_num, beg_img_no + batch_num) -+ futures.append(ex.submit(_rec_batch_worker, beg_img_no, end_img_no)) - -- # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size -- current_batch_size = min(batch_num, img_num - index * batch_num) -- index += 1 -- pbar.update(current_batch_size) -+ for fut in as_completed(futures): -+ n_done, batch_elapse = fut.result() -+ elapse += batch_elapse -+ pbar.update(n_done) - - # Fix NaN values in recognition results - for i in range(len(rec_res)): -diff --git a/mineru/model/table/rapid_table.py b/mineru/model/table/rapid_table.py -index 174a8052..dd796bcc 100644 ---- a/mineru/model/table/rapid_table.py -+++ b/mineru/model/table/rapid_table.py -@@ -21,6 +21,8 @@ class RapidTableModel(object): - self.table_model = RapidTable(input_args) - self.ocr_engine = ocr_engine - -+ import threading -+ self._dev_lock = getattr(self, "_dev_lock", threading.Lock()) - - def predict(self, image): - bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) -@@ -30,44 +32,45 @@ class RapidTableModel(object): - img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0 - img_is_portrait = img_aspect_ratio > 1.2 - -- if img_is_portrait: -+ with self._dev_lock: -+ if img_is_portrait: - -- det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0] -- # Check if table is rotated by analyzing text box aspect ratios -- is_rotated = False -- if det_res: -- vertical_count = 0 -+ det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0] -+ # Check if table is rotated by analyzing text box aspect ratios -+ is_rotated = False -+ if det_res: -+ vertical_count = 0 - -- for box_ocr_res in det_res: -- p1, p2, p3, p4 = box_ocr_res -+ for box_ocr_res in det_res: -+ p1, p2, p3, p4 = box_ocr_res - -- # Calculate width and height -- width = p3[0] - p1[0] -- height = p3[1] - p1[1] -+ # Calculate width and height -+ width = p3[0] - p1[0] -+ height = p3[1] - p1[1] - -- aspect_ratio = width / height if height > 0 else 1.0 -+ aspect_ratio = width / height if height > 0 else 1.0 - -- # Count vertical vs horizontal text boxes -- if aspect_ratio < 0.8: # Taller than wide - vertical text -- vertical_count += 1 -- # elif aspect_ratio > 1.2: # Wider than tall - horizontal text -- # horizontal_count += 1 -+ # Count vertical vs horizontal text boxes -+ if aspect_ratio < 0.8: # Taller than wide - vertical text -+ vertical_count += 1 -+ # elif aspect_ratio > 1.2: # Wider than tall - horizontal text -+ # horizontal_count += 1 - -- # If we have more vertical text boxes than horizontal ones, -- # and vertical ones are significant, table might be rotated -- if vertical_count >= len(det_res) * 0.3: -- is_rotated = True -+ # If we have more vertical text boxes than horizontal ones, -+ # and vertical ones are significant, table might be rotated -+ if vertical_count >= len(det_res) * 0.3: -+ is_rotated = True - -- # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}") -+ # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}") - -- # Rotate image if necessary -- if is_rotated: -- # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise") -- image = cv2.rotate(np.asarray(image), cv2.ROTATE_90_CLOCKWISE) -- bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) -+ # Rotate image if necessary -+ if is_rotated: -+ # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise") -+ image = cv2.rotate(np.asarray(image), cv2.ROTATE_90_CLOCKWISE) -+ bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) - -- # Continue with OCR on potentially rotated image -- ocr_result = self.ocr_engine.ocr(bgr_image)[0] -+ # Continue with OCR on potentially rotated image -+ ocr_result = self.ocr_engine.ocr(bgr_image)[0] - if ocr_result: - ocr_result = [[item[0], escape_html(item[1][0]), item[1][1]] for item in ocr_result if - len(item) == 2 and isinstance(item[1], tuple)] - -- Gitee
Dataset 精度 300I Pro最优性能(对应bs)800I A2最优性能(对应bs) 输入shape
OmniDocBench 0.1588 0.2527 多尺度