diff --git a/cv/ocr/dbnet/pytorch/.gitignore b/cv/ocr/dbnet/pytorch/.gitignore
new file mode 100755
index 0000000000000000000000000000000000000000..628f7f2699eb1265cba9873ad70f89b60f6a7eb2
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/.gitignore
@@ -0,0 +1,49 @@
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# PyTorch checkpoint
+*.pth
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+OCR_Detect/
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+
diff --git a/cv/ocr/dbnet/pytorch/CITATION.cff b/cv/ocr/dbnet/pytorch/CITATION.cff
new file mode 100755
index 0000000000000000000000000000000000000000..7d1d93a7c68daf442bc6540b197b401e7a38b91c
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/CITATION.cff
@@ -0,0 +1,9 @@
+cff-version: 1.2.0
+message: "If you use this software, please cite it as below."
+title: "OpenMMLab Text Detection, Recognition and Understanding Toolbox"
+authors:
+ - name: "MMOCR Contributors"
+version: 0.3.0
+date-released: 2020-08-15
+repository-code: "https://github.com/open-mmlab/mmocr"
+license: Apache-2.0
diff --git a/cv/ocr/dbnet/pytorch/EGG-INFO/PKG-INFO b/cv/ocr/dbnet/pytorch/EGG-INFO/PKG-INFO
new file mode 100755
index 0000000000000000000000000000000000000000..9280d7ec8910f9759c9d8490c4df803cfb4e7583
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/EGG-INFO/PKG-INFO
@@ -0,0 +1,380 @@
+Metadata-Version: 2.1
+Name: dbnet_det
+Version: 2.25.0
+Summary: OpenMMLab Detection Toolbox and Benchmark
+Home-page: https://github.com/open-mmlab/dbnet_detection
+Author: dbnet_detection Contributors
+Author-email: openmmlab@gmail.com
+License: Apache License 2.0
+Keywords: computer vision,object detection
+Classifier: Development Status :: 5 - Production/Stable
+Classifier: License :: OSI Approved :: Apache Software License
+Classifier: Operating System :: OS Independent
+Classifier: Programming Language :: Python :: 3
+Classifier: Programming Language :: Python :: 3.6
+Classifier: Programming Language :: Python :: 3.7
+Classifier: Programming Language :: Python :: 3.8
+Classifier: Programming Language :: Python :: 3.9
+Description-Content-Type: text/markdown
+Provides-Extra: all
+Provides-Extra: tests
+Provides-Extra: build
+Provides-Extra: optional
+
+
+

+
+
+
+
+[](https://pypi.org/project/dbnet_det)
+[](https://dbnet_detection.readthedocs.io/en/latest/)
+[](https://github.com/open-mmlab/dbnet_detection/actions)
+[](https://codecov.io/gh/open-mmlab/dbnet_detection)
+[](https://github.com/open-mmlab/dbnet_detection/blob/master/LICENSE)
+[](https://github.com/open-mmlab/dbnet_detection/issues)
+[](https://github.com/open-mmlab/dbnet_detection/issues)
+
+[📘Documentation](https://dbnet_detection.readthedocs.io/en/stable/) |
+[🛠️Installation](https://dbnet_detection.readthedocs.io/en/stable/get_started.html) |
+[👀Model Zoo](https://dbnet_detection.readthedocs.io/en/stable/model_zoo.html) |
+[🆕Update News](https://dbnet_detection.readthedocs.io/en/stable/changelog.html) |
+[🚀Ongoing Projects](https://github.com/open-mmlab/dbnet_detection/projects) |
+[🤔Reporting Issues](https://github.com/open-mmlab/dbnet_detection/issues/new/choose)
+
+
+
+
+
+English | [简体中文](README_zh-CN.md)
+
+
+
+## Introduction
+
+dbnet_detection is an open source object detection toolbox based on PyTorch. It is
+a part of the [OpenMMLab](https://openmmlab.com/) project.
+
+The master branch works with **PyTorch 1.5+**.
+
+
+
+
+Major features
+
+- **Modular Design**
+
+ We decompose the detection framework into different components and one can easily construct a customized object detection framework by combining different modules.
+
+- **Support of multiple frameworks out of box**
+
+ The toolbox directly supports popular and contemporary detection frameworks, *e.g.* Faster RCNN, Mask RCNN, RetinaNet, etc.
+
+- **High efficiency**
+
+ All basic bbox and mask operations run on GPUs. The training speed is faster than or comparable to other codebases, including [Detectron2](https://github.com/facebookresearch/detectron2), [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark) and [SimpleDet](https://github.com/TuSimple/simpledet).
+
+- **State of the art**
+
+ The toolbox stems from the codebase developed by the *dbnet_det* team, who won [COCO Detection Challenge](http://cocodataset.org/#detection-leaderboard) in 2018, and we keep pushing it forward.
+
+
+
+Apart from dbnet_detection, we also released a library [dbnet_cv](https://github.com/open-mmlab/dbnet_cv) for computer vision research, which is heavily depended on by this toolbox.
+
+## What's New
+
+**2.25.0** was released in 1/6/2022:
+
+- Support dedicated `dbnet_detWandbHook` hook
+- Support [ConvNeXt](configs/convnext), [DDOD](configs/ddod), [SOLOv2](configs/solov2)
+- Support [Mask2Former](configs/mask2former) for instance segmentation
+- Rename [config files of Mask2Former](configs/mask2former)
+
+Please refer to [changelog.md](docs/en/changelog.md) for details and release history.
+
+For compatibility changes between different versions of dbnet_detection, please refer to [compatibility.md](docs/en/compatibility.md).
+
+## Installation
+
+Please refer to [Installation](docs/en/get_started.md/#Installation) for installation instructions.
+
+## Getting Started
+
+Please see [get_started.md](docs/en/get_started.md) for the basic usage of dbnet_detection. We provide [colab tutorial](demo/dbnet_det_Tutorial.ipynb) and [instance segmentation colab tutorial](demo/dbnet_det_InstanceSeg_Tutorial.ipynb), and other tutorials for:
+
+- [with existing dataset](docs/en/1_exist_data_model.md)
+- [with new dataset](docs/en/2_new_data_model.md)
+- [with existing dataset_new_model](docs/en/3_exist_data_new_model.md)
+- [learn about configs](docs/en/tutorials/config.md)
+- [customize_datasets](docs/en/tutorials/customize_dataset.md)
+- [customize data pipelines](docs/en/tutorials/data_pipeline.md)
+- [customize_models](docs/en/tutorials/customize_models.md)
+- [customize runtime settings](docs/en/tutorials/customize_runtime.md)
+- [customize_losses](docs/en/tutorials/customize_losses.md)
+- [finetuning models](docs/en/tutorials/finetune.md)
+- [export a model to ONNX](docs/en/tutorials/pytorch2onnx.md)
+- [export ONNX to TRT](docs/en/tutorials/onnx2tensorrt.md)
+- [weight initialization](docs/en/tutorials/init_cfg.md)
+- [how to xxx](docs/en/tutorials/how_to.md)
+
+## Overview of Benchmark and Model Zoo
+
+Results and models are available in the [model zoo](docs/en/model_zoo.md).
+
+
+ Architectures
+
+
+
+
+
+ Object Detection
+ |
+
+ Instance Segmentation
+ |
+
+ Panoptic Segmentation
+ |
+
+ Other
+ |
+
+
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+ Contrastive Learning
+
+
+ Distillation
+
+
+ |
+
+
+
+
+
+
+
+ Components
+
+
+
+
+
+ Backbones
+ |
+
+ Necks
+ |
+
+ Loss
+ |
+
+ Common
+ |
+
+
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+
+
+
+
+Some other methods are also supported in [projects using dbnet_detection](./docs/en/projects.md).
+
+## FAQ
+
+Please refer to [FAQ](docs/en/faq.md) for frequently asked questions.
+
+## Contributing
+
+We appreciate all contributions to improve dbnet_detection. Ongoing projects can be found in out [GitHub Projects](https://github.com/open-mmlab/dbnet_detection/projects). Welcome community users to participate in these projects. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline.
+
+## Acknowledgement
+
+dbnet_detection is an open source project that is contributed by researchers and engineers from various colleges and companies. We appreciate all the contributors who implement their methods or add new features, as well as users who give valuable feedbacks.
+We wish that the toolbox and benchmark could serve the growing research community by providing a flexible toolkit to reimplement existing methods and develop their own new detectors.
+
+## Citation
+
+If you use this toolbox or benchmark in your research, please cite this project.
+
+```
+@article{dbnet_detection,
+ title = {{dbnet_detection}: Open MMLab Detection Toolbox and Benchmark},
+ author = {Chen, Kai and Wang, Jiaqi and Pang, Jiangmiao and Cao, Yuhang and
+ Xiong, Yu and Li, Xiaoxiao and Sun, Shuyang and Feng, Wansen and
+ Liu, Ziwei and Xu, Jiarui and Zhang, Zheng and Cheng, Dazhi and
+ Zhu, Chenchen and Cheng, Tianheng and Zhao, Qijie and Li, Buyu and
+ Lu, Xin and Zhu, Rui and Wu, Yue and Dai, Jifeng and Wang, Jingdong
+ and Shi, Jianping and Ouyang, Wanli and Loy, Chen Change and Lin, Dahua},
+ journal= {arXiv preprint arXiv:1906.07155},
+ year={2019}
+}
+```
+
+## License
+
+This project is released under the [Apache 2.0 license](LICENSE).
+
+## Projects in OpenMMLab
+
+- [DBNET_CV](https://github.com/open-mmlab/dbnet_cv): OpenMMLab foundational library for computer vision.
+- [MIM](https://github.com/open-mmlab/mim): MIM installs OpenMMLab packages.
+- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab image classification toolbox and benchmark.
+- [dbnet_detection](https://github.com/open-mmlab/dbnet_detection): OpenMMLab detection toolbox and benchmark.
+- [dbnet_detection3D](https://github.com/open-mmlab/dbnet_detection3d): OpenMMLab's next-generation platform for general 3D object detection.
+- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab rotated object detection toolbox and benchmark.
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab semantic segmentation toolbox and benchmark.
+- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab text detection, recognition, and understanding toolbox.
+- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark.
+- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 3D human parametric model toolbox and benchmark.
+- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab self-supervised learning toolbox and benchmark.
+- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab model compression toolbox and benchmark.
+- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab fewshot learning toolbox and benchmark.
+- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab's next-generation action understanding toolbox and benchmark.
+- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab video perception toolbox and benchmark.
+- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab optical flow toolbox and benchmark.
+- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab image and video editing toolbox.
+- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab image and video generative models toolbox.
+- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab model deployment framework.
diff --git a/cv/ocr/dbnet/pytorch/EGG-INFO/SOURCES.txt b/cv/ocr/dbnet/pytorch/EGG-INFO/SOURCES.txt
new file mode 100755
index 0000000000000000000000000000000000000000..4f42864d415c099c85025f0159e89eb707bdcd5e
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/EGG-INFO/SOURCES.txt
@@ -0,0 +1,62 @@
+README.md
+setup.py
+dbnet_det/__init__.py
+dbnet_det/version.py
+dbnet_det.egg-info/PKG-INFO
+dbnet_det.egg-info/SOURCES.txt
+dbnet_det.egg-info/dependency_links.txt
+dbnet_det.egg-info/not-zip-safe
+dbnet_det.egg-info/requires.txt
+dbnet_det.egg-info/top_level.txt
+dbnet_det/core/__init__.py
+dbnet_det/core/bbox/__init__.py
+dbnet_det/core/bbox/transforms.py
+dbnet_det/core/evaluation/__init__.py
+dbnet_det/core/evaluation/bbox_overlaps.py
+dbnet_det/core/evaluation/class_names.py
+dbnet_det/core/evaluation/eval_hooks.py
+dbnet_det/core/evaluation/mean_ap.py
+dbnet_det/core/evaluation/panoptic_utils.py
+dbnet_det/core/evaluation/recall.py
+dbnet_det/core/mask/__init__.py
+dbnet_det/core/mask/structures.py
+dbnet_det/core/mask/utils.py
+dbnet_det/core/utils/__init__.py
+dbnet_det/core/utils/dist_utils.py
+dbnet_det/core/utils/misc.py
+dbnet_det/core/visualization/__init__.py
+dbnet_det/core/visualization/image.py
+dbnet_det/core/visualization/palette.py
+dbnet_det/datasets/__init__.py
+dbnet_det/datasets/builder.py
+dbnet_det/datasets/coco.py
+dbnet_det/datasets/custom.py
+dbnet_det/datasets/dataset_wrappers.py
+dbnet_det/datasets/utils.py
+dbnet_det/datasets/api_wrappers/__init__.py
+dbnet_det/datasets/api_wrappers/coco_api.py
+dbnet_det/datasets/api_wrappers/panoptic_evaluation.py
+dbnet_det/datasets/pipelines/__init__.py
+dbnet_det/datasets/pipelines/compose.py
+dbnet_det/datasets/pipelines/formatting.py
+dbnet_det/datasets/pipelines/loading.py
+dbnet_det/datasets/pipelines/test_time_aug.py
+dbnet_det/datasets/pipelines/transforms.py
+dbnet_det/datasets/samplers/__init__.py
+dbnet_det/datasets/samplers/class_aware_sampler.py
+dbnet_det/datasets/samplers/distributed_sampler.py
+dbnet_det/datasets/samplers/group_sampler.py
+dbnet_det/datasets/samplers/infinite_sampler.py
+dbnet_det/models/__init__.py
+dbnet_det/models/builder.py
+dbnet_det/models/backbones/__init__.py
+dbnet_det/models/backbones/resnet.py
+dbnet_det/models/detectors/__init__.py
+dbnet_det/models/detectors/base.py
+dbnet_det/models/detectors/single_stage.py
+dbnet_det/models/utils/__init__.py
+dbnet_det/models/utils/res_layer.py
+dbnet_det/utils/__init__.py
+dbnet_det/utils/logger.py
+dbnet_det/utils/profiling.py
+dbnet_det/utils/util_distribution.py
\ No newline at end of file
diff --git a/cv/ocr/dbnet/pytorch/EGG-INFO/dependency_links.txt b/cv/ocr/dbnet/pytorch/EGG-INFO/dependency_links.txt
new file mode 100755
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/EGG-INFO/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/cv/ocr/dbnet/pytorch/EGG-INFO/not-zip-safe b/cv/ocr/dbnet/pytorch/EGG-INFO/not-zip-safe
new file mode 100755
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/EGG-INFO/not-zip-safe
@@ -0,0 +1 @@
+
diff --git a/cv/ocr/dbnet/pytorch/EGG-INFO/requires.txt b/cv/ocr/dbnet/pytorch/EGG-INFO/requires.txt
new file mode 100755
index 0000000000000000000000000000000000000000..22c49defa6c6641da95b0ac68789965cdac53292
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/EGG-INFO/requires.txt
@@ -0,0 +1,59 @@
+matplotlib
+numpy
+pycocotools
+six
+terminaltables
+
+[all]
+cython
+numpy
+cityscapesscripts
+imagecorruptions
+scipy
+sklearn
+timm
+matplotlib
+pycocotools
+six
+terminaltables
+asynctest
+codecov
+flake8
+interrogate
+isort==4.3.21
+kwarray
+mmtrack
+onnx==1.7.0
+onnxruntime>=1.8.0
+protobuf<=3.20.1
+pytest
+ubelt
+xdoctest>=0.10.0
+yapf
+
+[build]
+cython
+numpy
+
+[optional]
+cityscapesscripts
+imagecorruptions
+scipy
+sklearn
+timm
+
+[tests]
+asynctest
+codecov
+flake8
+interrogate
+isort==4.3.21
+kwarray
+mmtrack
+onnx==1.7.0
+onnxruntime>=1.8.0
+protobuf<=3.20.1
+pytest
+ubelt
+xdoctest>=0.10.0
+yapf
diff --git a/cv/ocr/dbnet/pytorch/EGG-INFO/top_level.txt b/cv/ocr/dbnet/pytorch/EGG-INFO/top_level.txt
new file mode 100755
index 0000000000000000000000000000000000000000..2d2f43c40746f3804a679a386877c3d40e319f48
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/EGG-INFO/top_level.txt
@@ -0,0 +1 @@
+dbnet_det
diff --git a/cv/ocr/dbnet/pytorch/LICENSE b/cv/ocr/dbnet/pytorch/LICENSE
new file mode 100755
index 0000000000000000000000000000000000000000..3076a4378396deea4db311adbe1fbfd8b8b05920
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/LICENSE
@@ -0,0 +1,203 @@
+Copyright (c) MMOCR Authors. All rights reserved.
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2021 MMOCR Authors. All rights reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/cv/ocr/dbnet/pytorch/README.md b/cv/ocr/dbnet/pytorch/README.md
new file mode 100755
index 0000000000000000000000000000000000000000..d92a6cfcab2e892ede674e6526ed1daab5f844ea
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/README.md
@@ -0,0 +1,69 @@
+# DBnet
+## Model description
+Recently, segmentation-based methods are quite popular in scene text detection, as the segmentation results can more accurately describe scene text of various shapes such as curve text. However, the post-processing of binarization is essential for segmentation-based detection, which converts probability maps produced by a segmentation method into bounding boxes/regions of text. In this paper, we propose a module named Differentiable Binarization (DB), which can perform the binarization process in a segmentation network. Optimized along with a DB module, a segmentation network can adaptively set the thresholds for binarization, which not only simplifies the post-processing but also enhances the performance of text detection. Based on a simple segmentation network, we validate the performance improvements of DB on five benchmark datasets, which consistently achieves state-of-the-art results, in terms of both detection accuracy and speed. In particular, with a light-weight backbone, the performance improvements by DB are significant so that we can look for an ideal tradeoff between detection accuracy and efficiency.
+## Step 2: Preparing datasets
+
+```shell
+$ mkdir data
+$ cd data
+```
+ICDAR 2015
+Please [ICDAR 2015](https://rrc.cvc.uab.es/?ch=4&com=downloads) download ICDAR 2015 here
+ch4_training_images.zip、ch4_test_images.zip、ch4_training_localization_transcription_gt.zip、Challenge4_Test_Task1_GT.zip
+
+```shell
+mkdir icdar2015 && cd icdar2015
+mkdir imgs && mkdir annotations
+
+mv ch4_training_images imgs/training
+mv ch4_test_images imgs/test
+
+mv ch4_training_localization_transcription_gt annotations/training
+mv Challenge4_Test_Task1_GT annotations/test
+```
+Please [instances_training.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_training.json) download instances_training.json here
+Please [instances_test.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_test.json) download instances_test.json here
+
+```shell
+
+icdar2015/
+├── imgs
+│ ├── test
+│ └── training
+├── instances_test.json
+└── instances_training.json
+
+```
+### Build Extension
+
+```shell
+$ DBNET_CV_WITH_OPS=1 python3 setup.py build && cp build/lib.linux*/dbnet_cv/_ext.cpython* dbnet_cv
+```
+### Install packages
+
+```shell
+$ pip3 install -r requirements.txt
+```
+
+### Training on single card
+```shell
+$ python3 train.py configs/textdet/dbnet/dbnet_mobilenetv3_fpnc_1200e_icdar2015.py
+```
+
+### Training on mutil-cards
+```shell
+$ bash dist_train.sh configs/textdet/dbnet/dbnet_mobilenetv3_fpnc_1200e_icdar2015.py 8
+```
+
+## Results on BI-V100
+
+| approach| GPUs | train mem | train FPS |
+| :-----: |:-------:| :-------: |:--------: |
+| dbnet | BI100x8 | 5426 | 54.375 |
+
+|0_hmean-iou:recall: | 0_hmean-iou:precision: | 0_hmean-iou:hmean:|
+| :-----: | :-------: | :-------: |
+| 0.7111 | 0.8062 | 0.7557 |
+
+## Reference
+https://github.com/open-mmlab/mmocr
diff --git a/cv/ocr/dbnet/pytorch/configs/_base_/default_runtime.py b/cv/ocr/dbnet/pytorch/configs/_base_/default_runtime.py
new file mode 100755
index 0000000000000000000000000000000000000000..de7f9650ce73ba7ca633652b50df021b67498362
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/configs/_base_/default_runtime.py
@@ -0,0 +1,17 @@
+# yapf:disable
+log_config = dict(
+ interval=5,
+ hooks=[
+ dict(type='TextLoggerHook')
+ ])
+# yapf:enable
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
+
+# disable opencv multithreading to avoid system being overloaded
+opencv_num_threads = 0
+# set multi-process start method as `fork` to speed up the training
+mp_start_method = 'fork'
diff --git a/cv/ocr/dbnet/pytorch/configs/_base_/det_datasets/icdar2015.py b/cv/ocr/dbnet/pytorch/configs/_base_/det_datasets/icdar2015.py
new file mode 100755
index 0000000000000000000000000000000000000000..f711c06dce76d53b8737288c8de318e6f90ce585
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/configs/_base_/det_datasets/icdar2015.py
@@ -0,0 +1,18 @@
+dataset_type = 'IcdarDataset'
+data_root = 'data/icdar2015'
+
+train = dict(
+ type=dataset_type,
+ ann_file=f'{data_root}/instances_training.json',
+ img_prefix=f'{data_root}/imgs',
+ pipeline=None)
+
+test = dict(
+ type=dataset_type,
+ ann_file=f'{data_root}/instances_test.json',
+ img_prefix=f'{data_root}/imgs',
+ pipeline=None)
+
+train_list = [train]
+
+test_list = [test]
diff --git a/cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_mobilenetv3_fpnc.py b/cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_mobilenetv3_fpnc.py
new file mode 100755
index 0000000000000000000000000000000000000000..e6bff598fce56decf6d5e92596d67165035b09f8
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_mobilenetv3_fpnc.py
@@ -0,0 +1,32 @@
+model = dict(
+ type='DBNet',
+ backbone=dict(
+ type='dbnet_det.MobileNetV3',
+ arch='large',
+ # num_stages=3,
+ out_indices=(3, 6, 12, 16),
+ norm_cfg=dict(type='BN', requires_grad=True),
+ init_cfg=dict(type='Pretrained', checkpoint='https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth')
+ ),
+ # backbone=dict(
+ # type='dbnet_det.ResNet',
+ # depth=18,
+ # num_stages=4,
+ # out_indices=(0, 1, 2, 3),
+ # frozen_stages=-1,
+ # norm_cfg=dict(type='BN', requires_grad=True),
+ # # init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18'),
+ # norm_eval=False,
+ # style='caffe'),
+ neck=dict(
+ type='FPNC', in_channels=[24, 40, 112, 960], lateral_channels=256),
+ bbox_head=dict(
+ type='DBHead',
+ in_channels=256,
+ loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=False),
+ postprocessor=dict(type='DBPostprocessor', text_repr_type='quad')),
+ train_cfg=None,
+ test_cfg=None)
+
+
+
diff --git a/cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_r18_fpnc.py b/cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_r18_fpnc.py
new file mode 100755
index 0000000000000000000000000000000000000000..b26391f4d4b2eb34beae81cce56c2816979bc730
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_r18_fpnc.py
@@ -0,0 +1,31 @@
+model = dict(
+ type='DBNet',
+ backbone=dict(
+ type='dbnet_det.ResNet',
+ depth=18,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ # init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18'),
+ norm_eval=False,
+ style='caffe'),
+ neck=dict(
+ type='FPNC', in_channels=[64, 128, 256, 512], lateral_channels=256),
+ bbox_head=dict(
+ type='DBHead',
+ in_channels=256,
+ loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=False),
+ postprocessor=dict(type='DBPostprocessor', text_repr_type='quad')),
+ train_cfg=None,
+ test_cfg=None)
+
+
+
+# backbone=dict(
+# type='MobileNetV3',
+# arch='small',
+# out_indices=(0, 1, 12),
+# norm_cfg=dict(type='BN', requires_grad=True),
+# init_cfg=dict(type='Pretrained', checkpoint='open-mmlab://contrib/mobilenet_v3_small')
+# ),
\ No newline at end of file
diff --git a/cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_r50dcnv2_fpnc.py b/cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_r50dcnv2_fpnc.py
new file mode 100755
index 0000000000000000000000000000000000000000..32b09f507a8aaabf16f86d9569a22090578084e8
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_r50dcnv2_fpnc.py
@@ -0,0 +1,23 @@
+model = dict(
+ type='DBNet',
+ backbone=dict(
+ type='dbnet_det.ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=False,
+ style='pytorch',
+ dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),
+ # init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
+ stage_with_dcn=(False, True, True, True)),
+ neck=dict(
+ type='FPNC', in_channels=[256, 512, 1024, 2048], lateral_channels=256),
+ bbox_head=dict(
+ type='DBHead',
+ in_channels=256,
+ loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True),
+ postprocessor=dict(type='DBPostprocessor', text_repr_type='quad')),
+ train_cfg=None,
+ test_cfg=None)
diff --git a/cv/ocr/dbnet/pytorch/configs/_base_/det_pipelines/dbnet_pipeline.py b/cv/ocr/dbnet/pytorch/configs/_base_/det_pipelines/dbnet_pipeline.py
new file mode 100755
index 0000000000000000000000000000000000000000..aeb944871ac6d12bf5fc9a694076942afa63a7d8
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/configs/_base_/det_pipelines/dbnet_pipeline.py
@@ -0,0 +1,88 @@
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+train_pipeline_r18 = [
+ dict(type='LoadImageFromFile', color_type='color_ignore_orientation'),
+ dict(
+ type='LoadTextAnnotations',
+ with_bbox=True,
+ with_mask=True,
+ poly2mask=False),
+ dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(
+ type='ImgAug',
+ args=[['Fliplr', 0.5],
+ dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]]),
+ dict(type='EastRandomCrop', target_size=(640, 640)),
+ dict(type='DBNetTargets', shrink_ratio=0.4),
+ dict(type='Pad', size_divisor=32),
+ dict(
+ type='CustomFormatBundle',
+ keys=['gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask'],
+ visualize=dict(flag=False, boundary_key='gt_shrink')),
+ dict(
+ type='Collect',
+ keys=['img', 'gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask'])
+]
+
+test_pipeline_1333_736 = [
+ dict(type='LoadImageFromFile', color_type='color_ignore_orientation'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1280, 736), # used by Resize
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+
+# for dbnet_r50dcnv2_fpnc
+img_norm_cfg_r50dcnv2 = dict(
+ mean=[122.67891434, 116.66876762, 104.00698793],
+ std=[58.395, 57.12, 57.375],
+ to_rgb=True)
+
+train_pipeline_r50dcnv2 = [
+ dict(type='LoadImageFromFile', color_type='color_ignore_orientation'),
+ dict(
+ type='LoadTextAnnotations',
+ with_bbox=True,
+ with_mask=True,
+ poly2mask=False),
+ dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
+ dict(type='Normalize', **img_norm_cfg_r50dcnv2),
+ dict(
+ type='ImgAug',
+ args=[['Fliplr', 0.5],
+ dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]]),
+ dict(type='EastRandomCrop', target_size=(640, 640)),
+ dict(type='DBNetTargets', shrink_ratio=0.4),
+ dict(type='Pad', size_divisor=32),
+ dict(
+ type='CustomFormatBundle',
+ keys=['gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask'],
+ visualize=dict(flag=False, boundary_key='gt_shrink')),
+ dict(
+ type='Collect',
+ keys=['img', 'gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask'])
+]
+
+test_pipeline_4068_1024 = [
+ dict(type='LoadImageFromFile', color_type='color_ignore_orientation'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(4068, 1024), # used by Resize
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg_r50dcnv2),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
diff --git a/cv/ocr/dbnet/pytorch/configs/_base_/schedules/schedule_adam_1200e.py b/cv/ocr/dbnet/pytorch/configs/_base_/schedules/schedule_adam_1200e.py
new file mode 100755
index 0000000000000000000000000000000000000000..6be9df0078e9261f0223eca9e2d5dafb5edc69ac
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/configs/_base_/schedules/schedule_adam_1200e.py
@@ -0,0 +1,8 @@
+# optimizer
+optimizer = dict(type='AdamW', lr=1e-3,betas=(0.9, 0.999), weight_decay=0.05)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='poly', power=0.9)
+# running settings
+runner = dict(type='EpochBasedRunner', max_epochs=1200)
+checkpoint_config = dict(interval=50)
\ No newline at end of file
diff --git a/cv/ocr/dbnet/pytorch/configs/_base_/schedules/schedule_sgd_1200e.py b/cv/ocr/dbnet/pytorch/configs/_base_/schedules/schedule_sgd_1200e.py
new file mode 100755
index 0000000000000000000000000000000000000000..bc7fbf69b42b11ea9b8ae4d14216d2fcf20e717c
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/configs/_base_/schedules/schedule_sgd_1200e.py
@@ -0,0 +1,8 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.007, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=1e-7, by_epoch=True)
+# running settings
+runner = dict(type='EpochBasedRunner', max_epochs=1200)
+checkpoint_config = dict(interval=100)
diff --git a/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/README.md b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/README.md
new file mode 100755
index 0000000000000000000000000000000000000000..d2007c72ec2b45e70d30c6edea128b7e0be2baca
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/README.md
@@ -0,0 +1,33 @@
+# DBNet
+
+> [Real-time Scene Text Detection with Differentiable Binarization](https://arxiv.org/abs/1911.08947)
+
+
+
+## Abstract
+
+Recently, segmentation-based methods are quite popular in scene text detection, as the segmentation results can more accurately describe scene text of various shapes such as curve text. However, the post-processing of binarization is essential for segmentation-based detection, which converts probability maps produced by a segmentation method into bounding boxes/regions of text. In this paper, we propose a module named Differentiable Binarization (DB), which can perform the binarization process in a segmentation network. Optimized along with a DB module, a segmentation network can adaptively set the thresholds for binarization, which not only simplifies the post-processing but also enhances the performance of text detection. Based on a simple segmentation network, we validate the performance improvements of DB on five benchmark datasets, which consistently achieves state-of-the-art results, in terms of both detection accuracy and speed. In particular, with a light-weight backbone, the performance improvements by DB are significant so that we can look for an ideal tradeoff between detection accuracy and efficiency. Specifically, with a backbone of ResNet-18, our detector achieves an F-measure of 82.8, running at 62 FPS, on the MSRA-TD500 dataset.
+
+
+

+
+
+## Results and models
+
+### ICDAR2015
+
+| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
+| :---------------------------------------: | :-------------------------------------------------: | :-------------: | :------------: | :-----: | :-------: | :----: | :-------: | :---: | :-----------------------------------------: |
+| [DBNet_r18](/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py) | ImageNet | ICDAR2015 Train | ICDAR2015 Test | 1200 | 736 | 0.731 | 0.871 | 0.795 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.log.json) |
+| [DBNet_r50dcn](/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py) | [Synthtext](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_2e_synthtext_20210325-aa96e477.pth) | ICDAR2015 Train | ICDAR2015 Test | 1200 | 1024 | 0.814 | 0.868 | 0.840 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20211025-9fe3b590.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20211025-9fe3b590.log.json) |
+
+## Citation
+
+```bibtex
+@article{Liao_Wan_Yao_Chen_Bai_2020,
+ title={Real-Time Scene Text Detection with Differentiable Binarization},
+ journal={Proceedings of the AAAI Conference on Artificial Intelligence},
+ author={Liao, Minghui and Wan, Zhaoyi and Yao, Cong and Chen, Kai and Bai, Xiang},
+ year={2020},
+ pages={11474-11481}}
+```
diff --git a/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_mobilenetv3_fpnc_1200e_icdar2015.py b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_mobilenetv3_fpnc_1200e_icdar2015.py
new file mode 100755
index 0000000000000000000000000000000000000000..e8bcd2bda11db4babff44fd9a2844dd9cfa78bce
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_mobilenetv3_fpnc_1200e_icdar2015.py
@@ -0,0 +1,33 @@
+_base_ = [
+ '../../_base_/default_runtime.py',
+ '../../_base_/schedules/schedule_sgd_1200e.py',
+ '../../_base_/det_models/dbnet_mobilenetv3_fpnc.py',
+ '../../_base_/det_datasets/icdar2015.py',
+ '../../_base_/det_pipelines/dbnet_pipeline.py'
+]
+
+train_list = {{_base_.train_list}}
+test_list = {{_base_.test_list}}
+
+train_pipeline_r18 = {{_base_.train_pipeline_r18}}
+test_pipeline_1333_736 = {{_base_.test_pipeline_1333_736}}
+
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
+ train=dict(
+ type='UniformConcatDataset',
+ datasets=train_list,
+ pipeline=train_pipeline_r18),
+ val=dict(
+ type='UniformConcatDataset',
+ datasets=test_list,
+ pipeline=test_pipeline_1333_736),
+ test=dict(
+ type='UniformConcatDataset',
+ datasets=test_list,
+ pipeline=test_pipeline_1333_736))
+fp16 = dict(loss_scale='dynamic')
+evaluation = dict(interval=50, metric='hmean-iou')
diff --git a/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py
new file mode 100755
index 0000000000000000000000000000000000000000..9ea1bcc18a63fbdc1bd118932d605eaf9badb008
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py
@@ -0,0 +1,33 @@
+_base_ = [
+ '../../_base_/default_runtime.py',
+ '../../_base_/schedules/schedule_sgd_1200e.py',
+ '../../_base_/det_models/dbnet_r18_fpnc.py',
+ '../../_base_/det_datasets/icdar2015.py',
+ '../../_base_/det_pipelines/dbnet_pipeline.py'
+]
+
+train_list = {{_base_.train_list}}
+test_list = {{_base_.test_list}}
+
+train_pipeline_r18 = {{_base_.train_pipeline_r18}}
+test_pipeline_1333_736 = {{_base_.test_pipeline_1333_736}}
+
+data = dict(
+ samples_per_gpu=16,
+ workers_per_gpu=8,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
+ train=dict(
+ type='UniformConcatDataset',
+ datasets=train_list,
+ pipeline=train_pipeline_r18),
+ val=dict(
+ type='UniformConcatDataset',
+ datasets=test_list,
+ pipeline=test_pipeline_1333_736),
+ test=dict(
+ type='UniformConcatDataset',
+ datasets=test_list,
+ pipeline=test_pipeline_1333_736))
+fp16 = dict(loss_scale='dynamic')
+evaluation = dict(interval=100, metric='hmean-iou')
diff --git a/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py
new file mode 100755
index 0000000000000000000000000000000000000000..06e2545b830fa710c2d641892b56bc0e69b13aa8
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py
@@ -0,0 +1,37 @@
+_base_ = [
+ '../../_base_/default_runtime.py',
+ '../../_base_/schedules/schedule_sgd_1200e.py',
+ '../../_base_/det_models/dbnet_r50dcnv2_fpnc.py',
+ '../../_base_/det_datasets/icdar2015.py',
+ '../../_base_/det_pipelines/dbnet_pipeline.py'
+]
+
+train_list = {{_base_.train_list}}
+test_list = {{_base_.test_list}}
+
+train_pipeline_r50dcnv2 = {{_base_.train_pipeline_r50dcnv2}}
+test_pipeline_4068_1024 = {{_base_.test_pipeline_4068_1024}}
+
+# load_from = 'checkpoints/textdet/dbnet/res50dcnv2_synthtext.pth'
+# fp16 = dict(loss_scale='dynamic')
+
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=4,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
+ train=dict(
+ type='UniformConcatDataset',
+ datasets=train_list,
+ pipeline=train_pipeline_r50dcnv2),
+ val=dict(
+ type='UniformConcatDataset',
+ datasets=test_list,
+ pipeline=test_pipeline_4068_1024),
+ test=dict(
+ type='UniformConcatDataset',
+ datasets=test_list,
+ pipeline=test_pipeline_4068_1024))
+fp16 = dict(loss_scale='dynamic')
+evaluation = dict(interval=1, metric='hmean-iou')
+# fp16 = dict(loss_scale='dynamic')
diff --git a/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/metafile.yml b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/metafile.yml
new file mode 100755
index 0000000000000000000000000000000000000000..c6abdbca61d760a0e6d275e5188312ef86fd055e
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/metafile.yml
@@ -0,0 +1,40 @@
+Collections:
+- Name: DBNet
+ Metadata:
+ Training Data: ICDAR2015
+ Training Techniques:
+ - SGD with Momentum
+ - Weight Decay
+ Training Resources: 1x GeForce GTX 1080 Ti
+ Architecture:
+ - ResNet
+ - FPNC
+ Paper:
+ URL: https://arxiv.org/pdf/1911.08947.pdf
+ Title: 'Real-time Scene Text Detection with Differentiable Binarization'
+ README: configs/textdet/dbnet/README.md
+
+Models:
+ - Name: dbnet_r18_fpnc_1200e_icdar2015
+ In Collection: DBNet
+ Config: configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py
+ Metadata:
+ Training Data: ICDAR2015
+ Results:
+ - Task: Text Detection
+ Dataset: ICDAR2015
+ Metrics:
+ hmean-iou: 0.795
+ Weights: https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth
+
+ - Name: dbnet_r50dcnv2_fpnc_1200e_icdar2015
+ In Collection: DBNet
+ Config: configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py
+ Metadata:
+ Training Data: ICDAR2015
+ Results:
+ - Task: Text Detection
+ Dataset: ICDAR2015
+ Metrics:
+ hmean-iou: 0.840
+ Weights: https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20211025-9fe3b590.pth
diff --git a/cv/ocr/dbnet/pytorch/dbnet/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..ed46a6711329441b0621d165970667bcfb52e673
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/__init__.py
@@ -0,0 +1,70 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import dbnet_cv
+import dbnet_det
+from packaging.version import parse
+
+from .version import __version__, short_version
+
+
+def digit_version(version_str: str, length: int = 4):
+ """Convert a version string into a tuple of integers.
+
+ This method is usually used for comparing two versions. For pre-release
+ versions: alpha < beta < rc.
+ Args:
+ version_str (str): The version string.
+ length (int): The maximum number of version levels. Default: 4.
+ Returns:
+ tuple[int]: The version info in digits (integers).
+ """
+ version = parse(version_str)
+ assert version.release, f'failed to parse version {version_str}'
+ release = list(version.release)
+ release = release[:length]
+ if len(release) < length:
+ release = release + [0] * (length - len(release))
+ if version.is_prerelease:
+ mapping = {'a': -3, 'b': -2, 'rc': -1}
+ val = -4
+ # version.pre can be None
+ if version.pre:
+ if version.pre[0] not in mapping:
+ warnings.warn(f'unknown prerelease version {version.pre[0]}, '
+ 'version checking may go wrong')
+ else:
+ val = mapping[version.pre[0]]
+ release.extend([val, version.pre[-1]])
+ else:
+ release.extend([val, 0])
+
+ elif version.is_postrelease:
+ release.extend([1, version.post])
+ else:
+ release.extend([0, 0])
+ return tuple(release)
+
+
+dbnet_cv_minimum_version = '1.3.8'
+dbnet_cv_maximum_version = '1.7.0'
+dbnet_cv_version = digit_version(dbnet_cv.__version__)
+
+assert (dbnet_cv_version >= digit_version(dbnet_cv_minimum_version)
+ and dbnet_cv_version <= digit_version(dbnet_cv_maximum_version)), \
+ f'DBNET_CV {dbnet_cv.__version__} is incompatible with MMOCR {__version__}. ' \
+ f'Please use DBNET_CV >= {dbnet_cv_minimum_version}, ' \
+ f'<= {dbnet_cv_maximum_version} instead.'
+
+# dbnet_det_minimum_version = '2.21.0'
+# dbnet_det_maximum_version = '3.0.0'
+# dbnet_det_version = digit_version(dbnet_det.__version__)
+
+# assert (dbnet_det_version >= digit_version(dbnet_det_minimum_version)
+# and dbnet_det_version <= digit_version(dbnet_det_maximum_version)), \
+# f'dbnet_detection {dbnet_det.__version__} is incompatible ' \
+# f'with MMOCR {__version__}. ' \
+# f'Please use dbnet_detection >= {dbnet_det_minimum_version}, ' \
+# f'<= {dbnet_det_maximum_version} instead.'
+
+__all__ = ['__version__', 'short_version', 'digit_version']
diff --git a/cv/ocr/dbnet/pytorch/dbnet/apis/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/apis/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..998771bea1bdbd03141d6bbab5c38dc91a966d8a
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/apis/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .inference import init_detector, model_inference
+from .test import single_gpu_test
+from .train import init_random_seed, train_detector
+# from .utils import (disable_text_recog_aug_test, replace_image_to_tensor,
+# tensor2grayimgs)
+
+# __all__ = [
+# 'model_inference', 'train_detector', 'init_detector', 'init_random_seed',
+# 'replace_image_to_tensor', 'disable_text_recog_aug_test',
+# 'single_gpu_test', 'tensor2grayimgs'
+# ]
diff --git a/cv/ocr/dbnet/pytorch/dbnet/apis/inference.py b/cv/ocr/dbnet/pytorch/dbnet/apis/inference.py
new file mode 100755
index 0000000000000000000000000000000000000000..33be6297951534579eeba713e6568b0e32a2bdc5
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/apis/inference.py
@@ -0,0 +1,238 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import dbnet_cv
+import numpy as np
+import torch
+from dbnet_cv.ops import RoIPool
+from dbnet_cv.parallel import collate, scatter
+from dbnet_cv.runner import load_checkpoint
+from dbnet_det.core import get_classes
+from dbnet_det.datasets import replace_ImageToTensor
+from dbnet_det.datasets.pipelines import Compose
+
+from dbnet.models import build_detector
+from dbnet.utils import is_2dlist
+from .utils import disable_text_recog_aug_test
+
+
+def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
+ """Initialize a detector from config file.
+
+ Args:
+ config (str or :obj:`dbnet_cv.Config`): Config file path or the config
+ object.
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
+ will not load any weights.
+ cfg_options (dict): Options to override some settings in the used
+ config.
+
+ Returns:
+ nn.Module: The constructed detector.
+ """
+ if isinstance(config, str):
+ config = dbnet_cv.Config.fromfile(config)
+ elif not isinstance(config, dbnet_cv.Config):
+ raise TypeError('config must be a filename or Config object, '
+ f'but got {type(config)}')
+ if cfg_options is not None:
+ config.merge_from_dict(cfg_options)
+ if config.model.get('pretrained'):
+ config.model.pretrained = None
+ config.model.train_cfg = None
+ model = build_detector(config.model, test_cfg=config.get('test_cfg'))
+ if checkpoint is not None:
+ checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
+ if 'CLASSES' in checkpoint.get('meta', {}):
+ model.CLASSES = checkpoint['meta']['CLASSES']
+ else:
+ warnings.simplefilter('once')
+ warnings.warn('Class names are not saved in the checkpoint\'s '
+ 'meta data, use COCO classes by default.')
+ model.CLASSES = get_classes('coco')
+ model.cfg = config # save the config in the model for convenience
+ model.to(device)
+ model.eval()
+ return model
+
+
+def model_inference(model,
+ imgs,
+ ann=None,
+ batch_mode=False,
+ return_data=False):
+ """Inference image(s) with the detector.
+
+ Args:
+ model (nn.Module): The loaded detector.
+ imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
+ Either image files or loaded images.
+ batch_mode (bool): If True, use batch mode for inference.
+ ann (dict): Annotation info for key information extraction.
+ return_data: Return postprocessed data.
+ Returns:
+ result (dict): Predicted results.
+ """
+
+ if isinstance(imgs, (list, tuple)):
+ is_batch = True
+ if len(imgs) == 0:
+ raise Exception('empty imgs provided, please check and try again')
+ if not isinstance(imgs[0], (np.ndarray, str)):
+ raise AssertionError('imgs must be strings or numpy arrays')
+
+ elif isinstance(imgs, (np.ndarray, str)):
+ imgs = [imgs]
+ is_batch = False
+ else:
+ raise AssertionError('imgs must be strings or numpy arrays')
+
+ is_ndarray = isinstance(imgs[0], np.ndarray)
+
+ cfg = model.cfg
+
+ if batch_mode:
+ cfg = disable_text_recog_aug_test(cfg, set_types=['test'])
+
+ device = next(model.parameters()).device # model device
+
+ if cfg.data.test.get('pipeline', None) is None:
+ if is_2dlist(cfg.data.test.datasets):
+ cfg.data.test.pipeline = cfg.data.test.datasets[0][0].pipeline
+ else:
+ cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline
+ if is_2dlist(cfg.data.test.pipeline):
+ cfg.data.test.pipeline = cfg.data.test.pipeline[0]
+
+ if is_ndarray:
+ cfg = cfg.copy()
+ # set loading pipeline type
+ cfg.data.test.pipeline[0].type = 'LoadImageFromNdarray'
+
+ cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
+ test_pipeline = Compose(cfg.data.test.pipeline)
+
+ datas = []
+ for img in imgs:
+ # prepare data
+ if is_ndarray:
+ # directly add img
+ data = dict(
+ img=img,
+ ann_info=ann,
+ img_info=dict(width=img.shape[1], height=img.shape[0]),
+ bbox_fields=[])
+ else:
+ # add information into dict
+ data = dict(
+ img_info=dict(filename=img),
+ img_prefix=None,
+ ann_info=ann,
+ bbox_fields=[])
+ if ann is not None:
+ data.update(dict(**ann))
+
+ # build the data pipeline
+ data = test_pipeline(data)
+ # get tensor from list to stack for batch mode (text detection)
+ if batch_mode:
+ if cfg.data.test.pipeline[1].type == 'MultiScaleFlipAug':
+ for key, value in data.items():
+ data[key] = value[0]
+ datas.append(data)
+
+ if isinstance(datas[0]['img'], list) and len(datas) > 1:
+ raise Exception('aug test does not support '
+ f'inference with batch size '
+ f'{len(datas)}')
+
+ data = collate(datas, samples_per_gpu=len(imgs))
+
+ # process img_metas
+ if isinstance(data['img_metas'], list):
+ data['img_metas'] = [
+ img_metas.data[0] for img_metas in data['img_metas']
+ ]
+ else:
+ data['img_metas'] = data['img_metas'].data
+
+ if isinstance(data['img'], list):
+ data['img'] = [img.data for img in data['img']]
+ if isinstance(data['img'][0], list):
+ data['img'] = [img[0] for img in data['img']]
+ else:
+ data['img'] = data['img'].data
+
+ # for KIE models
+ if ann is not None:
+ data['relations'] = data['relations'].data[0]
+ data['gt_bboxes'] = data['gt_bboxes'].data[0]
+ data['texts'] = data['texts'].data[0]
+ data['img'] = data['img'][0]
+ data['img_metas'] = data['img_metas'][0]
+
+ if next(model.parameters()).is_cuda:
+ # scatter to specified GPU
+ data = scatter(data, [device])[0]
+ else:
+ for m in model.modules():
+ assert not isinstance(
+ m, RoIPool
+ ), 'CPU inference with RoIPool is not supported currently.'
+
+ # forward the model
+ with torch.no_grad():
+ results = model(return_loss=False, rescale=True, **data)
+
+ if not is_batch:
+ if not return_data:
+ return results[0]
+ return results[0], datas[0]
+ else:
+ if not return_data:
+ return results
+ return results, datas
+
+
+def text_model_inference(model, input_sentence):
+ """Inference text(s) with the entity recognizer.
+
+ Args:
+ model (nn.Module): The loaded recognizer.
+ input_sentence (str): A text entered by the user.
+
+ Returns:
+ result (dict): Predicted results.
+ """
+
+ assert isinstance(input_sentence, str)
+
+ cfg = model.cfg
+ if cfg.data.test.get('pipeline', None) is None:
+ if is_2dlist(cfg.data.test.datasets):
+ cfg.data.test.pipeline = cfg.data.test.datasets[0][0].pipeline
+ else:
+ cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline
+ if is_2dlist(cfg.data.test.pipeline):
+ cfg.data.test.pipeline = cfg.data.test.pipeline[0]
+ test_pipeline = Compose(cfg.data.test.pipeline)
+ data = {'text': input_sentence, 'label': {}}
+
+ # build the data pipeline
+ data = test_pipeline(data)
+ if isinstance(data['img_metas'], dict):
+ img_metas = data['img_metas']
+ else:
+ img_metas = data['img_metas'].data
+
+ assert isinstance(img_metas, dict)
+ img_metas = {
+ 'input_ids': img_metas['input_ids'].unsqueeze(0),
+ 'attention_masks': img_metas['attention_masks'].unsqueeze(0),
+ 'token_type_ids': img_metas['token_type_ids'].unsqueeze(0),
+ 'labels': img_metas['labels'].unsqueeze(0)
+ }
+ # forward the model
+ with torch.no_grad():
+ result = model(None, img_metas, return_loss=False)
+ return result
diff --git a/cv/ocr/dbnet/pytorch/dbnet/apis/test.py b/cv/ocr/dbnet/pytorch/dbnet/apis/test.py
new file mode 100755
index 0000000000000000000000000000000000000000..fdfa747033df974e938be5b9a884a4d4fec79959
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/apis/test.py
@@ -0,0 +1,157 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+
+import dbnet_cv
+import numpy as np
+import torch
+from dbnet_cv.image import tensor2imgs
+from dbnet_cv.parallel import DataContainer
+from dbnet_det.core import encode_mask_results
+
+from .utils import tensor2grayimgs
+
+
+def retrieve_img_tensor_and_meta(data):
+ """Retrieval img_tensor, img_metas and img_norm_cfg.
+
+ Args:
+ data (dict): One batch data from data_loader.
+
+ Returns:
+ tuple: Returns (img_tensor, img_metas, img_norm_cfg).
+
+ - | img_tensor (Tensor): Input image tensor with shape
+ :math:`(N, C, H, W)`.
+ - | img_metas (list[dict]): The metadata of images.
+ - | img_norm_cfg (dict): Config for image normalization.
+ """
+
+ if isinstance(data['img'], torch.Tensor):
+ # for textrecog with batch_size > 1
+ # and not use 'DefaultFormatBundle' in pipeline
+ img_tensor = data['img']
+ img_metas = data['img_metas'].data[0]
+ elif isinstance(data['img'], list):
+ if isinstance(data['img'][0], torch.Tensor):
+ # for textrecog with aug_test and batch_size = 1
+ img_tensor = data['img'][0]
+ elif isinstance(data['img'][0], DataContainer):
+ # for textdet with 'MultiScaleFlipAug'
+ # and 'DefaultFormatBundle' in pipeline
+ img_tensor = data['img'][0].data[0]
+ img_metas = data['img_metas'][0].data[0]
+ elif isinstance(data['img'], DataContainer):
+ # for textrecog with 'DefaultFormatBundle' in pipeline
+ img_tensor = data['img'].data[0]
+ img_metas = data['img_metas'].data[0]
+
+ must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape', 'ori_shape']
+ for key in must_keys:
+ if key not in img_metas[0]:
+ raise KeyError(
+ f'Please add {key} to the "meta_keys" in the pipeline')
+
+ img_norm_cfg = img_metas[0]['img_norm_cfg']
+ if max(img_norm_cfg['mean']) <= 1:
+ img_norm_cfg['mean'] = [255 * x for x in img_norm_cfg['mean']]
+ img_norm_cfg['std'] = [255 * x for x in img_norm_cfg['std']]
+
+ return img_tensor, img_metas, img_norm_cfg
+
+
+def single_gpu_test(model,
+ data_loader,
+ show=False,
+ out_dir=None,
+ is_kie=False,
+ show_score_thr=0.3):
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ prog_bar = dbnet_cv.ProgressBar(len(dataset))
+ for data in data_loader:
+ with torch.no_grad():
+ result = model(return_loss=False, rescale=True, **data)
+
+ batch_size = len(result)
+ if show or out_dir:
+ if is_kie:
+ img_tensor = data['img'].data[0]
+ if img_tensor.shape[0] != 1:
+ raise KeyError('Visualizing KIE outputs in batches is'
+ 'currently not supported.')
+ gt_bboxes = data['gt_bboxes'].data[0]
+ img_metas = data['img_metas'].data[0]
+ must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape']
+ for key in must_keys:
+ if key not in img_metas[0]:
+ raise KeyError(
+ f'Please add {key} to the "meta_keys" in config.')
+ # for no visual model
+ if np.prod(img_tensor.shape) == 0:
+ imgs = []
+ for img_meta in img_metas:
+ try:
+ img = dbnet_cv.imread(img_meta['filename'])
+ except Exception as e:
+ print(f'Load image with error: {e}, '
+ 'use empty image instead.')
+ img = np.ones(
+ img_meta['img_shape'], dtype=np.uint8)
+ imgs.append(img)
+ else:
+ imgs = tensor2imgs(img_tensor,
+ **img_metas[0]['img_norm_cfg'])
+ for i, img in enumerate(imgs):
+ h, w, _ = img_metas[i]['img_shape']
+ img_show = img[:h, :w, :]
+ if out_dir:
+ out_file = osp.join(out_dir,
+ img_metas[i]['ori_filename'])
+ else:
+ out_file = None
+
+ model.module.show_result(
+ img_show,
+ result[i],
+ gt_bboxes[i],
+ show=show,
+ out_file=out_file)
+ else:
+ img_tensor, img_metas, img_norm_cfg = \
+ retrieve_img_tensor_and_meta(data)
+
+ if img_tensor.size(1) == 1:
+ imgs = tensor2grayimgs(img_tensor, **img_norm_cfg)
+ else:
+ imgs = tensor2imgs(img_tensor, **img_norm_cfg)
+ assert len(imgs) == len(img_metas)
+
+ for j, (img, img_meta) in enumerate(zip(imgs, img_metas)):
+ img_shape, ori_shape = img_meta['img_shape'], img_meta[
+ 'ori_shape']
+ img_show = img[:img_shape[0], :img_shape[1]]
+ img_show = dbnet_cv.imresize(img_show,
+ (ori_shape[1], ori_shape[0]))
+
+ if out_dir:
+ out_file = osp.join(out_dir, img_meta['ori_filename'])
+ else:
+ out_file = None
+
+ model.module.show_result(
+ img_show,
+ result[j],
+ show=show,
+ out_file=out_file,
+ score_thr=show_score_thr)
+
+ # encode mask results
+ if isinstance(result[0], tuple):
+ result = [(bbox_results, encode_mask_results(mask_results))
+ for bbox_results, mask_results in result]
+ results.extend(result)
+
+ for _ in range(batch_size):
+ prog_bar.update()
+ return results
diff --git a/cv/ocr/dbnet/pytorch/dbnet/apis/train.py b/cv/ocr/dbnet/pytorch/dbnet/apis/train.py
new file mode 100755
index 0000000000000000000000000000000000000000..e1008be26014eb0b96f18904a104416c162cd492
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/apis/train.py
@@ -0,0 +1,185 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import dbnet_cv
+import numpy as np
+import torch
+import torch.distributed as dist
+from dbnet_cv.parallel import MMDataParallel, MMDistributedDataParallel
+from dbnet_cv.runner import (DistSamplerSeedHook, EpochBasedRunner,
+ Fp16OptimizerHook, OptimizerHook, build_optimizer,
+ build_runner, get_dist_info)
+from dbnet_det.core import DistEvalHook, EvalHook
+from dbnet_det.datasets import build_dataloader, build_dataset
+
+from dbnet import digit_version
+from dbnet.apis.utils import (disable_text_recog_aug_test,
+ replace_image_to_tensor)
+from dbnet.utils import get_root_logger
+
+
+def train_detector(model,
+ dataset,
+ cfg,
+ distributed=False,
+ validate=False,
+ timestamp=None,
+ meta=None):
+ logger = get_root_logger(cfg.log_level)
+
+ # prepare data loaders
+ dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
+ # step 1: give default values and override (if exist) from cfg.data
+ default_loader_cfg = {
+ **dict(
+ num_gpus=len(cfg.gpu_ids),
+ dist=distributed,
+ seed=cfg.get('seed'),
+ drop_last=False,
+ persistent_workers=False),
+ **({} if torch.__version__ != 'parrots' else dict(
+ prefetch_num=2,
+ pin_memory=False,
+ )),
+ }
+ # update overall dataloader(for train, val and test) setting
+ default_loader_cfg.update({
+ k: v
+ for k, v in cfg.data.items() if k not in [
+ 'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
+ 'test_dataloader'
+ ]
+ })
+
+ # step 2: cfg.data.train_dataloader has highest priority
+ train_loader_cfg = dict(default_loader_cfg,
+ **cfg.data.get('train_dataloader', {}))
+
+ data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
+
+ # put model on gpus
+ if distributed:
+ find_unused_parameters = cfg.get('find_unused_parameters', False)
+ # Sets the `find_unused_parameters` parameter in
+ # torch.nn.parallel.DistributedDataParallel
+ model = MMDistributedDataParallel(
+ model.cuda(),
+ device_ids=[torch.cuda.current_device()],
+ broadcast_buffers=False,
+ find_unused_parameters=find_unused_parameters)
+ else:
+ if not torch.cuda.is_available():
+ assert digit_version(dbnet_cv.__version__) >= digit_version('1.4.4'), \
+ 'Please use DBNET_CV >= 1.4.4 for CPU training!'
+ model = MMDataParallel(model, device_ids=cfg.gpu_ids)
+
+ # build runner
+ optimizer = build_optimizer(model, cfg.optimizer)
+
+ if 'runner' not in cfg:
+ cfg.runner = {
+ 'type': 'EpochBasedRunner',
+ 'max_epochs': cfg.total_epochs
+ }
+ warnings.warn(
+ 'config is now expected to have a `runner` section, '
+ 'please set `runner` in your config.', UserWarning)
+ else:
+ if 'total_epochs' in cfg:
+ assert cfg.total_epochs == cfg.runner.max_epochs
+
+ runner = build_runner(
+ cfg.runner,
+ default_args=dict(
+ model=model,
+ optimizer=optimizer,
+ work_dir=cfg.work_dir,
+ logger=logger,
+ meta=meta))
+
+ # an ugly workaround to make .log and .log.json filenames the same
+ runner.timestamp = timestamp
+
+ # fp16 setting
+ fp16_cfg = cfg.get('fp16', None)
+ if fp16_cfg is not None:
+ optimizer_config = Fp16OptimizerHook(
+ **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
+ elif distributed and 'type' not in cfg.optimizer_config:
+ optimizer_config = OptimizerHook(**cfg.optimizer_config)
+ else:
+ optimizer_config = cfg.optimizer_config
+
+ # register hooks
+ runner.register_training_hooks(
+ cfg.lr_config,
+ optimizer_config,
+ cfg.checkpoint_config,
+ cfg.log_config,
+ cfg.get('momentum_config', None),
+ custom_hooks_config=cfg.get('custom_hooks', None))
+ if distributed:
+ if isinstance(runner, EpochBasedRunner):
+ runner.register_hook(DistSamplerSeedHook())
+
+ # register eval hooks
+ if validate:
+ val_samples_per_gpu = (cfg.data.get('val_dataloader', {})).get(
+ 'samples_per_gpu', cfg.data.get('samples_per_gpu', 1))
+ if val_samples_per_gpu > 1:
+ # Support batch_size > 1 in test for text recognition
+ # by disable MultiRotateAugOCR since it is useless for most case
+ cfg = disable_text_recog_aug_test(cfg)
+ cfg = replace_image_to_tensor(cfg)
+
+ val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
+
+ val_loader_cfg = {
+ **default_loader_cfg,
+ **dict(shuffle=False, drop_last=False),
+ **cfg.data.get('val_dataloader', {}),
+ **dict(samples_per_gpu=val_samples_per_gpu)
+ }
+
+ val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
+
+ eval_cfg = cfg.get('evaluation', {})
+ eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
+ eval_hook = DistEvalHook if distributed else EvalHook
+ runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
+
+ if cfg.resume_from:
+ runner.resume(cfg.resume_from)
+ elif cfg.load_from:
+ runner.load_checkpoint(cfg.load_from)
+ runner.run(data_loaders, cfg.workflow)
+
+
+def init_random_seed(seed=None, device='cuda'):
+ """Initialize random seed. If the seed is None, it will be replaced by a
+ random number, and then broadcasted to all processes.
+
+ Args:
+ seed (int, Optional): The seed.
+ device (str): The device where the seed will be put on.
+
+ Returns:
+ int: Seed to be used.
+ """
+ if seed is not None:
+ return seed
+
+ # Make sure all ranks share the same random seed to prevent
+ # some potential bugs. Please refer to
+ # https://github.com/open-mmlab/dbnet_detection/issues/6339
+ rank, world_size = get_dist_info()
+ seed = np.random.randint(2**31)
+ if world_size == 1:
+ return seed
+
+ if rank == 0:
+ random_num = torch.tensor(seed, dtype=torch.int32, device=device)
+ else:
+ random_num = torch.tensor(0, dtype=torch.int32, device=device)
+ dist.broadcast(random_num, src=0)
+ return random_num.item()
diff --git a/cv/ocr/dbnet/pytorch/dbnet/apis/utils.py b/cv/ocr/dbnet/pytorch/dbnet/apis/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..b724713a4f84ca4dbfd929da47ef356194508560
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/apis/utils.py
@@ -0,0 +1,123 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import warnings
+
+import dbnet_cv
+import numpy as np
+import torch
+from dbnet_det.datasets import replace_ImageToTensor
+
+from dbnet.utils import is_2dlist, is_type_list
+
+
+def update_pipeline(cfg, idx=None):
+ if idx is None:
+ if cfg.pipeline is not None:
+ cfg.pipeline = replace_ImageToTensor(cfg.pipeline)
+ else:
+ cfg.pipeline[idx] = replace_ImageToTensor(cfg.pipeline[idx])
+
+
+def replace_image_to_tensor(cfg, set_types=None):
+ """Replace 'ImageToTensor' to 'DefaultFormatBundle'."""
+ assert set_types is None or isinstance(set_types, list)
+ if set_types is None:
+ set_types = ['val', 'test']
+
+ cfg = copy.deepcopy(cfg)
+ for set_type in set_types:
+ assert set_type in ['val', 'test']
+ uniform_pipeline = cfg.data[set_type].get('pipeline', None)
+ if is_type_list(uniform_pipeline, dict):
+ update_pipeline(cfg.data[set_type])
+ elif is_2dlist(uniform_pipeline):
+ for idx, _ in enumerate(uniform_pipeline):
+ update_pipeline(cfg.data[set_type], idx)
+
+ for dataset in cfg.data[set_type].get('datasets', []):
+ if isinstance(dataset, list):
+ for each_dataset in dataset:
+ update_pipeline(each_dataset)
+ else:
+ update_pipeline(dataset)
+
+ return cfg
+
+
+def update_pipeline_recog(cfg, idx=None):
+ warning_msg = 'Remove "MultiRotateAugOCR" to support batch ' + \
+ 'inference since samples_per_gpu > 1.'
+ if idx is None:
+ if cfg.get('pipeline',
+ None) and cfg.pipeline[1].type == 'MultiRotateAugOCR':
+ warnings.warn(warning_msg)
+ cfg.pipeline = [cfg.pipeline[0], *cfg.pipeline[1].transforms]
+ else:
+ if cfg[idx][1].type == 'MultiRotateAugOCR':
+ warnings.warn(warning_msg)
+ cfg[idx] = [cfg[idx][0], *cfg[idx][1].transforms]
+
+
+def disable_text_recog_aug_test(cfg, set_types=None):
+ """Remove aug_test from test pipeline for text recognition.
+ Args:
+ cfg (dbnet_cv.Config): Input config.
+ set_types (list[str]): Type of dataset source. Should be
+ None or sublist of ['test', 'val'].
+ """
+ assert set_types is None or isinstance(set_types, list)
+ if set_types is None:
+ set_types = ['val', 'test']
+
+ cfg = copy.deepcopy(cfg)
+ warnings.simplefilter('once')
+ for set_type in set_types:
+ assert set_type in ['val', 'test']
+ dataset_type = cfg.data[set_type].type
+ if dataset_type not in [
+ 'ConcatDataset', 'UniformConcatDataset', 'OCRDataset',
+ 'OCRSegDataset'
+ ]:
+ continue
+
+ uniform_pipeline = cfg.data[set_type].get('pipeline', None)
+ if is_type_list(uniform_pipeline, dict):
+ update_pipeline_recog(cfg.data[set_type])
+ elif is_2dlist(uniform_pipeline):
+ for idx, _ in enumerate(uniform_pipeline):
+ update_pipeline_recog(cfg.data[set_type].pipeline, idx)
+
+ for dataset in cfg.data[set_type].get('datasets', []):
+ if isinstance(dataset, list):
+ for each_dataset in dataset:
+ update_pipeline_recog(each_dataset)
+ else:
+ update_pipeline_recog(dataset)
+
+ return cfg
+
+
+def tensor2grayimgs(tensor, mean=(127, ), std=(127, ), **kwargs):
+ """Convert tensor to 1-channel gray images.
+ Args:
+ tensor (torch.Tensor): Tensor that contains multiple images, shape (
+ N, C, H, W).
+ mean (tuple[float], optional): Mean of images. Defaults to (127).
+ std (tuple[float], optional): Standard deviation of images.
+ Defaults to (127).
+ Returns:
+ list[np.ndarray]: A list that contains multiple images.
+ """
+
+ assert torch.is_tensor(tensor) and tensor.ndim == 4
+ assert tensor.size(1) == len(mean) == len(std) == 1
+
+ num_imgs = tensor.size(0)
+ mean = np.array(mean, dtype=np.float32)
+ std = np.array(std, dtype=np.float32)
+ imgs = []
+ for img_id in range(num_imgs):
+ img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0)
+ img = dbnet_cv.imdenormalize(img, mean, std, to_bgr=False).astype(np.uint8)
+ imgs.append(np.ascontiguousarray(img))
+ return imgs
\ No newline at end of file
diff --git a/cv/ocr/dbnet/pytorch/dbnet/core/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/core/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..beae1ba42f375f7c3af16ac3b448160defaac41c
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/core/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from . import evaluation
+from .evaluation import * # NOQA
+from .mask import extract_boundary, points2boundary, seg2boundary
+from .visualize import (det_recog_show_result, imshow_edge, imshow_node,
+ imshow_pred_boundary, imshow_text_char_boundary,
+ imshow_text_label, overlay_mask_img, show_feature,
+ show_img_boundary, show_pred_gt)
+
+__all__ = [
+ 'points2boundary', 'seg2boundary', 'extract_boundary', 'overlay_mask_img',
+ 'show_feature', 'show_img_boundary', 'show_pred_gt',
+ 'imshow_pred_boundary', 'imshow_text_char_boundary', 'imshow_text_label',
+ 'imshow_node', 'det_recog_show_result', 'imshow_edge'
+]
+__all__ += evaluation.__all__
diff --git a/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..3c23516f7f3fe4c5ed800ececce40be42e558692
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .hmean import eval_hmean
+from .hmean_ic13 import eval_hmean_ic13
+from .hmean_iou import eval_hmean_iou
+# from .kie_metric import compute_f1_score
+# from .ner_metric import eval_ner_f1
+# from .ocr_metric import eval_ocr_metric
+
+__all__ = [
+ 'eval_hmean_ic13', 'eval_hmean_iou',
+ # 'eval_ocr_metric', 'eval_hmean',
+ # 'compute_f1_score', 'eval_ner_f1'
+]
diff --git a/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean.py b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean.py
new file mode 100755
index 0000000000000000000000000000000000000000..e97977606f0937c029057623d2a2145dd12b3a86
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean.py
@@ -0,0 +1,172 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from operator import itemgetter
+
+import dbnet_cv
+import numpy as np
+from dbnet_cv.utils import print_log
+
+import dbnet.utils as utils
+from dbnet.core.evaluation import hmean_ic13, hmean_iou
+from dbnet.core.evaluation.utils import (filter_2dlist_result,
+ select_top_boundary)
+from dbnet.core.mask import extract_boundary
+
+
+def output_ranklist(img_results, img_infos, out_file):
+ """Output the worst results for debugging.
+
+ Args:
+ img_results (list[dict]): Image result list.
+ img_infos (list[dict]): Image information list.
+ out_file (str): The output file path.
+
+ Returns:
+ sorted_results (list[dict]): Image results sorted by hmean.
+ """
+ assert utils.is_type_list(img_results, dict)
+ assert utils.is_type_list(img_infos, dict)
+ assert isinstance(out_file, str)
+ assert out_file.endswith('json')
+
+ sorted_results = []
+ for idx, result in enumerate(img_results):
+ name = img_infos[idx]['file_name']
+ img_result = result
+ img_result['file_name'] = name
+ sorted_results.append(img_result)
+ sorted_results = sorted(
+ sorted_results, key=itemgetter('hmean'), reverse=False)
+
+ dbnet_cv.dump(sorted_results, file=out_file)
+
+ return sorted_results
+
+
+def get_gt_masks(ann_infos):
+ """Get ground truth masks and ignored masks.
+
+ Args:
+ ann_infos (list[dict]): Each dict contains annotation
+ infos of one image, containing following keys:
+ masks, masks_ignore.
+ Returns:
+ gt_masks (list[list[list[int]]]): Ground truth masks.
+ gt_masks_ignore (list[list[list[int]]]): Ignored masks.
+ """
+ assert utils.is_type_list(ann_infos, dict)
+
+ gt_masks = []
+ gt_masks_ignore = []
+ for ann_info in ann_infos:
+ masks = ann_info['masks']
+ mask_gt = []
+ for mask in masks:
+ assert len(mask[0]) >= 8 and len(mask[0]) % 2 == 0
+ mask_gt.append(mask[0])
+ gt_masks.append(mask_gt)
+
+ masks_ignore = ann_info['masks_ignore']
+ mask_gt_ignore = []
+ for mask_ignore in masks_ignore:
+ assert len(mask_ignore[0]) >= 8 and len(mask_ignore[0]) % 2 == 0
+ mask_gt_ignore.append(mask_ignore[0])
+ gt_masks_ignore.append(mask_gt_ignore)
+
+ return gt_masks, gt_masks_ignore
+
+
+def eval_hmean(results,
+ img_infos,
+ ann_infos,
+ metrics={'hmean-iou'},
+ score_thr=None,
+ min_score_thr=0.3,
+ max_score_thr=0.9,
+ step=0.1,
+ rank_list=None,
+ logger=None,
+ **kwargs):
+ """Evaluation in hmean metric. It conducts grid search over a range of
+ boundary score thresholds and reports the best result.
+
+ Args:
+ results (list[dict]): Each dict corresponds to one image,
+ containing the following keys: boundary_result
+ img_infos (list[dict]): Each dict corresponds to one image,
+ containing the following keys: filename, height, width
+ ann_infos (list[dict]): Each dict corresponds to one image,
+ containing the following keys: masks, masks_ignore
+ score_thr (float): Deprecated. Please use min_score_thr instead.
+ min_score_thr (float): Minimum score threshold of prediction map.
+ max_score_thr (float): Maximum score threshold of prediction map.
+ step (float): The spacing between score thresholds.
+ metrics (set{str}): Hmean metric set, should be one or all of
+ {'hmean-iou', 'hmean-ic13'}
+ Returns:
+ dict[str: float]
+ """
+ assert utils.is_type_list(results, dict)
+ assert utils.is_type_list(img_infos, dict)
+ assert utils.is_type_list(ann_infos, dict)
+
+ if score_thr:
+ warnings.warn('score_thr is deprecated. Please use min_score_thr '
+ 'instead.')
+ min_score_thr = score_thr
+
+ assert 0 <= min_score_thr <= max_score_thr <= 1
+ assert 0 <= step <= 1
+ assert len(results) == len(img_infos) == len(ann_infos)
+ assert isinstance(metrics, set)
+
+ min_score_thr = float(min_score_thr)
+ max_score_thr = float(max_score_thr)
+ step = float(step)
+
+ gts, gts_ignore = get_gt_masks(ann_infos)
+
+ preds = []
+ pred_scores = []
+ for result in results:
+ _, texts, scores = extract_boundary(result)
+ if len(texts) > 0:
+ assert utils.valid_boundary(texts[0], False)
+ valid_texts, valid_text_scores = filter_2dlist_result(
+ texts, scores, min_score_thr)
+ preds.append(valid_texts)
+ pred_scores.append(valid_text_scores)
+
+ eval_results = {}
+
+ for metric in metrics:
+ msg = f'Evaluating {metric}...'
+ if logger is None:
+ msg = '\n' + msg
+ print_log(msg, logger=logger)
+ best_result = dict(hmean=-1)
+ for thr in np.arange(min_score_thr, min(max_score_thr + step, 1.0),
+ step):
+ top_preds = select_top_boundary(preds, pred_scores, thr)
+ if metric == 'hmean-iou':
+ result, img_result = hmean_iou.eval_hmean_iou(
+ top_preds, gts, gts_ignore)
+ elif metric == 'hmean-ic13':
+ result, img_result = hmean_ic13.eval_hmean_ic13(
+ top_preds, gts, gts_ignore)
+ else:
+ raise NotImplementedError
+ if rank_list is not None:
+ output_ranklist(img_result, img_infos, rank_list)
+
+ print_log(
+ 'thr {0:.2f}, recall: {1[recall]:.3f}, '
+ 'precision: {1[precision]:.3f}, '
+ 'hmean: {1[hmean]:.3f}'.format(thr, result),
+ logger=logger)
+ if result['hmean'] > best_result['hmean']:
+ best_result = result
+ eval_results[metric + ':recall'] = best_result['recall']
+ eval_results[metric + ':precision'] = best_result['precision']
+ eval_results[metric + ':hmean'] = best_result['hmean']
+ return eval_results
diff --git a/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean_ic13.py b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean_ic13.py
new file mode 100755
index 0000000000000000000000000000000000000000..a7139cab387946eb011507541ec4a7476c7dd54b
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean_ic13.py
@@ -0,0 +1,217 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+
+import dbnet.utils as utils
+from . import utils as eval_utils
+
+
+def compute_recall_precision(gt_polys, pred_polys):
+ """Compute the recall and the precision matrices between gt and predicted
+ polygons.
+
+ Args:
+ gt_polys (list[Polygon]): List of gt polygons.
+ pred_polys (list[Polygon]): List of predicted polygons.
+
+ Returns:
+ recall (ndarray): Recall matrix of size gt_num x det_num.
+ precision (ndarray): Precision matrix of size gt_num x det_num.
+ """
+ assert isinstance(gt_polys, list)
+ assert isinstance(pred_polys, list)
+
+ gt_num = len(gt_polys)
+ det_num = len(pred_polys)
+ sz = [gt_num, det_num]
+
+ recall = np.zeros(sz)
+ precision = np.zeros(sz)
+ # compute area recall and precision for each (gt, det) pair
+ # in one img
+ for gt_id in range(gt_num):
+ for pred_id in range(det_num):
+ gt = gt_polys[gt_id]
+ det = pred_polys[pred_id]
+
+ inter_area = eval_utils.poly_intersection(det, gt)
+ gt_area = gt.area
+ det_area = det.area
+ if gt_area != 0:
+ recall[gt_id, pred_id] = inter_area / gt_area
+ if det_area != 0:
+ precision[gt_id, pred_id] = inter_area / det_area
+
+ return recall, precision
+
+
+def eval_hmean_ic13(det_boxes,
+ gt_boxes,
+ gt_ignored_boxes,
+ precision_thr=0.4,
+ recall_thr=0.8,
+ center_dist_thr=1.0,
+ one2one_score=1.,
+ one2many_score=0.8,
+ many2one_score=1.):
+ """Evaluate hmean of text detection using the icdar2013 standard.
+
+ Args:
+ det_boxes (list[list[list[float]]]): List of arrays of shape (n, 2k).
+ Each element is the det_boxes for one img. k>=4.
+ gt_boxes (list[list[list[float]]]): List of arrays of shape (m, 2k).
+ Each element is the gt_boxes for one img. k>=4.
+ gt_ignored_boxes (list[list[list[float]]]): List of arrays of
+ (l, 2k). Each element is the ignored gt_boxes for one img. k>=4.
+ precision_thr (float): Precision threshold of the iou of one
+ (gt_box, det_box) pair.
+ recall_thr (float): Recall threshold of the iou of one
+ (gt_box, det_box) pair.
+ center_dist_thr (float): Distance threshold of one (gt_box, det_box)
+ center point pair.
+ one2one_score (float): Reward when one gt matches one det_box.
+ one2many_score (float): Reward when one gt matches many det_boxes.
+ many2one_score (float): Reward when many gts match one det_box.
+
+ Returns:
+ hmean (tuple[dict]): Tuple of dicts which encodes the hmean for
+ the dataset and all images.
+ """
+ assert utils.is_3dlist(det_boxes)
+ assert utils.is_3dlist(gt_boxes)
+ assert utils.is_3dlist(gt_ignored_boxes)
+
+ assert 0 <= precision_thr <= 1
+ assert 0 <= recall_thr <= 1
+ assert center_dist_thr > 0
+ assert 0 <= one2one_score <= 1
+ assert 0 <= one2many_score <= 1
+ assert 0 <= many2one_score <= 1
+
+ img_num = len(det_boxes)
+ assert img_num == len(gt_boxes)
+ assert img_num == len(gt_ignored_boxes)
+
+ dataset_gt_num = 0
+ dataset_pred_num = 0
+ dataset_hit_recall = 0.0
+ dataset_hit_prec = 0.0
+
+ img_results = []
+
+ for i in range(img_num):
+ gt = gt_boxes[i]
+ gt_ignored = gt_ignored_boxes[i]
+ pred = det_boxes[i]
+
+ gt_num = len(gt)
+ ignored_num = len(gt_ignored)
+ pred_num = len(pred)
+
+ accum_recall = 0.
+ accum_precision = 0.
+
+ gt_points = gt + gt_ignored
+ gt_polys = [eval_utils.points2polygon(p) for p in gt_points]
+ gt_ignored_index = [gt_num + i for i in range(len(gt_ignored))]
+ gt_num = len(gt_polys)
+
+ pred_polys, pred_points, pred_ignored_index = eval_utils.ignore_pred(
+ pred, gt_ignored_index, gt_polys, precision_thr)
+
+ if pred_num > 0 and gt_num > 0:
+
+ gt_hit = np.zeros(gt_num, np.int8).tolist()
+ pred_hit = np.zeros(pred_num, np.int8).tolist()
+
+ # compute area recall and precision for each (gt, pred) pair
+ # in one img.
+ recall_mat, precision_mat = compute_recall_precision(
+ gt_polys, pred_polys)
+
+ # match one gt to one pred box.
+ for gt_id in range(gt_num):
+ for pred_id in range(pred_num):
+ if (gt_hit[gt_id] != 0 or pred_hit[pred_id] != 0
+ or gt_id in gt_ignored_index
+ or pred_id in pred_ignored_index):
+ continue
+ match = eval_utils.one2one_match_ic13(
+ gt_id, pred_id, recall_mat, precision_mat, recall_thr,
+ precision_thr)
+
+ if match:
+ gt_point = np.array(gt_points[gt_id])
+ det_point = np.array(pred_points[pred_id])
+
+ norm_dist = eval_utils.box_center_distance(
+ det_point, gt_point)
+ norm_dist /= eval_utils.box_diag(
+ det_point) + eval_utils.box_diag(gt_point)
+ norm_dist *= 2.0
+
+ if norm_dist < center_dist_thr:
+ gt_hit[gt_id] = 1
+ pred_hit[pred_id] = 1
+ accum_recall += one2one_score
+ accum_precision += one2one_score
+
+ # match one gt to many det boxes.
+ for gt_id in range(gt_num):
+ if gt_id in gt_ignored_index:
+ continue
+ match, match_det_set = eval_utils.one2many_match_ic13(
+ gt_id, recall_mat, precision_mat, recall_thr,
+ precision_thr, gt_hit, pred_hit, pred_ignored_index)
+
+ if match:
+ gt_hit[gt_id] = 1
+ accum_recall += one2many_score
+ accum_precision += one2many_score * len(match_det_set)
+ for pred_id in match_det_set:
+ pred_hit[pred_id] = 1
+
+ # match many gt to one det box. One pair of (det,gt) are matched
+ # successfully if their recall, precision, normalized distance
+ # meet some thresholds.
+ for pred_id in range(pred_num):
+ if pred_id in pred_ignored_index:
+ continue
+
+ match, match_gt_set = eval_utils.many2one_match_ic13(
+ pred_id, recall_mat, precision_mat, recall_thr,
+ precision_thr, gt_hit, pred_hit, gt_ignored_index)
+
+ if match:
+ pred_hit[pred_id] = 1
+ accum_recall += many2one_score * len(match_gt_set)
+ accum_precision += many2one_score
+ for gt_id in match_gt_set:
+ gt_hit[gt_id] = 1
+
+ gt_care_number = gt_num - ignored_num
+ pred_care_number = pred_num - len(pred_ignored_index)
+
+ r, p, h = eval_utils.compute_hmean(accum_recall, accum_precision,
+ gt_care_number, pred_care_number)
+
+ img_results.append({'recall': r, 'precision': p, 'hmean': h})
+
+ dataset_gt_num += gt_care_number
+ dataset_pred_num += pred_care_number
+ dataset_hit_recall += accum_recall
+ dataset_hit_prec += accum_precision
+
+ total_r, total_p, total_h = eval_utils.compute_hmean(
+ dataset_hit_recall, dataset_hit_prec, dataset_gt_num, dataset_pred_num)
+
+ dataset_results = {
+ 'num_gts': dataset_gt_num,
+ 'num_dets': dataset_pred_num,
+ 'num_recall': dataset_hit_recall,
+ 'num_precision': dataset_hit_prec,
+ 'recall': total_r,
+ 'precision': total_p,
+ 'hmean': total_h
+ }
+
+ return dataset_results, img_results
diff --git a/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean_iou.py b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean_iou.py
new file mode 100755
index 0000000000000000000000000000000000000000..ade2aaa975ba8e4584329fdf066583ea76bec443
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean_iou.py
@@ -0,0 +1,117 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+
+import dbnet.utils as utils
+from . import utils as eval_utils
+
+
+def eval_hmean_iou(pred_boxes,
+ gt_boxes,
+ gt_ignored_boxes,
+ iou_thr=0.5,
+ precision_thr=0.5):
+ """Evaluate hmean of text detection using IOU standard.
+
+ Args:
+ pred_boxes (list[list[list[float]]]): Text boxes for an img list. Each
+ box has 2k (>=8) values.
+ gt_boxes (list[list[list[float]]]): Ground truth text boxes for an img
+ list. Each box has 2k (>=8) values.
+ gt_ignored_boxes (list[list[list[float]]]): Ignored ground truth text
+ boxes for an img list. Each box has 2k (>=8) values.
+ iou_thr (float): Iou threshold when one (gt_box, det_box) pair is
+ matched.
+ precision_thr (float): Precision threshold when one (gt_box, det_box)
+ pair is matched.
+
+ Returns:
+ hmean (tuple[dict]): Tuple of dicts indicates the hmean for the dataset
+ and all images.
+ """
+ assert utils.is_3dlist(pred_boxes)
+ assert utils.is_3dlist(gt_boxes)
+ assert utils.is_3dlist(gt_ignored_boxes)
+ assert 0 <= iou_thr <= 1
+ assert 0 <= precision_thr <= 1
+
+ img_num = len(pred_boxes)
+ assert img_num == len(gt_boxes)
+ assert img_num == len(gt_ignored_boxes)
+
+ dataset_gt_num = 0
+ dataset_pred_num = 0
+ dataset_hit_num = 0
+
+ img_results = []
+
+ for i in range(img_num):
+ gt = gt_boxes[i]
+ gt_ignored = gt_ignored_boxes[i]
+ pred = pred_boxes[i]
+
+ gt_num = len(gt)
+ gt_ignored_num = len(gt_ignored)
+ pred_num = len(pred)
+
+ hit_num = 0
+
+ # get gt polygons.
+ gt_all = gt + gt_ignored
+ gt_polys = [eval_utils.points2polygon(p) for p in gt_all]
+ gt_ignored_index = [gt_num + i for i in range(len(gt_ignored))]
+ gt_num = len(gt_polys)
+ pred_polys, _, pred_ignored_index = eval_utils.ignore_pred(
+ pred, gt_ignored_index, gt_polys, precision_thr)
+
+ # match.
+ if gt_num > 0 and pred_num > 0:
+ sz = [gt_num, pred_num]
+ iou_mat = np.zeros(sz)
+
+ gt_hit = np.zeros(gt_num, np.int8)
+ pred_hit = np.zeros(pred_num, np.int8)
+
+ for gt_id in range(gt_num):
+ for pred_id in range(pred_num):
+ gt_pol = gt_polys[gt_id]
+ det_pol = pred_polys[pred_id]
+
+ iou_mat[gt_id,
+ pred_id] = eval_utils.poly_iou(det_pol, gt_pol)
+
+ for gt_id in range(gt_num):
+ for pred_id in range(pred_num):
+ if (gt_hit[gt_id] != 0 or pred_hit[pred_id] != 0
+ or gt_id in gt_ignored_index
+ or pred_id in pred_ignored_index):
+ continue
+ if iou_mat[gt_id, pred_id] > iou_thr:
+ gt_hit[gt_id] = 1
+ pred_hit[pred_id] = 1
+ hit_num += 1
+
+ gt_care_number = gt_num - gt_ignored_num
+ pred_care_number = pred_num - len(pred_ignored_index)
+
+ r, p, h = eval_utils.compute_hmean(hit_num, hit_num, gt_care_number,
+ pred_care_number)
+
+ img_results.append({'recall': r, 'precision': p, 'hmean': h})
+
+ dataset_hit_num += hit_num
+ dataset_gt_num += gt_care_number
+ dataset_pred_num += pred_care_number
+
+ dataset_r, dataset_p, dataset_h = eval_utils.compute_hmean(
+ dataset_hit_num, dataset_hit_num, dataset_gt_num, dataset_pred_num)
+
+ dataset_results = {
+ 'num_gts': dataset_gt_num,
+ 'num_dets': dataset_pred_num,
+ 'num_match': dataset_hit_num,
+ 'recall': dataset_r,
+ 'precision': dataset_p,
+ 'hmean': dataset_h
+ }
+
+ return dataset_results, img_results
diff --git a/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/utils.py b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..4c0db33b78c1256b8476eb6090cc09f698ac797f
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/utils.py
@@ -0,0 +1,547 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+from shapely.geometry import Polygon as plg
+
+import dbnet.utils as utils
+
+
+def ignore_pred(pred_boxes, gt_ignored_index, gt_polys, precision_thr):
+ """Ignore the predicted box if it hits any ignored ground truth.
+
+ Args:
+ pred_boxes (list[ndarray or list]): The predicted boxes of one image.
+ gt_ignored_index (list[int]): The ignored ground truth index list.
+ gt_polys (list[Polygon]): The polygon list of one image.
+ precision_thr (float): The precision threshold.
+
+ Returns:
+ pred_polys (list[Polygon]): The predicted polygon list.
+ pred_points (list[list]): The predicted box list represented
+ by point sequences.
+ pred_ignored_index (list[int]): The ignored text index list.
+ """
+
+ assert isinstance(pred_boxes, list)
+ assert isinstance(gt_ignored_index, list)
+ assert isinstance(gt_polys, list)
+ assert 0 <= precision_thr <= 1
+
+ pred_polys = []
+ pred_points = []
+ pred_ignored_index = []
+
+ gt_ignored_num = len(gt_ignored_index)
+ # get detection polygons
+ for box_id, box in enumerate(pred_boxes):
+ poly = points2polygon(box)
+ pred_polys.append(poly)
+ pred_points.append(box)
+
+ if gt_ignored_num < 1:
+ continue
+
+ # ignore the current detection box
+ # if its overlap with any ignored gt > precision_thr
+ for ignored_box_id in gt_ignored_index:
+ ignored_box = gt_polys[ignored_box_id]
+ inter_area = poly_intersection(poly, ignored_box)
+ area = poly.area
+ precision = 0 if area == 0 else inter_area / area
+ if precision > precision_thr:
+ pred_ignored_index.append(box_id)
+ break
+
+ return pred_polys, pred_points, pred_ignored_index
+
+
+def compute_hmean(accum_hit_recall, accum_hit_prec, gt_num, pred_num):
+ """Compute hmean given hit number, ground truth number and prediction
+ number.
+
+ Args:
+ accum_hit_recall (int|float): Accumulated hits for computing recall.
+ accum_hit_prec (int|float): Accumulated hits for computing precision.
+ gt_num (int): Ground truth number.
+ pred_num (int): Prediction number.
+
+ Returns:
+ recall (float): The recall value.
+ precision (float): The precision value.
+ hmean (float): The hmean value.
+ """
+
+ assert isinstance(accum_hit_recall, (float, int))
+ assert isinstance(accum_hit_prec, (float, int))
+
+ assert isinstance(gt_num, int)
+ assert isinstance(pred_num, int)
+ assert accum_hit_recall >= 0.0
+ assert accum_hit_prec >= 0.0
+ assert gt_num >= 0.0
+ assert pred_num >= 0.0
+
+ if gt_num == 0:
+ recall = 1.0
+ precision = 0.0 if pred_num > 0 else 1.0
+ else:
+ recall = float(accum_hit_recall) / gt_num
+ precision = 0.0 if pred_num == 0 else float(accum_hit_prec) / pred_num
+
+ denom = recall + precision
+
+ hmean = 0.0 if denom == 0 else (2.0 * precision * recall / denom)
+
+ return recall, precision, hmean
+
+
+def box2polygon(box):
+ """Convert box to polygon.
+
+ Args:
+ box (ndarray or list): A ndarray or a list of shape (4)
+ that indicates 2 points.
+
+ Returns:
+ polygon (Polygon): A polygon object.
+ """
+ if isinstance(box, list):
+ box = np.array(box)
+
+ assert isinstance(box, np.ndarray)
+ assert box.size == 4
+ boundary = np.array(
+ [box[0], box[1], box[2], box[1], box[2], box[3], box[0], box[3]])
+
+ point_mat = boundary.reshape([-1, 2])
+ return plg(point_mat)
+
+
+def points2polygon(points):
+ """Convert k points to 1 polygon.
+
+ Args:
+ points (ndarray or list): A ndarray or a list of shape (2k)
+ that indicates k points.
+
+ Returns:
+ polygon (Polygon): A polygon object.
+ """
+ if isinstance(points, list):
+ points = np.array(points)
+
+ assert isinstance(points, np.ndarray)
+ assert (points.size % 2 == 0) and (points.size >= 8)
+
+ point_mat = points.reshape([-1, 2])
+ return plg(point_mat)
+
+
+def poly_make_valid(poly):
+ """Convert a potentially invalid polygon to a valid one by eliminating
+ self-crossing or self-touching parts.
+
+ Args:
+ poly (Polygon): A polygon needed to be converted.
+
+ Returns:
+ A valid polygon.
+ """
+ return poly if poly.is_valid else poly.buffer(0)
+
+
+def poly_intersection(poly_det, poly_gt, invalid_ret=None, return_poly=False):
+ """Calculate the intersection area between two polygon.
+
+ Args:
+ poly_det (Polygon): A polygon predicted by detector.
+ poly_gt (Polygon): A gt polygon.
+ invalid_ret (None|float|int): The return value when the invalid polygon
+ exists. If it is not specified, the function allows the computation
+ to proceed with invalid polygons by cleaning the their
+ self-touching or self-crossing parts.
+ return_poly (bool): Whether to return the polygon of the intersection
+ area.
+
+ Returns:
+ intersection_area (float): The intersection area between two polygons.
+ poly_obj (Polygon, optional): The Polygon object of the intersection
+ area. Set as `None` if the input is invalid.
+ """
+ assert isinstance(poly_det, plg)
+ assert isinstance(poly_gt, plg)
+ assert invalid_ret is None or isinstance(invalid_ret, float) or \
+ isinstance(invalid_ret, int)
+
+ if invalid_ret is None:
+ poly_det = poly_make_valid(poly_det)
+ poly_gt = poly_make_valid(poly_gt)
+
+ poly_obj = None
+ area = invalid_ret
+ if poly_det.is_valid and poly_gt.is_valid:
+ poly_obj = poly_det.intersection(poly_gt)
+ area = poly_obj.area
+ return (area, poly_obj) if return_poly else area
+
+
+def poly_union(poly_det, poly_gt, invalid_ret=None, return_poly=False):
+ """Calculate the union area between two polygon.
+ Args:
+ poly_det (Polygon): A polygon predicted by detector.
+ poly_gt (Polygon): A gt polygon.
+ invalid_ret (None|float|int): The return value when the invalid polygon
+ exists. If it is not specified, the function allows the computation
+ to proceed with invalid polygons by cleaning the their
+ self-touching or self-crossing parts.
+ return_poly (bool): Whether to return the polygon of the intersection
+ area.
+
+ Returns:
+ union_area (float): The union area between two polygons.
+ poly_obj (Polygon|MultiPolygon, optional): The Polygon or MultiPolygon
+ object of the union of the inputs. The type of object depends on
+ whether they intersect or not. Set as `None` if the input is
+ invalid.
+ """
+ assert isinstance(poly_det, plg)
+ assert isinstance(poly_gt, plg)
+ assert invalid_ret is None or isinstance(invalid_ret, float) or \
+ isinstance(invalid_ret, int)
+
+ if invalid_ret is None:
+ poly_det = poly_make_valid(poly_det)
+ poly_gt = poly_make_valid(poly_gt)
+
+ poly_obj = None
+ area = invalid_ret
+ if poly_det.is_valid and poly_gt.is_valid:
+ poly_obj = poly_det.union(poly_gt)
+ area = poly_obj.area
+ return (area, poly_obj) if return_poly else area
+
+
+def boundary_iou(src, target, zero_division=0):
+ """Calculate the IOU between two boundaries.
+
+ Args:
+ src (list): Source boundary.
+ target (list): Target boundary.
+ zero_division (int|float): The return value when invalid
+ boundary exists.
+
+ Returns:
+ iou (float): The iou between two boundaries.
+ """
+ assert utils.valid_boundary(src, False)
+ assert utils.valid_boundary(target, False)
+ src_poly = points2polygon(src)
+ target_poly = points2polygon(target)
+
+ return poly_iou(src_poly, target_poly, zero_division=zero_division)
+
+
+def poly_iou(poly_det, poly_gt, zero_division=0):
+ """Calculate the IOU between two polygons.
+
+ Args:
+ poly_det (Polygon): A polygon predicted by detector.
+ poly_gt (Polygon): A gt polygon.
+ zero_division (int|float): The return value when invalid
+ polygon exists.
+
+ Returns:
+ iou (float): The IOU between two polygons.
+ """
+ assert isinstance(poly_det, plg)
+ assert isinstance(poly_gt, plg)
+ area_inters = poly_intersection(poly_det, poly_gt)
+ area_union = poly_union(poly_det, poly_gt)
+ return area_inters / area_union if area_union != 0 else zero_division
+
+
+def one2one_match_ic13(gt_id, det_id, recall_mat, precision_mat, recall_thr,
+ precision_thr):
+ """One-to-One match gt and det with icdar2013 standards.
+
+ Args:
+ gt_id (int): The ground truth id index.
+ det_id (int): The detection result id index.
+ recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
+ being the recall ratio of gt i to det j.
+ precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
+ being the precision ratio of gt i to det j.
+ recall_thr (float): The recall threshold.
+ precision_thr (float): The precision threshold.
+ Returns:
+ True|False: Whether the gt and det are matched.
+ """
+ assert isinstance(gt_id, int)
+ assert isinstance(det_id, int)
+ assert isinstance(recall_mat, np.ndarray)
+ assert isinstance(precision_mat, np.ndarray)
+ assert 0 <= recall_thr <= 1
+ assert 0 <= precision_thr <= 1
+
+ cont = 0
+ for i in range(recall_mat.shape[1]):
+ if recall_mat[gt_id,
+ i] > recall_thr and precision_mat[gt_id,
+ i] > precision_thr:
+ cont += 1
+ if cont != 1:
+ return False
+
+ cont = 0
+ for i in range(recall_mat.shape[0]):
+ if recall_mat[i, det_id] > recall_thr and precision_mat[
+ i, det_id] > precision_thr:
+ cont += 1
+ if cont != 1:
+ return False
+
+ if recall_mat[gt_id, det_id] > recall_thr and precision_mat[
+ gt_id, det_id] > precision_thr:
+ return True
+
+ return False
+
+
+def one2many_match_ic13(gt_id, recall_mat, precision_mat, recall_thr,
+ precision_thr, gt_match_flag, det_match_flag,
+ det_ignored_index):
+ """One-to-Many match gt and detections with icdar2013 standards.
+
+ Args:
+ gt_id (int): gt index.
+ recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
+ being the recall ratio of gt i to det j.
+ precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
+ being the precision ratio of gt i to det j.
+ recall_thr (float): The recall threshold.
+ precision_thr (float): The precision threshold.
+ gt_match_flag (ndarray): An array indicates each gt matched already.
+ det_match_flag (ndarray): An array indicates each box has been
+ matched already or not.
+ det_ignored_index (list): A list indicates each detection box can be
+ ignored or not.
+
+ Returns:
+ tuple (True|False, list): The first indicates the gt is matched or not;
+ the second is the matched detection ids.
+ """
+ assert isinstance(gt_id, int)
+ assert isinstance(recall_mat, np.ndarray)
+ assert isinstance(precision_mat, np.ndarray)
+ assert 0 <= recall_thr <= 1
+ assert 0 <= precision_thr <= 1
+
+ assert isinstance(gt_match_flag, list)
+ assert isinstance(det_match_flag, list)
+ assert isinstance(det_ignored_index, list)
+
+ many_sum = 0.
+ det_ids = []
+ for det_id in range(recall_mat.shape[1]):
+ if gt_match_flag[gt_id] == 0 and det_match_flag[
+ det_id] == 0 and det_id not in det_ignored_index:
+ if precision_mat[gt_id, det_id] >= precision_thr:
+ many_sum += recall_mat[gt_id, det_id]
+ det_ids.append(det_id)
+ if many_sum >= recall_thr:
+ return True, det_ids
+ return False, []
+
+
+def many2one_match_ic13(det_id, recall_mat, precision_mat, recall_thr,
+ precision_thr, gt_match_flag, det_match_flag,
+ gt_ignored_index):
+ """Many-to-One match gt and detections with icdar2013 standards.
+
+ Args:
+ det_id (int): Detection index.
+ recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
+ being the recall ratio of gt i to det j.
+ precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
+ being the precision ratio of gt i to det j.
+ recall_thr (float): The recall threshold.
+ precision_thr (float): The precision threshold.
+ gt_match_flag (ndarray): An array indicates each gt has been matched
+ already.
+ det_match_flag (ndarray): An array indicates each detection box has
+ been matched already or not.
+ gt_ignored_index (list): A list indicates each gt box can be ignored
+ or not.
+
+ Returns:
+ tuple (True|False, list): The first indicates the detection is matched
+ or not; the second is the matched gt ids.
+ """
+ assert isinstance(det_id, int)
+ assert isinstance(recall_mat, np.ndarray)
+ assert isinstance(precision_mat, np.ndarray)
+ assert 0 <= recall_thr <= 1
+ assert 0 <= precision_thr <= 1
+
+ assert isinstance(gt_match_flag, list)
+ assert isinstance(det_match_flag, list)
+ assert isinstance(gt_ignored_index, list)
+ many_sum = 0.
+ gt_ids = []
+ for gt_id in range(recall_mat.shape[0]):
+ if gt_match_flag[gt_id] == 0 and det_match_flag[
+ det_id] == 0 and gt_id not in gt_ignored_index:
+ if recall_mat[gt_id, det_id] >= recall_thr:
+ many_sum += precision_mat[gt_id, det_id]
+ gt_ids.append(gt_id)
+ if many_sum >= precision_thr:
+ return True, gt_ids
+ return False, []
+
+
+def points_center(points):
+
+ assert isinstance(points, np.ndarray)
+ assert points.size % 2 == 0
+
+ points = points.reshape([-1, 2])
+ return np.mean(points, axis=0)
+
+
+def point_distance(p1, p2):
+ assert isinstance(p1, np.ndarray)
+ assert isinstance(p2, np.ndarray)
+
+ assert p1.size == 2
+ assert p2.size == 2
+
+ dist = np.square(p2 - p1)
+ dist = np.sum(dist)
+ dist = np.sqrt(dist)
+ return dist
+
+
+def box_center_distance(b1, b2):
+ assert isinstance(b1, np.ndarray)
+ assert isinstance(b2, np.ndarray)
+ return point_distance(points_center(b1), points_center(b2))
+
+
+def box_diag(box):
+ assert isinstance(box, np.ndarray)
+ assert box.size == 8
+
+ return point_distance(box[0:2], box[4:6])
+
+
+def filter_2dlist_result(results, scores, score_thr):
+ """Find out detected results whose score > score_thr.
+
+ Args:
+ results (list[list[float]]): The result list.
+ score (list): The score list.
+ score_thr (float): The score threshold.
+ Returns:
+ valid_results (list[list[float]]): The valid results.
+ valid_score (list[float]): The scores which correspond to the valid
+ results.
+ """
+ assert isinstance(results, list)
+ assert len(results) == len(scores)
+ assert isinstance(score_thr, float)
+ assert 0 <= score_thr <= 1
+
+ inds = np.array(scores) > score_thr
+ valid_results = [results[idx] for idx in np.where(inds)[0].tolist()]
+ valid_scores = [scores[idx] for idx in np.where(inds)[0].tolist()]
+ return valid_results, valid_scores
+
+
+def filter_result(results, scores, score_thr):
+ """Find out detected results whose score > score_thr.
+
+ Args:
+ results (ndarray): The results matrix of shape (n, k).
+ score (ndarray): The score vector of shape (n,).
+ score_thr (float): The score threshold.
+ Returns:
+ valid_results (ndarray): The valid results of shape (m,k) with m<=n.
+ valid_score (ndarray): The scores which correspond to the
+ valid results.
+ """
+ assert results.ndim == 2
+ assert scores.shape[0] == results.shape[0]
+ assert isinstance(score_thr, float)
+ assert 0 <= score_thr <= 1
+
+ inds = scores > score_thr
+ valid_results = results[inds, :]
+ valid_scores = scores[inds]
+ return valid_results, valid_scores
+
+
+def select_top_boundary(boundaries_list, scores_list, score_thr):
+ """Select poly boundaries with scores >= score_thr.
+
+ Args:
+ boundaries_list (list[list[list[float]]]): List of boundaries.
+ The 1st, 2nd, and 3rd indices are for image, text and
+ vertice, respectively.
+ scores_list (list(list[float])): List of lists of scores.
+ score_thr (float): The score threshold to filter out bboxes.
+
+ Returns:
+ selected_bboxes (list[list[list[float]]]): List of boundaries.
+ The 1st, 2nd, and 3rd indices are for image, text and vertice,
+ respectively.
+ """
+ assert isinstance(boundaries_list, list)
+ assert isinstance(scores_list, list)
+ assert isinstance(score_thr, float)
+ assert len(boundaries_list) == len(scores_list)
+ assert 0 <= score_thr <= 1
+
+ selected_boundaries = []
+ for boundary, scores in zip(boundaries_list, scores_list):
+ if len(scores) > 0:
+ assert len(scores) == len(boundary)
+ inds = [
+ iter for iter in range(len(scores))
+ if scores[iter] >= score_thr
+ ]
+ selected_boundaries.append([boundary[i] for i in inds])
+ else:
+ selected_boundaries.append(boundary)
+ return selected_boundaries
+
+
+def select_bboxes_via_score(bboxes_list, scores_list, score_thr):
+ """Select bboxes with scores >= score_thr.
+
+ Args:
+ bboxes_list (list[ndarray]): List of bboxes. Each element is ndarray of
+ shape (n,8)
+ scores_list (list(list[float])): List of lists of scores.
+ score_thr (float): The score threshold to filter out bboxes.
+
+ Returns:
+ selected_bboxes (list[ndarray]): List of bboxes. Each element is
+ ndarray of shape (m,8) with m<=n.
+ """
+ assert isinstance(bboxes_list, list)
+ assert isinstance(scores_list, list)
+ assert isinstance(score_thr, float)
+ assert len(bboxes_list) == len(scores_list)
+ assert 0 <= score_thr <= 1
+
+ selected_bboxes = []
+ for bboxes, scores in zip(bboxes_list, scores_list):
+ if len(scores) > 0:
+ assert len(scores) == bboxes.shape[0]
+ inds = [
+ iter for iter in range(len(scores))
+ if scores[iter] >= score_thr
+ ]
+ selected_bboxes.append(bboxes[inds, :])
+ else:
+ selected_bboxes.append(bboxes)
+ return selected_bboxes
diff --git a/cv/ocr/dbnet/pytorch/dbnet/core/mask.py b/cv/ocr/dbnet/pytorch/dbnet/core/mask.py
new file mode 100755
index 0000000000000000000000000000000000000000..dc5753621a0d286510307d2884079adf3202bf8b
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/core/mask.py
@@ -0,0 +1,102 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+
+import dbnet.utils as utils
+
+
+def points2boundary(points, text_repr_type, text_score=None, min_width=-1):
+ """Convert a text mask represented by point coordinates sequence into a
+ text boundary.
+
+ Args:
+ points (ndarray): Mask index of size (n, 2).
+ text_repr_type (str): Text instance encoding type
+ ('quad' for quadrangle or 'poly' for polygon).
+ text_score (float): Text score.
+
+ Returns:
+ boundary (list[float]): The text boundary point coordinates (x, y)
+ list. Return None if no text boundary found.
+ """
+ assert isinstance(points, np.ndarray)
+ assert points.shape[1] == 2
+ assert text_repr_type in ['quad', 'poly']
+ assert text_score is None or 0 <= text_score <= 1
+
+ if text_repr_type == 'quad':
+ rect = cv2.minAreaRect(points)
+ vertices = cv2.boxPoints(rect)
+ boundary = []
+ if min(rect[1]) > min_width:
+ boundary = [p for p in vertices.flatten().tolist()]
+
+ elif text_repr_type == 'poly':
+
+ height = np.max(points[:, 1]) + 10
+ width = np.max(points[:, 0]) + 10
+
+ mask = np.zeros((height, width), np.uint8)
+ mask[points[:, 1], points[:, 0]] = 255
+
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL,
+ cv2.CHAIN_APPROX_SIMPLE)
+ boundary = list(contours[0].flatten().tolist())
+
+ if text_score is not None:
+ boundary = boundary + [text_score]
+ if len(boundary) < 8:
+ return None
+
+ return boundary
+
+
+def seg2boundary(seg, text_repr_type, text_score=None):
+ """Convert a segmentation mask to a text boundary.
+
+ Args:
+ seg (ndarray): The segmentation mask.
+ text_repr_type (str): Text instance encoding type
+ ('quad' for quadrangle or 'poly' for polygon).
+ text_score (float): The text score.
+
+ Returns:
+ boundary (list): The text boundary. Return None if no text found.
+ """
+ assert isinstance(seg, np.ndarray)
+ assert isinstance(text_repr_type, str)
+ assert text_score is None or 0 <= text_score <= 1
+
+ points = np.where(seg)
+ # x, y order
+ points = np.concatenate([points[1], points[0]]).reshape(2, -1).transpose()
+ boundary = None
+ if len(points) != 0:
+ boundary = points2boundary(points, text_repr_type, text_score)
+
+ return boundary
+
+
+def extract_boundary(result):
+ """Extract boundaries and their scores from result.
+
+ Args:
+ result (dict): The detection result with the key 'boundary_result'
+ of one image.
+
+ Returns:
+ boundaries_with_scores (list[list[float]]): The boundary and score
+ list.
+ boundaries (list[list[float]]): The boundary list.
+ scores (list[float]): The boundary score list.
+ """
+ assert isinstance(result, dict)
+ assert 'boundary_result' in result.keys()
+
+ boundaries_with_scores = result['boundary_result']
+ assert utils.is_2dlist(boundaries_with_scores)
+
+ boundaries = [b[:-1] for b in boundaries_with_scores]
+ scores = [b[-1] for b in boundaries_with_scores]
+
+ return (boundaries_with_scores, boundaries, scores)
diff --git a/cv/ocr/dbnet/pytorch/dbnet/core/visualize.py b/cv/ocr/dbnet/pytorch/dbnet/core/visualize.py
new file mode 100755
index 0000000000000000000000000000000000000000..2934255b099be20121f92b9ac6467be40f158dbc
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/core/visualize.py
@@ -0,0 +1,888 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import os
+import shutil
+import urllib
+import warnings
+
+import cv2
+import dbnet_cv
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+from PIL import Image, ImageDraw, ImageFont
+
+import dbnet.utils as utils
+
+
+def overlay_mask_img(img, mask):
+ """Draw mask boundaries on image for visualization.
+
+ Args:
+ img (ndarray): The input image.
+ mask (ndarray): The instance mask.
+
+ Returns:
+ img (ndarray): The output image with instance boundaries on it.
+ """
+ assert isinstance(img, np.ndarray)
+ assert isinstance(mask, np.ndarray)
+
+ contours, _ = cv2.findContours(
+ mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+ cv2.drawContours(img, contours, -1, (0, 255, 0), 1)
+
+ return img
+
+
+def show_feature(features, names, to_uint8, out_file=None):
+ """Visualize a list of feature maps.
+
+ Args:
+ features (list(ndarray)): The feature map list.
+ names (list(str)): The visualized title list.
+ to_uint8 (list(1|0)): The list indicating whether to convent
+ feature maps to uint8.
+ out_file (str): The output file name. If set to None,
+ the output image will be shown without saving.
+ """
+ assert utils.is_type_list(features, np.ndarray)
+ assert utils.is_type_list(names, str)
+ assert utils.is_type_list(to_uint8, int)
+ assert utils.is_none_or_type(out_file, str)
+ assert utils.equal_len(features, names, to_uint8)
+
+ num = len(features)
+ row = col = math.ceil(math.sqrt(num))
+
+ for i, (f, n) in enumerate(zip(features, names)):
+ plt.subplot(row, col, i + 1)
+ plt.title(n)
+ if to_uint8[i]:
+ f = f.astype(np.uint8)
+ plt.imshow(f)
+ if out_file is None:
+ plt.show()
+ else:
+ plt.savefig(out_file)
+
+
+def show_img_boundary(img, boundary):
+ """Show image and instance boundaires.
+
+ Args:
+ img (ndarray): The input image.
+ boundary (list[float or int]): The input boundary.
+ """
+ assert isinstance(img, np.ndarray)
+ assert utils.is_type_list(boundary, (int, float))
+
+ cv2.polylines(
+ img, [np.array(boundary).astype(np.int32).reshape(-1, 1, 2)],
+ True,
+ color=(0, 255, 0),
+ thickness=1)
+ plt.imshow(img)
+ plt.show()
+
+
+def show_pred_gt(preds,
+ gts,
+ show=False,
+ win_name='',
+ wait_time=0,
+ out_file=None):
+ """Show detection and ground truth for one image.
+
+ Args:
+ preds (list[list[float]]): The detection boundary list.
+ gts (list[list[float]]): The ground truth boundary list.
+ show (bool): Whether to show the image.
+ win_name (str): The window name.
+ wait_time (int): The value of waitKey param.
+ out_file (str): The filename of the output.
+ """
+ assert utils.is_2dlist(preds)
+ assert utils.is_2dlist(gts)
+ assert isinstance(show, bool)
+ assert isinstance(win_name, str)
+ assert isinstance(wait_time, int)
+ assert utils.is_none_or_type(out_file, str)
+
+ p_xy = [p for boundary in preds for p in boundary]
+ gt_xy = [g for gt in gts for g in gt]
+
+ max_xy = np.max(np.array(p_xy + gt_xy).reshape(-1, 2), axis=0)
+
+ width = int(max_xy[0]) + 100
+ height = int(max_xy[1]) + 100
+
+ img = np.ones((height, width, 3), np.int8) * 255
+ pred_color = dbnet_cv.color_val('red')
+ gt_color = dbnet_cv.color_val('blue')
+ thickness = 1
+
+ for boundary in preds:
+ cv2.polylines(
+ img, [np.array(boundary).astype(np.int32).reshape(-1, 1, 2)],
+ True,
+ color=pred_color,
+ thickness=thickness)
+ for gt in gts:
+ cv2.polylines(
+ img, [np.array(gt).astype(np.int32).reshape(-1, 1, 2)],
+ True,
+ color=gt_color,
+ thickness=thickness)
+ if show:
+ dbnet_cv.imshow(img, win_name, wait_time)
+ if out_file is not None:
+ dbnet_cv.imwrite(img, out_file)
+
+ return img
+
+
+def imshow_pred_boundary(img,
+ boundaries_with_scores,
+ labels,
+ score_thr=0,
+ boundary_color='blue',
+ text_color='blue',
+ thickness=1,
+ font_scale=0.5,
+ show=True,
+ win_name='',
+ wait_time=0,
+ out_file=None,
+ show_score=False):
+ """Draw boundaries and class labels (with scores) on an image.
+
+ Args:
+ img (str or ndarray): The image to be displayed.
+ boundaries_with_scores (list[list[float]]): Boundaries with scores.
+ labels (list[int]): Labels of boundaries.
+ score_thr (float): Minimum score of boundaries to be shown.
+ boundary_color (str or tuple or :obj:`Color`): Color of boundaries.
+ text_color (str or tuple or :obj:`Color`): Color of texts.
+ thickness (int): Thickness of lines.
+ font_scale (float): Font scales of texts.
+ show (bool): Whether to show the image.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ out_file (str or None): The filename of the output.
+ show_score (bool): Whether to show text instance score.
+ """
+ assert isinstance(img, (str, np.ndarray))
+ assert utils.is_2dlist(boundaries_with_scores)
+ assert utils.is_type_list(labels, int)
+ assert utils.equal_len(boundaries_with_scores, labels)
+ if len(boundaries_with_scores) == 0:
+ warnings.warn('0 text found in ' + out_file)
+ return None
+
+ utils.valid_boundary(boundaries_with_scores[0])
+ img = dbnet_cv.imread(img)
+
+ scores = np.array([b[-1] for b in boundaries_with_scores])
+ inds = scores > score_thr
+ boundaries = [boundaries_with_scores[i][:-1] for i in np.where(inds)[0]]
+ scores = [scores[i] for i in np.where(inds)[0]]
+ labels = [labels[i] for i in np.where(inds)[0]]
+
+ boundary_color = dbnet_cv.color_val(boundary_color)
+ text_color = dbnet_cv.color_val(text_color)
+ font_scale = 0.5
+
+ for boundary, score in zip(boundaries, scores):
+ boundary_int = np.array(boundary).astype(np.int32)
+
+ cv2.polylines(
+ img, [boundary_int.reshape(-1, 1, 2)],
+ True,
+ color=boundary_color,
+ thickness=thickness)
+
+ if show_score:
+ label_text = f'{score:.02f}'
+ cv2.putText(img, label_text,
+ (boundary_int[0], boundary_int[1] - 2),
+ cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color)
+ if show:
+ dbnet_cv.imshow(img, win_name, wait_time)
+ if out_file is not None:
+ dbnet_cv.imwrite(img, out_file)
+
+ return img
+
+
+def imshow_text_char_boundary(img,
+ text_quads,
+ boundaries,
+ char_quads,
+ chars,
+ show=False,
+ thickness=1,
+ font_scale=0.5,
+ win_name='',
+ wait_time=-1,
+ out_file=None):
+ """Draw text boxes and char boxes on img.
+
+ Args:
+ img (str or ndarray): The img to be displayed.
+ text_quads (list[list[int|float]]): The text boxes.
+ boundaries (list[list[int|float]]): The boundary list.
+ char_quads (list[list[list[int|float]]]): A 2d list of char boxes.
+ char_quads[i] is for the ith text, and char_quads[i][j] is the jth
+ char of the ith text.
+ chars (list[list[char]]). The string for each text box.
+ thickness (int): Thickness of lines.
+ font_scale (float): Font scales of texts.
+ show (bool): Whether to show the image.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ out_file (str or None): The filename of the output.
+ """
+ assert isinstance(img, (np.ndarray, str))
+ assert utils.is_2dlist(text_quads)
+ assert utils.is_2dlist(boundaries)
+ assert utils.is_3dlist(char_quads)
+ assert utils.is_2dlist(chars)
+ assert utils.equal_len(text_quads, char_quads, boundaries)
+
+ img = dbnet_cv.imread(img)
+ char_color = [dbnet_cv.color_val('blue'), dbnet_cv.color_val('green')]
+ text_color = dbnet_cv.color_val('red')
+ text_inx = 0
+ for text_box, boundary, char_box, txt in zip(text_quads, boundaries,
+ char_quads, chars):
+ text_box = np.array(text_box)
+ boundary = np.array(boundary)
+
+ text_box = text_box.reshape(-1, 2).astype(np.int32)
+ cv2.polylines(
+ img, [text_box.reshape(-1, 1, 2)],
+ True,
+ color=text_color,
+ thickness=thickness)
+ if boundary.shape[0] > 0:
+ cv2.polylines(
+ img, [boundary.reshape(-1, 1, 2)],
+ True,
+ color=text_color,
+ thickness=thickness)
+
+ for b in char_box:
+ b = np.array(b)
+ c = char_color[text_inx % 2]
+ b = b.astype(np.int32)
+ cv2.polylines(
+ img, [b.reshape(-1, 1, 2)], True, color=c, thickness=thickness)
+
+ label_text = ''.join(txt)
+ cv2.putText(img, label_text, (text_box[0, 0], text_box[0, 1] - 2),
+ cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color)
+ text_inx = text_inx + 1
+
+ if show:
+ dbnet_cv.imshow(img, win_name, wait_time)
+ if out_file is not None:
+ dbnet_cv.imwrite(img, out_file)
+
+ return img
+
+
+def tile_image(images):
+ """Combined multiple images to one vertically.
+
+ Args:
+ images (list[np.ndarray]): Images to be combined.
+ """
+ assert isinstance(images, list)
+ assert len(images) > 0
+
+ for i, _ in enumerate(images):
+ if len(images[i].shape) == 2:
+ images[i] = cv2.cvtColor(images[i], cv2.COLOR_GRAY2BGR)
+
+ widths = [img.shape[1] for img in images]
+ heights = [img.shape[0] for img in images]
+ h, w = sum(heights), max(widths)
+ vis_img = np.zeros((h, w, 3), dtype=np.uint8)
+
+ offset_y = 0
+ for image in images:
+ img_h, img_w = image.shape[:2]
+ vis_img[offset_y:(offset_y + img_h), 0:img_w, :] = image
+ offset_y += img_h
+
+ return vis_img
+
+
+def imshow_text_label(img,
+ pred_label,
+ gt_label,
+ show=False,
+ win_name='',
+ wait_time=-1,
+ out_file=None):
+ """Draw predicted texts and ground truth texts on images.
+
+ Args:
+ img (str or np.ndarray): Image filename or loaded image.
+ pred_label (str): Predicted texts.
+ gt_label (str): Ground truth texts.
+ show (bool): Whether to show the image.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ out_file (str): The filename of the output.
+ """
+ assert isinstance(img, (np.ndarray, str))
+ assert isinstance(pred_label, str)
+ assert isinstance(gt_label, str)
+ assert isinstance(show, bool)
+ assert isinstance(win_name, str)
+ assert isinstance(wait_time, int)
+
+ img = dbnet_cv.imread(img)
+
+ src_h, src_w = img.shape[:2]
+ resize_height = 64
+ resize_width = int(1.0 * src_w / src_h * resize_height)
+ img = cv2.resize(img, (resize_width, resize_height))
+ h, w = img.shape[:2]
+
+ if is_contain_chinese(pred_label):
+ pred_img = draw_texts_by_pil(img, [pred_label], None)
+ else:
+ pred_img = np.ones((h, w, 3), dtype=np.uint8) * 255
+ cv2.putText(pred_img, pred_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX,
+ 0.9, (0, 0, 255), 2)
+ images = [pred_img, img]
+
+ if gt_label != '':
+ if is_contain_chinese(gt_label):
+ gt_img = draw_texts_by_pil(img, [gt_label], None)
+ else:
+ gt_img = np.ones((h, w, 3), dtype=np.uint8) * 255
+ cv2.putText(gt_img, gt_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX,
+ 0.9, (255, 0, 0), 2)
+ images.append(gt_img)
+
+ img = tile_image(images)
+
+ if show:
+ dbnet_cv.imshow(img, win_name, wait_time)
+ if out_file is not None:
+ dbnet_cv.imwrite(img, out_file)
+
+ return img
+
+
+def imshow_node(img,
+ result,
+ boxes,
+ idx_to_cls={},
+ show=False,
+ win_name='',
+ wait_time=-1,
+ out_file=None):
+
+ img = dbnet_cv.imread(img)
+ h, w = img.shape[:2]
+
+ max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1)
+ node_pred_label = max_idx.numpy().tolist()
+ node_pred_score = max_value.numpy().tolist()
+
+ texts, text_boxes = [], []
+ for i, box in enumerate(boxes):
+ new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]],
+ [box[0], box[3]]]
+ Pts = np.array([new_box], np.int32)
+ cv2.polylines(
+ img, [Pts.reshape((-1, 1, 2))],
+ True,
+ color=(255, 255, 0),
+ thickness=1)
+ x_min = int(min([point[0] for point in new_box]))
+ y_min = int(min([point[1] for point in new_box]))
+
+ # text
+ pred_label = str(node_pred_label[i])
+ if pred_label in idx_to_cls:
+ pred_label = idx_to_cls[pred_label]
+ pred_score = '{:.2f}'.format(node_pred_score[i])
+ text = pred_label + '(' + pred_score + ')'
+ texts.append(text)
+
+ # text box
+ font_size = int(
+ min(
+ abs(new_box[3][1] - new_box[0][1]),
+ abs(new_box[1][0] - new_box[0][0])))
+ char_num = len(text)
+ text_box = [
+ x_min * 2, y_min, x_min * 2 + font_size * char_num, y_min,
+ x_min * 2 + font_size * char_num, y_min + font_size, x_min * 2,
+ y_min + font_size
+ ]
+ text_boxes.append(text_box)
+
+ pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255
+ pred_img = draw_texts_by_pil(
+ pred_img, texts, text_boxes, draw_box=False, on_ori_img=True)
+
+ vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
+ vis_img[:, :w] = img
+ vis_img[:, w:] = pred_img
+
+ if show:
+ dbnet_cv.imshow(vis_img, win_name, wait_time)
+ if out_file is not None:
+ dbnet_cv.imwrite(vis_img, out_file)
+
+ return vis_img
+
+
+def gen_color():
+ """Generate BGR color schemes."""
+ color_list = [(101, 67, 254), (154, 157, 252), (173, 205, 249),
+ (123, 151, 138), (187, 200, 178), (148, 137, 69),
+ (169, 200, 200), (155, 175, 131), (154, 194, 182),
+ (178, 190, 137), (140, 211, 222), (83, 156, 222)]
+ return color_list
+
+
+def draw_polygons(img, polys):
+ """Draw polygons on image.
+
+ Args:
+ img (np.ndarray): The original image.
+ polys (list[list[float]]): Detected polygons.
+ Return:
+ out_img (np.ndarray): Visualized image.
+ """
+ dst_img = img.copy()
+ color_list = gen_color()
+ out_img = dst_img
+ for idx, poly in enumerate(polys):
+ poly = np.array(poly).reshape((-1, 1, 2)).astype(np.int32)
+ cv2.drawContours(
+ img,
+ np.array([poly]),
+ -1,
+ color_list[idx % len(color_list)],
+ thickness=cv2.FILLED)
+ out_img = cv2.addWeighted(dst_img, 0.5, img, 0.5, 0)
+ return out_img
+
+
+def get_optimal_font_scale(text, width):
+ """Get optimal font scale for cv2.putText.
+
+ Args:
+ text (str): Text in one box.
+ width (int): The box width.
+ """
+ for scale in reversed(range(0, 60, 1)):
+ textSize = cv2.getTextSize(
+ text,
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
+ fontScale=scale / 10,
+ thickness=1)
+ new_width = textSize[0][0]
+ if new_width <= width:
+ return scale / 10
+ return 1
+
+
+def draw_texts(img, texts, boxes=None, draw_box=True, on_ori_img=False):
+ """Draw boxes and texts on empty img.
+
+ Args:
+ img (np.ndarray): The original image.
+ texts (list[str]): Recognized texts.
+ boxes (list[list[float]]): Detected bounding boxes.
+ draw_box (bool): Whether draw box or not. If False, draw text only.
+ on_ori_img (bool): If True, draw box and text on input image,
+ else, on a new empty image.
+ Return:
+ out_img (np.ndarray): Visualized image.
+ """
+ color_list = gen_color()
+ h, w = img.shape[:2]
+ if boxes is None:
+ boxes = [[0, 0, w, 0, w, h, 0, h]]
+ assert len(texts) == len(boxes)
+
+ if on_ori_img:
+ out_img = img
+ else:
+ out_img = np.ones((h, w, 3), dtype=np.uint8) * 255
+ for idx, (box, text) in enumerate(zip(boxes, texts)):
+ if draw_box:
+ new_box = [[x, y] for x, y in zip(box[0::2], box[1::2])]
+ Pts = np.array([new_box], np.int32)
+ cv2.polylines(
+ out_img, [Pts.reshape((-1, 1, 2))],
+ True,
+ color=color_list[idx % len(color_list)],
+ thickness=1)
+ min_x = int(min(box[0::2]))
+ max_y = int(
+ np.mean(np.array(box[1::2])) + 0.2 *
+ (max(box[1::2]) - min(box[1::2])))
+ font_scale = get_optimal_font_scale(
+ text, int(max(box[0::2]) - min(box[0::2])))
+ cv2.putText(out_img, text, (min_x, max_y), cv2.FONT_HERSHEY_SIMPLEX,
+ font_scale, (0, 0, 0), 1)
+
+ return out_img
+
+
+def draw_texts_by_pil(img,
+ texts,
+ boxes=None,
+ draw_box=True,
+ on_ori_img=False,
+ font_size=None,
+ fill_color=None,
+ draw_pos=None,
+ return_text_size=False):
+ """Draw boxes and texts on empty image, especially for Chinese.
+
+ Args:
+ img (np.ndarray): The original image.
+ texts (list[str]): Recognized texts.
+ boxes (list[list[float]]): Detected bounding boxes.
+ draw_box (bool): Whether draw box or not. If False, draw text only.
+ on_ori_img (bool): If True, draw box and text on input image,
+ else on a new empty image.
+ font_size (int, optional): Size to create a font object for a font.
+ fill_color (tuple(int), optional): Fill color for text.
+ draw_pos (list[tuple(int)], optional): Start point to draw each text.
+ return_text_size (bool): If True, return the list of text size.
+
+ Returns:
+ (np.ndarray, list[tuple]) or np.ndarray: Return a tuple
+ ``(out_img, text_sizes)``, where ``out_img`` is the output image
+ with texts drawn on it and ``text_sizes`` are the size of drawing
+ texts. If ``return_text_size`` is False, only the output image will be
+ returned.
+ """
+
+ color_list = gen_color()
+ h, w = img.shape[:2]
+ if boxes is None:
+ boxes = [[0, 0, w, 0, w, h, 0, h]]
+ if draw_pos is None:
+ draw_pos = [None for _ in texts]
+ assert len(boxes) == len(texts) == len(draw_pos)
+
+ if fill_color is None:
+ fill_color = (0, 0, 0)
+
+ if on_ori_img:
+ out_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
+ else:
+ out_img = Image.new('RGB', (w, h), color=(255, 255, 255))
+ out_draw = ImageDraw.Draw(out_img)
+
+ text_sizes = []
+ for idx, (box, text, ori_point) in enumerate(zip(boxes, texts, draw_pos)):
+ if len(text) == 0:
+ continue
+ min_x, max_x = min(box[0::2]), max(box[0::2])
+ min_y, max_y = min(box[1::2]), max(box[1::2])
+ color = tuple(list(color_list[idx % len(color_list)])[::-1])
+ if draw_box:
+ out_draw.line(box, fill=color, width=1)
+ dirname, _ = os.path.split(os.path.abspath(__file__))
+ font_path = os.path.join(dirname, 'font.TTF')
+ if not os.path.exists(font_path):
+ url = ('https://download.openmmlab.com/mmocr/data/font.TTF')
+ print(f'Downloading {url} ...')
+ local_filename, _ = urllib.request.urlretrieve(url)
+ shutil.move(local_filename, font_path)
+ tmp_font_size = font_size
+ if tmp_font_size is None:
+ box_width = max(max_x - min_x, max_y - min_y)
+ tmp_font_size = int(0.9 * box_width / len(text))
+ fnt = ImageFont.truetype(font_path, tmp_font_size)
+ if ori_point is None:
+ ori_point = (min_x + 1, min_y + 1)
+ out_draw.text(ori_point, text, font=fnt, fill=fill_color)
+ text_sizes.append(fnt.getsize(text))
+
+ del out_draw
+
+ out_img = cv2.cvtColor(np.asarray(out_img), cv2.COLOR_RGB2BGR)
+
+ if return_text_size:
+ return out_img, text_sizes
+
+ return out_img
+
+
+def is_contain_chinese(check_str):
+ """Check whether string contains Chinese or not.
+
+ Args:
+ check_str (str): String to be checked.
+
+ Return True if contains Chinese, else False.
+ """
+ for ch in check_str:
+ if u'\u4e00' <= ch <= u'\u9fff':
+ return True
+ return False
+
+
+def det_recog_show_result(img, end2end_res, out_file=None):
+ """Draw `result`(boxes and texts) on `img`.
+
+ Args:
+ img (str or np.ndarray): The image to be displayed.
+ end2end_res (dict): Text detect and recognize results.
+ out_file (str): Image path where the visualized image should be saved.
+ Return:
+ out_img (np.ndarray): Visualized image.
+ """
+ img = dbnet_cv.imread(img)
+ boxes, texts = [], []
+ for res in end2end_res['result']:
+ boxes.append(res['box'])
+ texts.append(res['text'])
+ box_vis_img = draw_polygons(img, boxes)
+
+ if is_contain_chinese(''.join(texts)):
+ text_vis_img = draw_texts_by_pil(img, texts, boxes)
+ else:
+ text_vis_img = draw_texts(img, texts, boxes)
+
+ h, w = img.shape[:2]
+ out_img = np.ones((h, w * 2, 3), dtype=np.uint8)
+ out_img[:, :w, :] = box_vis_img
+ out_img[:, w:, :] = text_vis_img
+
+ if out_file:
+ dbnet_cv.imwrite(out_img, out_file)
+
+ return out_img
+
+
+def draw_edge_result(img, result, edge_thresh=0.5, keynode_thresh=0.5):
+ """Draw text and their relationship on empty images.
+
+ Args:
+ img (np.ndarray): The original image.
+ result (dict): The result of model forward_test, including:
+ - img_metas (list[dict]): List of meta information dictionary.
+ - nodes (Tensor): Node prediction with size:
+ number_node * node_classes.
+ - edges (Tensor): Edge prediction with size: number_edge * 2.
+ edge_thresh (float): Score threshold for edge classification.
+ keynode_thresh (float): Score threshold for node
+ (``key``) classification.
+
+ Returns:
+ np.ndarray: The image with key, value and relation drawn on it.
+ """
+
+ h, w = img.shape[:2]
+
+ vis_area_width = w // 3 * 2
+ vis_area_height = h
+ dist_key_to_value = vis_area_width // 2
+ dist_pair_to_pair = 30
+
+ bbox_x1 = dist_pair_to_pair
+ bbox_y1 = 0
+
+ new_w = vis_area_width
+ new_h = vis_area_height
+ pred_edge_img = np.ones((new_h, new_w, 3), dtype=np.uint8) * 255
+
+ nodes = result['nodes'].detach().cpu()
+ texts = result['img_metas'][0]['ori_texts']
+ num_nodes = result['nodes'].size(0)
+ edges = result['edges'].detach().cpu()[:, -1].view(num_nodes, num_nodes)
+
+ # (i, j) will be a valid pair
+ # either edge_score(node_i->node_j) > edge_thresh
+ # or edge_score(node_j->node_i) > edge_thresh
+ pairs = (torch.max(edges, edges.T) > edge_thresh).nonzero(as_tuple=True)
+ pairs = (pairs[0].numpy().tolist(), pairs[1].numpy().tolist())
+
+ # 1. "for n1, n2 in zip(*pairs) if n1 < n2":
+ # Only (n1, n2) will be included if n1 < n2 but not (n2, n1), to
+ # avoid duplication.
+ # 2. "(n1, n2) if nodes[n1, 1] > nodes[n1, 2]":
+ # nodes[n1, 1] is the score that this node is predicted as key,
+ # nodes[n1, 2] is the score that this node is predicted as value.
+ # If nodes[n1, 1] > nodes[n1, 2], n1 will be the index of key,
+ # so that n2 will be the index of value.
+ result_pairs = [(n1, n2) if nodes[n1, 1] > nodes[n1, 2] else (n2, n1)
+ for n1, n2 in zip(*pairs) if n1 < n2]
+
+ result_pairs.sort()
+ result_pairs_score = [
+ torch.max(edges[n1, n2], edges[n2, n1]) for n1, n2 in result_pairs
+ ]
+
+ key_current_idx = -1
+ pos_current = (-1, -1)
+ newline_flag = False
+
+ key_font_size = 15
+ value_font_size = 15
+ key_font_color = (0, 0, 0)
+ value_font_color = (0, 0, 255)
+ arrow_color = (0, 0, 255)
+ score_color = (0, 255, 0)
+ for pair, pair_score in zip(result_pairs, result_pairs_score):
+ key_idx = pair[0]
+ if nodes[key_idx, 1] < keynode_thresh:
+ continue
+ if key_idx != key_current_idx:
+ # move y-coords down for a new key
+ bbox_y1 += 10
+ # enlarge blank area to show key-value info
+ if newline_flag:
+ bbox_x1 += vis_area_width
+ tmp_img = np.ones(
+ (new_h, new_w + vis_area_width, 3), dtype=np.uint8) * 255
+ tmp_img[:new_h, :new_w] = pred_edge_img
+ pred_edge_img = tmp_img
+ new_w += vis_area_width
+ newline_flag = False
+ bbox_y1 = 10
+ key_text = texts[key_idx]
+ key_pos = (bbox_x1, bbox_y1)
+ value_idx = pair[1]
+ value_text = texts[value_idx]
+ value_pos = (bbox_x1 + dist_key_to_value, bbox_y1)
+ if key_idx != key_current_idx:
+ # draw text for a new key
+ key_current_idx = key_idx
+ pred_edge_img, text_sizes = draw_texts_by_pil(
+ pred_edge_img, [key_text],
+ draw_box=False,
+ on_ori_img=True,
+ font_size=key_font_size,
+ fill_color=key_font_color,
+ draw_pos=[key_pos],
+ return_text_size=True)
+ pos_right_bottom = (key_pos[0] + text_sizes[0][0],
+ key_pos[1] + text_sizes[0][1])
+ pos_current = (pos_right_bottom[0] + 5, bbox_y1 + 10)
+ pred_edge_img = cv2.arrowedLine(
+ pred_edge_img, (pos_right_bottom[0] + 5, bbox_y1 + 10),
+ (bbox_x1 + dist_key_to_value - 5, bbox_y1 + 10), arrow_color,
+ 1)
+ score_pos_x = int(
+ (pos_right_bottom[0] + bbox_x1 + dist_key_to_value) / 2.)
+ score_pos_y = bbox_y1 + 10 - int(key_font_size * 0.3)
+ else:
+ # draw arrow from key to value
+ if newline_flag:
+ tmp_img = np.ones((new_h + dist_pair_to_pair, new_w, 3),
+ dtype=np.uint8) * 255
+ tmp_img[:new_h, :new_w] = pred_edge_img
+ pred_edge_img = tmp_img
+ new_h += dist_pair_to_pair
+ pred_edge_img = cv2.arrowedLine(pred_edge_img, pos_current,
+ (bbox_x1 + dist_key_to_value - 5,
+ bbox_y1 + 10), arrow_color, 1)
+ score_pos_x = int(
+ (pos_current[0] + bbox_x1 + dist_key_to_value - 5) / 2.)
+ score_pos_y = int((pos_current[1] + bbox_y1 + 10) / 2.)
+ # draw edge score
+ cv2.putText(pred_edge_img, '{:.2f}'.format(pair_score),
+ (score_pos_x, score_pos_y), cv2.FONT_HERSHEY_COMPLEX, 0.4,
+ score_color)
+ # draw text for value
+ pred_edge_img = draw_texts_by_pil(
+ pred_edge_img, [value_text],
+ draw_box=False,
+ on_ori_img=True,
+ font_size=value_font_size,
+ fill_color=value_font_color,
+ draw_pos=[value_pos],
+ return_text_size=False)
+ bbox_y1 += dist_pair_to_pair
+ if bbox_y1 + dist_pair_to_pair >= new_h:
+ newline_flag = True
+
+ return pred_edge_img
+
+
+def imshow_edge(img,
+ result,
+ boxes,
+ show=False,
+ win_name='',
+ wait_time=-1,
+ out_file=None):
+ """Display the prediction results of the nodes and edges of the KIE model.
+
+ Args:
+ img (np.ndarray): The original image.
+ result (dict): The result of model forward_test, including:
+ - img_metas (list[dict]): List of meta information dictionary.
+ - nodes (Tensor): Node prediction with size: \
+ number_node * node_classes.
+ - edges (Tensor): Edge prediction with size: number_edge * 2.
+ boxes (list): The text boxes corresponding to the nodes.
+ show (bool): Whether to show the image. Default: False.
+ win_name (str): The window name. Default: ''
+ wait_time (float): Value of waitKey param. Default: 0.
+ out_file (str or None): The filename to write the image.
+ Default: None.
+
+ Returns:
+ np.ndarray: The image with key, value and relation drawn on it.
+ """
+ img = dbnet_cv.imread(img)
+ h, w = img.shape[:2]
+ color_list = gen_color()
+
+ for i, box in enumerate(boxes):
+ new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]],
+ [box[0], box[3]]]
+ Pts = np.array([new_box], np.int32)
+ cv2.polylines(
+ img, [Pts.reshape((-1, 1, 2))],
+ True,
+ color=color_list[i % len(color_list)],
+ thickness=1)
+
+ pred_img_h = h
+ pred_img_w = w
+
+ pred_edge_img = draw_edge_result(img, result)
+ pred_img_h = max(pred_img_h, pred_edge_img.shape[0])
+ pred_img_w += pred_edge_img.shape[1]
+
+ vis_img = np.zeros((pred_img_h, pred_img_w, 3), dtype=np.uint8)
+ vis_img[:h, :w] = img
+ vis_img[:, w:] = 255
+
+ height_t, width_t = pred_edge_img.shape[:2]
+ vis_img[:height_t, w:(w + width_t)] = pred_edge_img
+
+ if show:
+ dbnet_cv.imshow(vis_img, win_name, wait_time)
+ if out_file is not None:
+ dbnet_cv.imwrite(vis_img, out_file)
+ res_dic = {
+ 'boxes': boxes,
+ 'nodes': result['nodes'].detach().cpu(),
+ 'edges': result['edges'].detach().cpu(),
+ 'metas': result['img_metas'][0]
+ }
+ dbnet_cv.dump(res_dic, f'{out_file}_res.pkl')
+
+ return vis_img
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..53e86ff549fab879e5adf004188e9a847cbaca3c
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/__init__.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from dbnet_det.datasets.builder import DATASETS, build_dataloader, build_dataset
+
+from . import utils
+from .base_dataset import BaseDataset
+from .icdar_dataset import IcdarDataset
+# from .kie_dataset import KIEDataset
+# from .ner_dataset import NerDataset
+# from .ocr_dataset import OCRDataset
+# from .ocr_seg_dataset import OCRSegDataset
+# from .openset_kie_dataset import OpensetKIEDataset
+from .pipelines import CustomFormatBundle, DBNetTargets
+# from .pipelines import CustomFormatBundle, DBNetTargets, FCENetTargets
+from .text_det_dataset import TextDetDataset
+from .uniform_concat_dataset import UniformConcatDataset
+from .utils import * # NOQA
+
+# __all__ = [
+# 'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset',
+# 'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle',
+# 'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets',
+# 'NerDataset', 'UniformConcatDataset', 'OpensetKIEDataset'
+# ]
+
+# __all__ += utils.__all__
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/base_dataset.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/base_dataset.py
new file mode 100755
index 0000000000000000000000000000000000000000..879454dad2d3fd4bcc190edd70cf6060db7abff9
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/base_dataset.py
@@ -0,0 +1,169 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+from dbnet_cv.utils import print_log
+from dbnet_det.datasets.builder import DATASETS
+from dbnet_det.datasets.pipelines import Compose
+from torch.utils.data import Dataset
+
+from dbnet.datasets.builder import build_loader
+
+
+@DATASETS.register_module()
+class BaseDataset(Dataset):
+ """Custom dataset for text detection, text recognition, and their
+ downstream tasks.
+
+ 1. The text detection annotation format is as follows:
+ The `annotations` field is optional for testing
+ (this is one line of anno_file, with line-json-str
+ converted to dict for visualizing only).
+
+ .. code-block:: json
+
+ {
+ "file_name": "sample.jpg",
+ "height": 1080,
+ "width": 960,
+ "annotations":
+ [
+ {
+ "iscrowd": 0,
+ "category_id": 1,
+ "bbox": [357.0, 667.0, 804.0, 100.0],
+ "segmentation": [[361, 667, 710, 670,
+ 72, 767, 357, 763]]
+ }
+ ]
+ }
+
+ 2. The two text recognition annotation formats are as follows:
+ The `x1,y1,x2,y2,x3,y3,x4,y4` field is used for online crop
+ augmentation during training.
+
+ format1: sample.jpg hello
+ format2: sample.jpg 20 20 100 20 100 40 20 40 hello
+
+ Args:
+ ann_file (str): Annotation file path.
+ pipeline (list[dict]): Processing pipeline.
+ loader (dict): Dictionary to construct loader
+ to load annotation infos.
+ img_prefix (str, optional): Image prefix to generate full
+ image path.
+ test_mode (bool, optional): If set True, try...except will
+ be turned off in __getitem__.
+ """
+
+ def __init__(self,
+ ann_file,
+ loader,
+ pipeline,
+ img_prefix='',
+ test_mode=False):
+ super().__init__()
+ self.test_mode = test_mode
+ self.img_prefix = img_prefix
+ self.ann_file = ann_file
+ # load annotations
+ loader.update(ann_file=ann_file)
+ self.data_infos = build_loader(loader)
+ # processing pipeline
+ self.pipeline = Compose(pipeline)
+ # set group flag and class, no meaning
+ # for text detect and recognize
+ self._set_group_flag()
+ self.CLASSES = 0
+
+ def __len__(self):
+ return len(self.data_infos)
+
+ def _set_group_flag(self):
+ """Set flag."""
+ self.flag = np.zeros(len(self), dtype=np.uint8)
+
+ def pre_pipeline(self, results):
+ """Prepare results dict for pipeline."""
+ results['img_prefix'] = self.img_prefix
+
+ def prepare_train_img(self, index):
+ """Get training data and annotations from pipeline.
+
+ Args:
+ index (int): Index of data.
+
+ Returns:
+ dict: Training data and annotation after pipeline with new keys
+ introduced by pipeline.
+ """
+ img_info = self.data_infos[index]
+ results = dict(img_info=img_info)
+ self.pre_pipeline(results)
+ return self.pipeline(results)
+
+ def prepare_test_img(self, img_info):
+ """Get testing data from pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Testing data after pipeline with new keys introduced by
+ pipeline.
+ """
+ return self.prepare_train_img(img_info)
+
+ def _log_error_index(self, index):
+ """Logging data info of bad index."""
+ try:
+ data_info = self.data_infos[index]
+ img_prefix = self.img_prefix
+ print_log(f'Warning: skip broken file {data_info} '
+ f'with img_prefix {img_prefix}')
+ except Exception as e:
+ print_log(f'load index {index} with error {e}')
+
+ def _get_next_index(self, index):
+ """Get next index from dataset."""
+ self._log_error_index(index)
+ index = (index + 1) % len(self)
+ return index
+
+ def __getitem__(self, index):
+ """Get training/test data from pipeline.
+
+ Args:
+ index (int): Index of data.
+
+ Returns:
+ dict: Training/test data.
+ """
+ if self.test_mode:
+ return self.prepare_test_img(index)
+
+ while True:
+ try:
+ data = self.prepare_train_img(index)
+ if data is None:
+ raise Exception('prepared train data empty')
+ break
+ except Exception as e:
+ print_log(f'prepare index {index} with error {e}')
+ index = self._get_next_index(index)
+ return data
+
+ def format_results(self, results, **kwargs):
+ """Placeholder to format result to dataset-specific output."""
+ pass
+
+ def evaluate(self, results, metric=None, logger=None, **kwargs):
+ """Evaluate the dataset.
+
+ Args:
+ results (list): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ Returns:
+ dict[str: float]
+ """
+ raise NotImplementedError
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/builder.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/builder.py
new file mode 100755
index 0000000000000000000000000000000000000000..a4dcfbc766e617b771eb64185ff747294d4b9aed
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/builder.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from dbnet_cv.utils import Registry, build_from_cfg
+
+LOADERS = Registry('loader')
+PARSERS = Registry('parser')
+
+
+def build_loader(cfg):
+ """Build anno file loader."""
+ return build_from_cfg(cfg, LOADERS)
+
+
+def build_parser(cfg):
+ """Build anno file parser."""
+ return build_from_cfg(cfg, PARSERS)
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/icdar_dataset.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/icdar_dataset.py
new file mode 100755
index 0000000000000000000000000000000000000000..f18cece80bc1a400c23009b65fda3652f3a0b14f
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/icdar_dataset.py
@@ -0,0 +1,191 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import dbnet_cv
+import numpy as np
+from dbnet_det.datasets.api_wrappers import COCO
+from dbnet_det.datasets.builder import DATASETS
+from dbnet_det.datasets.coco import CocoDataset
+
+import dbnet.utils as utils
+from dbnet import digit_version
+from dbnet.core.evaluation.hmean import eval_hmean
+
+
+@DATASETS.register_module()
+class IcdarDataset(CocoDataset):
+ """Dataset for text detection while ann_file in coco format.
+
+ Args:
+ ann_file_backend (str): Storage backend for annotation file,
+ should be one in ['disk', 'petrel', 'http']. Default to 'disk'.
+ """
+ CLASSES = ('text')
+
+ def __init__(self,
+ ann_file,
+ pipeline,
+ classes=None,
+ data_root=None,
+ img_prefix='',
+ seg_prefix=None,
+ proposal_file=None,
+ test_mode=False,
+ filter_empty_gt=True,
+ select_first_k=-1,
+ ann_file_backend='disk'):
+ # select first k images for fast debugging.
+ self.select_first_k = select_first_k
+ assert ann_file_backend in ['disk', 'petrel', 'http']
+ self.ann_file_backend = ann_file_backend
+
+ super().__init__(ann_file, pipeline, classes, data_root, img_prefix,
+ seg_prefix, proposal_file, test_mode, filter_empty_gt)
+
+ # Set dummy flags just to be compatible with dbnet_det
+ self.flag = np.zeros(len(self), dtype=np.uint8)
+
+ def load_annotations(self, ann_file):
+ """Load annotation from COCO style annotation file.
+
+ Args:
+ ann_file (str): Path of annotation file.
+
+ Returns:
+ list[dict]: Annotation info from COCO api.
+ """
+ if self.ann_file_backend == 'disk':
+ self.coco = COCO(ann_file)
+ else:
+ dbnet_cv_version = digit_version(dbnet_cv.__version__)
+ if dbnet_cv_version < digit_version('1.3.16'):
+ raise Exception('Please update dbnet_cv to 1.3.16 or higher '
+ 'to enable "get_local_path" of "FileClient".')
+ file_client = dbnet_cv.FileClient(backend=self.ann_file_backend)
+ with file_client.get_local_path(ann_file) as local_path:
+ self.coco = COCO(local_path)
+ self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
+ self.img_ids = self.coco.get_img_ids()
+ data_infos = []
+
+ count = 0
+ for i in self.img_ids:
+ info = self.coco.load_imgs([i])[0]
+ info['filename'] = info['file_name']
+ data_infos.append(info)
+ count = count + 1
+ if count > self.select_first_k and self.select_first_k > 0:
+ break
+ return data_infos
+
+ def _parse_ann_info(self, img_info, ann_info):
+ """Parse bbox and mask annotation.
+
+ Args:
+ ann_info (list[dict]): Annotation info of an image.
+
+ Returns:
+ dict: A dict containing the following keys: bboxes, bboxes_ignore,
+ labels, masks, masks_ignore, seg_map. "masks" and
+ "masks_ignore" are represented by polygon boundary
+ point sequences.
+ """
+ gt_bboxes = []
+ gt_labels = []
+ gt_bboxes_ignore = []
+ gt_masks_ignore = []
+ gt_masks_ann = []
+
+ for ann in ann_info:
+ if ann.get('ignore', False):
+ continue
+ x1, y1, w, h = ann['bbox']
+ if ann['area'] <= 0 or w < 1 or h < 1:
+ continue
+ if ann['category_id'] not in self.cat_ids:
+ continue
+ bbox = [x1, y1, x1 + w, y1 + h]
+ if ann.get('iscrowd', False):
+ gt_bboxes_ignore.append(bbox)
+ gt_masks_ignore.append(ann.get(
+ 'segmentation', None)) # to float32 for latter processing
+
+ else:
+ gt_bboxes.append(bbox)
+ gt_labels.append(self.cat2label[ann['category_id']])
+ gt_masks_ann.append(ann.get('segmentation', None))
+ if gt_bboxes:
+ gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
+ gt_labels = np.array(gt_labels, dtype=np.int64)
+ else:
+ gt_bboxes = np.zeros((0, 4), dtype=np.float32)
+ gt_labels = np.array([], dtype=np.int64)
+
+ if gt_bboxes_ignore:
+ gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
+ else:
+ gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
+
+ seg_map = img_info['filename'].replace('jpg', 'png')
+
+ ann = dict(
+ bboxes=gt_bboxes,
+ labels=gt_labels,
+ bboxes_ignore=gt_bboxes_ignore,
+ masks_ignore=gt_masks_ignore,
+ masks=gt_masks_ann,
+ seg_map=seg_map)
+
+ return ann
+
+ def evaluate(self,
+ results,
+ metric='hmean-iou',
+ logger=None,
+ score_thr=None,
+ min_score_thr=0.3,
+ max_score_thr=0.9,
+ step=0.1,
+ rank_list=None,
+ **kwargs):
+ """Evaluate the hmean metric.
+
+ Args:
+ results (list[dict]): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ score_thr (float): Deprecated. Please use min_score_thr instead.
+ min_score_thr (float): Minimum score threshold of prediction map.
+ max_score_thr (float): Maximum score threshold of prediction map.
+ step (float): The spacing between score thresholds.
+ rank_list (str): json file used to save eval result
+ of each image after ranking.
+ Returns:
+ dict[dict[str: float]]: The evaluation results.
+ """
+ assert utils.is_type_list(results, dict)
+
+ metrics = metric if isinstance(metric, list) else [metric]
+ allowed_metrics = ['hmean-iou', 'hmean-ic13']
+ metrics = set(metrics) & set(allowed_metrics)
+
+ img_infos = []
+ ann_infos = []
+ for i in range(len(self)):
+ img_info = {'filename': self.data_infos[i]['file_name']}
+ img_infos.append(img_info)
+ ann_infos.append(self.get_ann_info(i))
+
+ eval_results = eval_hmean(
+ results,
+ img_infos,
+ ann_infos,
+ metrics=metrics,
+ score_thr=score_thr,
+ min_score_thr=min_score_thr,
+ max_score_thr=max_score_thr,
+ step=step,
+ logger=logger,
+ rank_list=rank_list)
+
+ return eval_results
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..ac3606371b7b032e08d9d52f334ffec5bea7e749
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/__init__.py
@@ -0,0 +1,37 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .box_utils import sort_vertex, sort_vertex8
+from .custom_format_bundle import CustomFormatBundle
+from .dbnet_transforms import EastRandomCrop, ImgAug
+# from .kie_transforms import KIEFormatBundle, ResizeNoImg
+from .loading import (LoadImageFromLMDB, LoadImageFromNdarray,
+ LoadTextAnnotations)
+# from .ner_transforms import NerTransform, ToTensorNER
+# from .ocr_seg_targets import OCRSegTargets
+# from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR,
+# OpencvToPil, PilToOpencv, RandomPaddingOCR,
+# RandomRotateImageBox, ResizeOCR, ToTensorOCR)
+# from .test_time_aug import MultiRotateAugOCR
+from .textdet_targets import DBNetTargets
+# from .textdet_targets import (DBNetTargets
+# # FCENetTargets, PANetTargets,
+# # TextSnakeTargets
+# )
+from .transform_wrappers import OneOfWrapper, RandomWrapper, TorchVisionWrapper
+from .transforms import (ColorJitter, PyramidRescale, RandomCropFlip,
+ RandomCropInstances, RandomCropPolyInstances,
+ RandomRotatePolyInstances, RandomRotateTextDet,
+ RandomScaling, ScaleAspectJitter, SquareResizePad)
+
+# __all__ = [
+# 'LoadTextAnnotations', 'NormalizeOCR', 'OnlineCropOCR', 'ResizeOCR',
+# 'ToTensorOCR', 'CustomFormatBundle', 'DBNetTargets', 'PANetTargets',
+# 'ColorJitter', 'RandomCropInstances', 'RandomRotateTextDet',
+# 'ScaleAspectJitter', 'MultiRotateAugOCR', 'OCRSegTargets', 'FancyPCA',
+# 'RandomCropPolyInstances', 'RandomRotatePolyInstances', 'RandomPaddingOCR',
+# 'ImgAug', 'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil',
+# 'PilToOpencv', 'KIEFormatBundle', 'SquareResizePad', 'TextSnakeTargets',
+# 'sort_vertex', 'LoadImageFromNdarray', 'sort_vertex8', 'FCENetTargets',
+# 'RandomScaling', 'RandomCropFlip', 'NerTransform', 'ToTensorNER',
+# 'ResizeNoImg', 'PyramidRescale', 'OneOfWrapper', 'RandomWrapper',
+# 'TorchVisionWrapper', 'LoadImageFromLMDB'
+# ]
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/box_utils.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/box_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..9e887e3e87334acf25315cecc5744b88d809d586
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/box_utils.py
@@ -0,0 +1,53 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+
+import dbnet.utils as utils
+
+
+def sort_vertex(points_x, points_y):
+ """Sort box vertices in clockwise order from left-top first.
+
+ Args:
+ points_x (list[float]): x of four vertices.
+ points_y (list[float]): y of four vertices.
+ Returns:
+ sorted_points_x (list[float]): x of sorted four vertices.
+ sorted_points_y (list[float]): y of sorted four vertices.
+ """
+ assert utils.is_type_list(points_x, (float, int))
+ assert utils.is_type_list(points_y, (float, int))
+ assert len(points_x) == 4
+ assert len(points_y) == 4
+ vertices = np.stack((points_x, points_y), axis=-1).astype(np.float32)
+ vertices = _sort_vertex(vertices)
+ sorted_points_x = list(vertices[:, 0])
+ sorted_points_y = list(vertices[:, 1])
+ return sorted_points_x, sorted_points_y
+
+
+def _sort_vertex(vertices):
+ assert vertices.ndim == 2
+ assert vertices.shape[-1] == 2
+ N = vertices.shape[0]
+ if N == 0:
+ return vertices
+
+ center = np.mean(vertices, axis=0)
+ directions = vertices - center
+ angles = np.arctan2(directions[:, 1], directions[:, 0])
+ sort_idx = np.argsort(angles)
+ vertices = vertices[sort_idx]
+
+ left_top = np.min(vertices, axis=0)
+ dists = np.linalg.norm(left_top - vertices, axis=-1, ord=2)
+ lefttop_idx = np.argmin(dists)
+ indexes = (np.arange(N, dtype=np.int) + lefttop_idx) % N
+ return vertices[indexes]
+
+
+def sort_vertex8(points):
+ """Sort vertex with 8 points [x1 y1 x2 y2 x3 y3 x4 y4]"""
+ assert len(points) == 8
+ vertices = _sort_vertex(np.array(points, dtype=np.float32).reshape(-1, 2))
+ sorted_box = list(vertices.flatten())
+ return sorted_box
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/crop.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/crop.py
new file mode 100755
index 0000000000000000000000000000000000000000..153365f55bd789578ff3b026930e566c7b5db620
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/crop.py
@@ -0,0 +1,125 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+from shapely.geometry import LineString, Point
+
+import dbnet.utils as utils
+from .box_utils import sort_vertex
+
+
+def box_jitter(points_x, points_y, jitter_ratio_x=0.5, jitter_ratio_y=0.1):
+ """Jitter on the coordinates of bounding box.
+
+ Args:
+ points_x (list[float | int]): List of y for four vertices.
+ points_y (list[float | int]): List of x for four vertices.
+ jitter_ratio_x (float): Horizontal jitter ratio relative to the height.
+ jitter_ratio_y (float): Vertical jitter ratio relative to the height.
+ """
+ assert len(points_x) == 4
+ assert len(points_y) == 4
+ assert isinstance(jitter_ratio_x, float)
+ assert isinstance(jitter_ratio_y, float)
+ assert 0 <= jitter_ratio_x < 1
+ assert 0 <= jitter_ratio_y < 1
+
+ points = [Point(points_x[i], points_y[i]) for i in range(4)]
+ line_list = [
+ LineString([points[i], points[i + 1 if i < 3 else 0]])
+ for i in range(4)
+ ]
+
+ tmp_h = max(line_list[1].length, line_list[3].length)
+
+ for i in range(4):
+ jitter_pixel_x = (np.random.rand() - 0.5) * 2 * jitter_ratio_x * tmp_h
+ jitter_pixel_y = (np.random.rand() - 0.5) * 2 * jitter_ratio_y * tmp_h
+ points_x[i] += jitter_pixel_x
+ points_y[i] += jitter_pixel_y
+
+
+def warp_img(src_img,
+ box,
+ jitter_flag=False,
+ jitter_ratio_x=0.5,
+ jitter_ratio_y=0.1):
+ """Crop box area from image using opencv warpPerspective w/o box jitter.
+
+ Args:
+ src_img (np.array): Image before cropping.
+ box (list[float | int]): Coordinates of quadrangle.
+ """
+ assert utils.is_type_list(box, (float, int))
+ assert len(box) == 8
+
+ h, w = src_img.shape[:2]
+ points_x = [min(max(x, 0), w) for x in box[0:8:2]]
+ points_y = [min(max(y, 0), h) for y in box[1:9:2]]
+
+ points_x, points_y = sort_vertex(points_x, points_y)
+
+ if jitter_flag:
+ box_jitter(
+ points_x,
+ points_y,
+ jitter_ratio_x=jitter_ratio_x,
+ jitter_ratio_y=jitter_ratio_y)
+
+ points = [Point(points_x[i], points_y[i]) for i in range(4)]
+ edges = [
+ LineString([points[i], points[i + 1 if i < 3 else 0]])
+ for i in range(4)
+ ]
+
+ pts1 = np.float32([[points[i].x, points[i].y] for i in range(4)])
+ box_width = max(edges[0].length, edges[2].length)
+ box_height = max(edges[1].length, edges[3].length)
+
+ pts2 = np.float32([[0, 0], [box_width, 0], [box_width, box_height],
+ [0, box_height]])
+ M = cv2.getPerspectiveTransform(pts1, pts2)
+ dst_img = cv2.warpPerspective(src_img, M,
+ (int(box_width), int(box_height)))
+
+ return dst_img
+
+
+def crop_img(src_img, box, long_edge_pad_ratio=0.4, short_edge_pad_ratio=0.2):
+ """Crop text region with their bounding box.
+
+ Args:
+ src_img (np.array): The original image.
+ box (list[float | int]): Points of quadrangle.
+ long_edge_pad_ratio (float): Box pad ratio for long edge
+ corresponding to font size.
+ short_edge_pad_ratio (float): Box pad ratio for short edge
+ corresponding to font size.
+ """
+ assert utils.is_type_list(box, (float, int))
+ assert len(box) == 8
+ assert 0. <= long_edge_pad_ratio < 1.0
+ assert 0. <= short_edge_pad_ratio < 1.0
+
+ h, w = src_img.shape[:2]
+ points_x = np.clip(np.array(box[0::2]), 0, w)
+ points_y = np.clip(np.array(box[1::2]), 0, h)
+
+ box_width = np.max(points_x) - np.min(points_x)
+ box_height = np.max(points_y) - np.min(points_y)
+ font_size = min(box_height, box_width)
+
+ if box_height < box_width:
+ horizontal_pad = long_edge_pad_ratio * font_size
+ vertical_pad = short_edge_pad_ratio * font_size
+ else:
+ horizontal_pad = short_edge_pad_ratio * font_size
+ vertical_pad = long_edge_pad_ratio * font_size
+
+ left = np.clip(int(np.min(points_x) - horizontal_pad), 0, w)
+ top = np.clip(int(np.min(points_y) - vertical_pad), 0, h)
+ right = np.clip(int(np.max(points_x) + horizontal_pad), 0, w)
+ bottom = np.clip(int(np.max(points_y) + vertical_pad), 0, h)
+
+ dst_img = src_img[top:bottom, left:right]
+
+ return dst_img
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/custom_format_bundle.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/custom_format_bundle.py
new file mode 100755
index 0000000000000000000000000000000000000000..e64938b4129332767e077bc0ecf495543b2375b9
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/custom_format_bundle.py
@@ -0,0 +1,66 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+from dbnet_cv.parallel import DataContainer as DC
+from dbnet_det.datasets.builder import PIPELINES
+from dbnet_det.datasets.pipelines.formatting import DefaultFormatBundle
+
+from dbnet.core.visualize import overlay_mask_img, show_feature
+
+
+@PIPELINES.register_module()
+class CustomFormatBundle(DefaultFormatBundle):
+ """Custom formatting bundle.
+
+ It formats common fields such as 'img' and 'proposals' as done in
+ DefaultFormatBundle, while other fields such as 'gt_kernels' and
+ 'gt_effective_region_mask' will be formatted to DC as follows:
+
+ - gt_kernels: to DataContainer (cpu_only=True)
+ - gt_effective_mask: to DataContainer (cpu_only=True)
+
+ Args:
+ keys (list[str]): Fields to be formatted to DC only.
+ call_super (bool): If True, format common fields
+ by DefaultFormatBundle, else format fields in keys above only.
+ visualize (dict): If flag=True, visualize gt mask for debugging.
+ """
+
+ def __init__(self,
+ keys=[],
+ call_super=True,
+ visualize=dict(flag=False, boundary_key=None)):
+
+ super().__init__()
+ self.visualize = visualize
+ self.keys = keys
+ self.call_super = call_super
+
+ def __call__(self, results):
+
+ if self.visualize['flag']:
+ img = results['img'].astype(np.uint8)
+ boundary_key = self.visualize['boundary_key']
+ if boundary_key is not None:
+ img = overlay_mask_img(img, results[boundary_key].masks[0])
+
+ features = [img]
+ names = ['img']
+ to_uint8 = [1]
+
+ for k in results['mask_fields']:
+ for iter in range(len(results[k].masks)):
+ features.append(results[k].masks[iter])
+ names.append(k + str(iter))
+ to_uint8.append(0)
+ show_feature(features, names, to_uint8)
+
+ if self.call_super:
+ results = super().__call__(results)
+
+ for k in self.keys:
+ results[k] = DC(results[k], cpu_only=True)
+
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/dbnet_transforms.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/dbnet_transforms.py
new file mode 100755
index 0000000000000000000000000000000000000000..006fc12aca59d0b241cda6291fad3a02333e344f
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/dbnet_transforms.py
@@ -0,0 +1,325 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import imgaug
+import imgaug.augmenters as iaa
+import dbnet_cv
+import numpy as np
+from dbnet_det.core.mask import PolygonMasks
+from dbnet_det.datasets.builder import PIPELINES
+
+
+class AugmenterBuilder:
+ """Build imgaug object according ImgAug argmentations."""
+
+ def __init__(self):
+ pass
+
+ def build(self, args, root=True):
+ if args is None:
+ return None
+ if isinstance(args, (int, float, str)):
+ return args
+ if isinstance(args, list):
+ if root:
+ sequence = [self.build(value, root=False) for value in args]
+ return iaa.Sequential(sequence)
+ arg_list = [self.to_tuple_if_list(a) for a in args[1:]]
+ return getattr(iaa, args[0])(*arg_list)
+ if isinstance(args, dict):
+ if 'cls' in args:
+ cls = getattr(iaa, args['cls'])
+ return cls(
+ **{
+ k: self.to_tuple_if_list(v)
+ for k, v in args.items() if not k == 'cls'
+ })
+ else:
+ return {
+ key: self.build(value, root=False)
+ for key, value in args.items()
+ }
+ raise RuntimeError('unknown augmenter arg: ' + str(args))
+
+ def to_tuple_if_list(self, obj):
+ if isinstance(obj, list):
+ return tuple(obj)
+ return obj
+
+
+@PIPELINES.register_module()
+class ImgAug:
+ """A wrapper to use imgaug https://github.com/aleju/imgaug.
+
+ Args:
+ args ([list[list|dict]]): The argumentation list. For details, please
+ refer to imgaug document. Take args=[['Fliplr', 0.5],
+ dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]] as an
+ example. The args horizontally flip images with probability 0.5,
+ followed by random rotation with angles in range [-10, 10], and
+ resize with an independent scale in range [0.5, 3.0] for each
+ side of images.
+ clip_invalid_polys (bool): Whether to clip invalid polygons after
+ transformation. False persists to the behavior in DBNet.
+ """
+
+ def __init__(self, args=None, clip_invalid_ploys=True):
+ self.augmenter_args = args
+ self.augmenter = AugmenterBuilder().build(self.augmenter_args)
+ self.clip_invalid_polys = clip_invalid_ploys
+
+ def __call__(self, results):
+ # img is bgr
+ image = results['img']
+ aug = None
+ shape = image.shape
+
+ if self.augmenter:
+ aug = self.augmenter.to_deterministic()
+ results['img'] = aug.augment_image(image)
+ results['img_shape'] = results['img'].shape
+ results['flip'] = 'unknown' # it's unknown
+ results['flip_direction'] = 'unknown' # it's unknown
+ target_shape = results['img_shape']
+
+ self.may_augment_annotation(aug, shape, target_shape, results)
+
+ return results
+
+ def may_augment_annotation(self, aug, shape, target_shape, results):
+ if aug is None:
+ return results
+
+ # augment polygon mask
+ for key in results['mask_fields']:
+ if self.clip_invalid_polys:
+ masks = self.may_augment_poly(aug, shape, results[key])
+ results[key] = PolygonMasks(masks, *target_shape[:2])
+ else:
+ masks = self.may_augment_poly_legacy(aug, shape, results[key])
+ if len(masks) > 0:
+ results[key] = PolygonMasks(masks, *target_shape[:2])
+
+ # augment bbox
+ for key in results['bbox_fields']:
+ bboxes = self.may_augment_bbox(aug, shape, results[key])
+ results[key] = np.zeros(0)
+ if len(bboxes) > 0:
+ results[key] = np.stack(bboxes)
+
+ return results
+
+ def may_augment_bbox(self, aug, ori_shape, bboxes):
+ imgaug_bboxes = []
+ for bbox in bboxes:
+ x1, y1, x2, y2 = bbox
+ imgaug_bboxes.append(
+ imgaug.BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2))
+ imgaug_bboxes = aug.augment_bounding_boxes([
+ imgaug.BoundingBoxesOnImage(imgaug_bboxes, shape=ori_shape)
+ ])[0].clip_out_of_image()
+
+ new_bboxes = []
+ for box in imgaug_bboxes.bounding_boxes:
+ new_bboxes.append(
+ np.array([box.x1, box.y1, box.x2, box.y2], dtype=np.float32))
+
+ return new_bboxes
+
+ def may_augment_poly(self, aug, img_shape, polys):
+ imgaug_polys = []
+ for poly in polys:
+ poly = poly[0]
+ poly = poly.reshape(-1, 2)
+ imgaug_polys.append(imgaug.Polygon(poly))
+ imgaug_polys = aug.augment_polygons(
+ [imgaug.PolygonsOnImage(imgaug_polys,
+ shape=img_shape)])[0].clip_out_of_image()
+
+ new_polys = []
+ for poly in imgaug_polys.polygons:
+ new_poly = []
+ for point in poly:
+ new_poly.append(np.array(point, dtype=np.float32))
+ new_poly = np.array(new_poly, dtype=np.float32).flatten()
+ new_polys.append([new_poly])
+
+ return new_polys
+
+ def may_augment_poly_legacy(self, aug, img_shape, polys):
+ key_points, poly_point_nums = [], []
+ for poly in polys:
+ poly = poly[0]
+ poly = poly.reshape(-1, 2)
+ key_points.extend([imgaug.Keypoint(p[0], p[1]) for p in poly])
+ poly_point_nums.append(poly.shape[0])
+ # Warning: we do not clip the out-of-boudnary polygons
+ key_points = aug.augment_keypoints(
+ [imgaug.KeypointsOnImage(keypoints=key_points,
+ shape=img_shape)])[0].keypoints
+
+ new_polys = []
+ start_idx = 0
+ for poly_point_num in poly_point_nums:
+ new_poly = []
+ for key_point in key_points[start_idx:(start_idx +
+ poly_point_num)]:
+ new_poly.append([key_point.x, key_point.y])
+ start_idx += poly_point_num
+ new_poly = np.array(new_poly).flatten()
+ new_polys.append([new_poly])
+
+ return new_polys
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class EastRandomCrop:
+
+ def __init__(self,
+ target_size=(640, 640),
+ max_tries=10,
+ min_crop_side_ratio=0.1):
+ self.target_size = target_size
+ self.max_tries = max_tries
+ self.min_crop_side_ratio = min_crop_side_ratio
+
+ def __call__(self, results):
+ # sampling crop
+ # crop image, boxes, masks
+ img = results['img']
+ crop_x, crop_y, crop_w, crop_h = self.crop_area(
+ img, results['gt_masks'])
+ scale_w = self.target_size[0] / crop_w
+ scale_h = self.target_size[1] / crop_h
+ scale = min(scale_w, scale_h)
+ h = int(crop_h * scale)
+ w = int(crop_w * scale)
+ padded_img = np.zeros(
+ (self.target_size[1], self.target_size[0], img.shape[2]),
+ img.dtype)
+ padded_img[:h, :w] = dbnet_cv.imresize(
+ img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
+
+ # for bboxes
+ for key in results['bbox_fields']:
+ lines = []
+ for box in results[key]:
+ box = box.reshape(2, 2)
+ poly = ((box - (crop_x, crop_y)) * scale)
+ if not self.is_poly_outside_rect(poly, 0, 0, w, h):
+ lines.append(poly.flatten())
+ results[key] = np.array(lines)
+ # for masks
+ for key in results['mask_fields']:
+ polys = []
+ polys_label = []
+ for poly in results[key]:
+ poly = np.array(poly).reshape(-1, 2)
+ poly = ((poly - (crop_x, crop_y)) * scale)
+ if not self.is_poly_outside_rect(poly, 0, 0, w, h):
+ polys.append([poly])
+ polys_label.append(0)
+ results[key] = PolygonMasks(polys, *self.target_size)
+ if key == 'gt_masks':
+ results['gt_labels'] = polys_label
+
+ results['img'] = padded_img
+ results['img_shape'] = padded_img.shape
+
+ return results
+
+ def is_poly_in_rect(self, poly, x, y, w, h):
+ poly = np.array(poly)
+ if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
+ return False
+ if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
+ return False
+ return True
+
+ def is_poly_outside_rect(self, poly, x, y, w, h):
+ poly = np.array(poly).reshape(-1, 2)
+ if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
+ return True
+ if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
+ return True
+ return False
+
+ def split_regions(self, axis):
+ regions = []
+ min_axis = 0
+ for i in range(1, axis.shape[0]):
+ if axis[i] != axis[i - 1] + 1:
+ region = axis[min_axis:i]
+ min_axis = i
+ regions.append(region)
+ return regions
+
+ def random_select(self, axis, max_size):
+ xx = np.random.choice(axis, size=2)
+ xmin = np.min(xx)
+ xmax = np.max(xx)
+ xmin = np.clip(xmin, 0, max_size - 1)
+ xmax = np.clip(xmax, 0, max_size - 1)
+ return xmin, xmax
+
+ def region_wise_random_select(self, regions):
+ selected_index = list(np.random.choice(len(regions), 2))
+ selected_values = []
+ for index in selected_index:
+ axis = regions[index]
+ xx = int(np.random.choice(axis, size=1))
+ selected_values.append(xx)
+ xmin = min(selected_values)
+ xmax = max(selected_values)
+ return xmin, xmax
+
+ def crop_area(self, img, polys):
+ h, w, _ = img.shape
+ h_array = np.zeros(h, dtype=np.int32)
+ w_array = np.zeros(w, dtype=np.int32)
+ for points in polys:
+ points = np.round(
+ points, decimals=0).astype(np.int32).reshape(-1, 2)
+ min_x = np.min(points[:, 0])
+ max_x = np.max(points[:, 0])
+ w_array[min_x:max_x] = 1
+ min_y = np.min(points[:, 1])
+ max_y = np.max(points[:, 1])
+ h_array[min_y:max_y] = 1
+ # ensure the cropped area not across a text
+ h_axis = np.where(h_array == 0)[0]
+ w_axis = np.where(w_array == 0)[0]
+
+ if len(h_axis) == 0 or len(w_axis) == 0:
+ return 0, 0, w, h
+
+ h_regions = self.split_regions(h_axis)
+ w_regions = self.split_regions(w_axis)
+
+ for i in range(self.max_tries):
+ if len(w_regions) > 1:
+ xmin, xmax = self.region_wise_random_select(w_regions)
+ else:
+ xmin, xmax = self.random_select(w_axis, w)
+ if len(h_regions) > 1:
+ ymin, ymax = self.region_wise_random_select(h_regions)
+ else:
+ ymin, ymax = self.random_select(h_axis, h)
+
+ if (xmax - xmin < self.min_crop_side_ratio * w
+ or ymax - ymin < self.min_crop_side_ratio * h):
+ # area too small
+ continue
+ num_poly_in_rect = 0
+ for poly in polys:
+ if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin,
+ ymax - ymin):
+ num_poly_in_rect += 1
+ break
+
+ if num_poly_in_rect > 0:
+ return xmin, ymin, xmax - xmin, ymax - ymin
+
+ return 0, 0, w, h
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/loading.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/loading.py
new file mode 100755
index 0000000000000000000000000000000000000000..d22aa0d8988b63e89636e773a2e399278fb93b65
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/loading.py
@@ -0,0 +1,189 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import lmdb
+import dbnet_cv
+import numpy as np
+from dbnet_det.core import BitmapMasks, PolygonMasks
+from dbnet_det.datasets.builder import PIPELINES
+from dbnet_det.datasets.pipelines.loading import LoadAnnotations, LoadImageFromFile
+
+
+@PIPELINES.register_module()
+class LoadTextAnnotations(LoadAnnotations):
+ """Load annotations for text detection.
+
+ Args:
+ with_bbox (bool): Whether to parse and load the bbox annotation.
+ Default: True.
+ with_label (bool): Whether to parse and load the label annotation.
+ Default: True.
+ with_mask (bool): Whether to parse and load the mask annotation.
+ Default: False.
+ with_seg (bool): Whether to parse and load the semantic segmentation
+ annotation. Default: False.
+ poly2mask (bool): Whether to convert the instance masks from polygons
+ to bitmaps. Default: True.
+ use_img_shape (bool): Use the shape of loaded image from
+ previous pipeline ``LoadImageFromFile`` to generate mask.
+ """
+
+ def __init__(self,
+ with_bbox=True,
+ with_label=True,
+ with_mask=False,
+ with_seg=False,
+ poly2mask=True,
+ use_img_shape=False):
+ super().__init__(
+ with_bbox=with_bbox,
+ with_label=with_label,
+ with_mask=with_mask,
+ with_seg=with_seg,
+ poly2mask=poly2mask)
+
+ self.use_img_shape = use_img_shape
+
+ def process_polygons(self, polygons):
+ """Convert polygons to list of ndarray and filter invalid polygons.
+
+ Args:
+ polygons (list[list]): Polygons of one instance.
+
+ Returns:
+ list[numpy.ndarray]: Processed polygons.
+ """
+
+ polygons = [np.array(p).astype(np.float32) for p in polygons]
+ valid_polygons = []
+ for polygon in polygons:
+ if len(polygon) % 2 == 0 and len(polygon) >= 6:
+ valid_polygons.append(polygon)
+ return valid_polygons
+
+ def _load_masks(self, results):
+ ann_info = results['ann_info']
+ h, w = results['img_info']['height'], results['img_info']['width']
+ if self.use_img_shape:
+ if results.get('ori_shape', None):
+ h, w = results['ori_shape'][:2]
+ results['img_info']['height'] = h
+ results['img_info']['width'] = w
+ else:
+ warnings.warn('"ori_shape" not in results, use the shape '
+ 'in "img_info" instead.')
+ gt_masks = ann_info['masks']
+ if self.poly2mask:
+ gt_masks = BitmapMasks(
+ [self._poly2mask(mask, h, w) for mask in gt_masks], h, w)
+ else:
+ gt_masks = PolygonMasks(
+ [self.process_polygons(polygons) for polygons in gt_masks], h,
+ w)
+ gt_masks_ignore = ann_info.get('masks_ignore', None)
+ if gt_masks_ignore is not None:
+ if self.poly2mask:
+ gt_masks_ignore = BitmapMasks(
+ [self._poly2mask(mask, h, w) for mask in gt_masks_ignore],
+ h, w)
+ else:
+ gt_masks_ignore = PolygonMasks([
+ self.process_polygons(polygons)
+ for polygons in gt_masks_ignore
+ ], h, w)
+ results['gt_masks_ignore'] = gt_masks_ignore
+ results['mask_fields'].append('gt_masks_ignore')
+
+ results['gt_masks'] = gt_masks
+ results['mask_fields'].append('gt_masks')
+ return results
+
+
+@PIPELINES.register_module()
+class LoadImageFromNdarray(LoadImageFromFile):
+ """Load an image from np.ndarray.
+
+ Similar with :obj:`LoadImageFromFile`, but the image read from
+ ``results['img']``, which is np.ndarray.
+ """
+
+ def __call__(self, results):
+ """Call functions to add image meta information.
+
+ Args:
+ results (dict): Result dict with Webcam read image in
+ ``results['img']``.
+
+ Returns:
+ dict: The dict contains loaded image and meta information.
+ """
+ assert results['img'].dtype == 'uint8'
+
+ img = results['img']
+ if self.color_type == 'grayscale' and img.shape[2] == 3:
+ img = dbnet_cv.bgr2gray(img, keepdim=True)
+ if self.color_type == 'color' and img.shape[2] == 1:
+ img = dbnet_cv.gray2bgr(img)
+ if self.to_float32:
+ img = img.astype(np.float32)
+
+ results['filename'] = None
+ results['ori_filename'] = None
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ results['img_fields'] = ['img']
+ return results
+
+
+@PIPELINES.register_module()
+class LoadImageFromLMDB(object):
+ """Load an image from lmdb file.
+
+ Similar with :obj:'LoadImageFromFile', but the image read from
+ "results['img_info']['filename']", which is a data index of lmdb file.
+ """
+
+ def __init__(self, color_type='color'):
+ self.color_type = color_type
+ self.env = None
+ self.txn = None
+
+ def __call__(self, results):
+ img_key = results['img_info']['filename']
+ lmdb_path = results['img_prefix']
+
+ # lmdb env
+ if self.env is None:
+ self.env = lmdb.open(
+ lmdb_path,
+ max_readers=1,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False,
+ )
+ # read image
+ with self.env.begin(write=False) as txn:
+ imgbuf = txn.get(img_key.encode('utf-8'))
+ try:
+ img = dbnet_cv.imfrombytes(imgbuf, flag=self.color_type)
+ except IOError:
+ print('Corrupted image for {}'.format(img_key))
+ return None
+
+ results['filename'] = img_key
+ results['ori_filename'] = img_key
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ results['img_fields'] = ['img']
+ return results
+
+ def __repr__(self):
+ return '{} (color_type={})'.format(self.__class__.__name__,
+ self.color_type)
+
+ def __del__(self):
+ if self.env is not None:
+ self.env.close()
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/textdet_targets/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/textdet_targets/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..ddbf10ed539f44db4f12f707c92dd4a8d2eec5f3
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/textdet_targets/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base_textdet_targets import BaseTextDetTargets
+from .dbnet_targets import DBNetTargets
+# from .drrg_targets import DRRGTargets
+# from .fcenet_targets import FCENetTargets
+# from .panet_targets import PANetTargets
+# from .psenet_targets import PSENetTargets
+# from .textsnake_targets import TextSnakeTargets
+
+# __all__ = [
+# 'BaseTextDetTargets', 'PANetTargets', 'PSENetTargets', 'DBNetTargets',
+# 'FCENetTargets', 'TextSnakeTargets', 'DRRGTargets'
+# ]
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/textdet_targets/base_textdet_targets.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/textdet_targets/base_textdet_targets.py
new file mode 100755
index 0000000000000000000000000000000000000000..a034ce77a9e5f058876ddec9cb99cb9bcbacd658
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/textdet_targets/base_textdet_targets.py
@@ -0,0 +1,168 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import sys
+
+import cv2
+import numpy as np
+import pyclipper
+from dbnet_cv.utils import print_log
+from shapely.geometry import Polygon as plg
+
+import dbnet.utils.check_argument as check_argument
+
+
+class BaseTextDetTargets:
+ """Generate text detector ground truths."""
+
+ def __init__(self):
+ pass
+
+ def point2line(self, xs, ys, point_1, point_2):
+ """Compute the distance from point to a line. This is adapted from
+ https://github.com/MhLiao/DB.
+
+ Args:
+ xs (ndarray): The x coordinates of size hxw.
+ ys (ndarray): The y coordinates of size hxw.
+ point_1 (ndarray): The first point with shape 1x2.
+ point_2 (ndarray): The second point with shape 1x2.
+
+ Returns:
+ result (ndarray): The distance matrix of size hxw.
+ """
+ # suppose a triangle with three edge abc with c=point_1 point_2
+ # a^2
+ a_square = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
+ # b^2
+ b_square = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
+ # c^2
+ c_square = np.square(point_1[0] - point_2[0]) + np.square(point_1[1] -
+ point_2[1])
+ # -cosC=(c^2-a^2-b^2)/2(ab)
+ neg_cos_c = (
+ (c_square - a_square - b_square) /
+ (np.finfo(np.float32).eps + 2 * np.sqrt(a_square * b_square)))
+ # sinC^2=1-cosC^2
+ square_sin = 1 - np.square(neg_cos_c)
+ square_sin = np.nan_to_num(square_sin)
+ # distance=a*b*sinC/c=a*h/c=2*area/c
+ result = np.sqrt(a_square * b_square * square_sin /
+ (np.finfo(np.float32).eps + c_square))
+ # set result to minimum edge if C 0:
+ padded_polygon = np.array(padded_polygon[0])
+ else:
+ print(f'padding {polygon} with {distance} gets {padded_polygon}')
+ padded_polygon = polygon.copy().astype(np.int32)
+
+ x_min = padded_polygon[:, 0].min()
+ x_max = padded_polygon[:, 0].max()
+ y_min = padded_polygon[:, 1].min()
+ y_max = padded_polygon[:, 1].max()
+
+ width = x_max - x_min + 1
+ height = y_max - y_min + 1
+
+ polygon[:, 0] = polygon[:, 0] - x_min
+ polygon[:, 1] = polygon[:, 1] - y_min
+
+ xs = np.broadcast_to(
+ np.linspace(0, width - 1, num=width).reshape(1, width),
+ (height, width))
+ ys = np.broadcast_to(
+ np.linspace(0, height - 1, num=height).reshape(height, 1),
+ (height, width))
+
+ distance_map = np.zeros((polygon.shape[0], height, width),
+ dtype=np.float32)
+ for i in range(polygon.shape[0]):
+ j = (i + 1) % polygon.shape[0]
+ absolute_distance = self.point2line(xs, ys, polygon[i], polygon[j])
+ distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
+ distance_map = distance_map.min(axis=0)
+
+ x_min_valid = min(max(0, x_min), canvas.shape[1] - 1)
+ x_max_valid = min(max(0, x_max), canvas.shape[1] - 1)
+ y_min_valid = min(max(0, y_min), canvas.shape[0] - 1)
+ y_max_valid = min(max(0, y_max), canvas.shape[0] - 1)
+
+ if x_min_valid - x_min >= width or y_min_valid - y_min >= height:
+ return
+
+ cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
+ canvas[y_min_valid:y_max_valid + 1,
+ x_min_valid:x_max_valid + 1] = np.fmax(
+ 1 - distance_map[y_min_valid - y_min:y_max_valid - y_max +
+ height, x_min_valid - x_min:x_max_valid -
+ x_max + width],
+ canvas[y_min_valid:y_max_valid + 1,
+ x_min_valid:x_max_valid + 1])
+
+ def generate_targets(self, results):
+ """Generate the gt targets for DBNet.
+
+ Args:
+ results (dict): The input result dictionary.
+
+ Returns:
+ results (dict): The output result dictionary.
+ """
+ assert isinstance(results, dict)
+
+ if 'bbox_fields' in results:
+ results['bbox_fields'].clear()
+
+ ignore_tags = self.find_invalid(results)
+ results, ignore_tags = self.ignore_texts(results, ignore_tags)
+
+ h, w, _ = results['img_shape']
+ polygons = results['gt_masks'].masks
+
+ # generate gt_shrink_kernel
+ gt_shrink, ignore_tags = self.generate_kernels((h, w),
+ polygons,
+ self.shrink_ratio,
+ ignore_tags=ignore_tags)
+
+ results, ignore_tags = self.ignore_texts(results, ignore_tags)
+ # genenrate gt_shrink_mask
+ polygons_ignore = results['gt_masks_ignore'].masks
+ gt_shrink_mask = self.generate_effective_mask((h, w), polygons_ignore)
+
+ # generate gt_threshold and gt_threshold_mask
+ polygons = results['gt_masks'].masks
+ gt_thr, gt_thr_mask = self.generate_thr_map((h, w), polygons)
+
+ results['mask_fields'].clear() # rm gt_masks encoded by polygons
+ results.pop('gt_labels', None)
+ results.pop('gt_masks', None)
+ results.pop('gt_bboxes', None)
+ results.pop('gt_bboxes_ignore', None)
+
+ mapping = {
+ 'gt_shrink': gt_shrink,
+ 'gt_shrink_mask': gt_shrink_mask,
+ 'gt_thr': gt_thr,
+ 'gt_thr_mask': gt_thr_mask
+ }
+ for key, value in mapping.items():
+ value = value if isinstance(value, list) else [value]
+ results[key] = BitmapMasks(value, h, w)
+ results['mask_fields'].append(key)
+
+ return results
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/transform_wrappers.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/transform_wrappers.py
new file mode 100755
index 0000000000000000000000000000000000000000..dc5af00e2d22af3ae6824772571f7b009b51ba37
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/transform_wrappers.py
@@ -0,0 +1,128 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import random
+
+import dbnet_cv
+import numpy as np
+import torchvision.transforms as torchvision_transforms
+from dbnet_cv.utils import build_from_cfg
+from dbnet_det.datasets.builder import PIPELINES
+from dbnet_det.datasets.pipelines import Compose
+from PIL import Image
+
+
+@PIPELINES.register_module()
+class OneOfWrapper:
+ """Randomly select and apply one of the transforms, each with the equal
+ chance.
+
+ Warning:
+ Different from albumentations, this wrapper only runs the selected
+ transform, but doesn't guarantee the transform can always be applied to
+ the input if the transform comes with a probability to run.
+
+ Args:
+ transforms (list[dict|callable]): Candidate transforms to be applied.
+ """
+
+ def __init__(self, transforms):
+ assert isinstance(transforms, list) or isinstance(transforms, tuple)
+ assert len(transforms) > 0, 'Need at least one transform.'
+ self.transforms = []
+ for t in transforms:
+ if isinstance(t, dict):
+ self.transforms.append(build_from_cfg(t, PIPELINES))
+ elif callable(t):
+ self.transforms.append(t)
+ else:
+ raise TypeError('transform must be callable or a dict')
+
+ def __call__(self, results):
+ return random.choice(self.transforms)(results)
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(transforms={self.transforms})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomWrapper:
+ """Run a transform or a sequence of transforms with probability p.
+
+ Args:
+ transforms (list[dict|callable]): Transform(s) to be applied.
+ p (int|float): Probability of running transform(s).
+ """
+
+ def __init__(self, transforms, p):
+ assert 0 <= p <= 1
+ self.transforms = Compose(transforms)
+ self.p = p
+
+ def __call__(self, results):
+ return results if np.random.uniform() > self.p else self.transforms(
+ results)
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(transforms={self.transforms}, '
+ repr_str += f'p={self.p})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class TorchVisionWrapper:
+ """A wrapper of torchvision trasnforms. It applies specific transform to
+ ``img`` and updates ``img_shape`` accordingly.
+
+ Warning:
+ This transform only affects the image but not its associated
+ annotations, such as word bounding boxes and polygon masks. Therefore,
+ it may only be applicable to text recognition tasks.
+
+ Args:
+ op (str): The name of any transform class in
+ :func:`torchvision.transforms`.
+ **kwargs: Arguments that will be passed to initializer of torchvision
+ transform.
+
+ :Required Keys:
+ - | ``img`` (ndarray): The input image.
+
+ :Affected Keys:
+ :Modified:
+ - | ``img`` (ndarray): The modified image.
+ :Added:
+ - | ``img_shape`` (tuple(int)): Size of the modified image.
+ """
+
+ def __init__(self, op, **kwargs):
+ assert type(op) is str
+
+ if dbnet_cv.is_str(op):
+ obj_cls = getattr(torchvision_transforms, op)
+ elif inspect.isclass(op):
+ obj_cls = op
+ else:
+ raise TypeError(
+ f'type must be a str or valid type, but got {type(type)}')
+ self.transform = obj_cls(**kwargs)
+ self.kwargs = kwargs
+
+ def __call__(self, results):
+ assert 'img' in results
+ # BGR -> RGB
+ img = results['img'][..., ::-1]
+ img = Image.fromarray(img)
+ img = self.transform(img)
+ img = np.asarray(img)
+ img = img[..., ::-1]
+ results['img'] = img
+ results['img_shape'] = img.shape
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(transform={self.transform})'
+ return repr_str
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/transforms.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/transforms.py
new file mode 100755
index 0000000000000000000000000000000000000000..e4a7a08ff976356fc76154349d530aaad3bfeb85
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/transforms.py
@@ -0,0 +1,1020 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import cv2
+import dbnet_cv
+import numpy as np
+import torchvision.transforms as transforms
+from dbnet_det.core import BitmapMasks, PolygonMasks
+from dbnet_det.datasets.builder import PIPELINES
+from dbnet_det.datasets.pipelines.transforms import Resize
+from PIL import Image
+from shapely.geometry import Polygon as plg
+
+import dbnet.core.evaluation.utils as eval_utils
+from dbnet.utils import check_argument
+
+
+@PIPELINES.register_module()
+class RandomCropInstances:
+ """Randomly crop images and make sure to contain text instances.
+
+ Args:
+ target_size (tuple or int): (height, width)
+ positive_sample_ratio (float): The probability of sampling regions
+ that go through positive regions.
+ """
+
+ def __init__(
+ self,
+ target_size,
+ instance_key,
+ mask_type='inx0', # 'inx0' or 'union_all'
+ positive_sample_ratio=5.0 / 8.0):
+
+ assert mask_type in ['inx0', 'union_all']
+
+ self.mask_type = mask_type
+ self.instance_key = instance_key
+ self.positive_sample_ratio = positive_sample_ratio
+ self.target_size = target_size if (target_size is None or isinstance(
+ target_size, tuple)) else (target_size, target_size)
+
+ def sample_offset(self, img_gt, img_size):
+ h, w = img_size
+ t_h, t_w = self.target_size
+
+ # target size is bigger than origin size
+ t_h = t_h if t_h < h else h
+ t_w = t_w if t_w < w else w
+ if (img_gt is not None
+ and np.random.random_sample() < self.positive_sample_ratio
+ and np.max(img_gt) > 0):
+
+ # make sure to crop the positive region
+
+ # the minimum top left to crop positive region (h,w)
+ tl = np.min(np.where(img_gt > 0), axis=1) - (t_h, t_w)
+ tl[tl < 0] = 0
+ # the maximum top left to crop positive region
+ br = np.max(np.where(img_gt > 0), axis=1) - (t_h, t_w)
+ br[br < 0] = 0
+ # if br is too big so that crop the outside region of img
+ br[0] = min(br[0], h - t_h)
+ br[1] = min(br[1], w - t_w)
+ #
+ h = np.random.randint(tl[0], br[0]) if tl[0] < br[0] else 0
+ w = np.random.randint(tl[1], br[1]) if tl[1] < br[1] else 0
+ else:
+ # make sure not to crop outside of img
+
+ h = np.random.randint(0, h - t_h) if h - t_h > 0 else 0
+ w = np.random.randint(0, w - t_w) if w - t_w > 0 else 0
+
+ return (h, w)
+
+ @staticmethod
+ def crop_img(img, offset, target_size):
+ h, w = img.shape[:2]
+ br = np.min(
+ np.stack((np.array(offset) + np.array(target_size), np.array(
+ (h, w)))),
+ axis=0)
+ return img[offset[0]:br[0], offset[1]:br[1]], np.array(
+ [offset[1], offset[0], br[1], br[0]])
+
+ def crop_bboxes(self, bboxes, canvas_bbox):
+ kept_bboxes = []
+ kept_inx = []
+ canvas_poly = eval_utils.box2polygon(canvas_bbox)
+ tl = canvas_bbox[0:2]
+
+ for idx, bbox in enumerate(bboxes):
+ poly = eval_utils.box2polygon(bbox)
+ area, inters = eval_utils.poly_intersection(
+ poly, canvas_poly, return_poly=True)
+ if area == 0:
+ continue
+ xmin, ymin, xmax, ymax = inters.bounds
+ kept_bboxes += [
+ np.array(
+ [xmin - tl[0], ymin - tl[1], xmax - tl[0], ymax - tl[1]],
+ dtype=np.float32)
+ ]
+ kept_inx += [idx]
+
+ if len(kept_inx) == 0:
+ return np.array([]).astype(np.float32).reshape(0, 4), kept_inx
+
+ return np.stack(kept_bboxes), kept_inx
+
+ @staticmethod
+ def generate_mask(gt_mask, type):
+
+ if type == 'inx0':
+ return gt_mask.masks[0]
+ if type == 'union_all':
+ mask = gt_mask.masks[0].copy()
+ for idx in range(1, len(gt_mask.masks)):
+ mask = np.logical_or(mask, gt_mask.masks[idx])
+ return mask
+
+ raise NotImplementedError
+
+ def __call__(self, results):
+
+ gt_mask = results[self.instance_key]
+ mask = None
+ if len(gt_mask.masks) > 0:
+ mask = self.generate_mask(gt_mask, self.mask_type)
+ results['crop_offset'] = self.sample_offset(mask,
+ results['img'].shape[:2])
+
+ # crop img. bbox = [x1,y1,x2,y2]
+ img, bbox = self.crop_img(results['img'], results['crop_offset'],
+ self.target_size)
+ results['img'] = img
+ img_shape = img.shape
+ results['img_shape'] = img_shape
+
+ # crop masks
+ for key in results.get('mask_fields', []):
+ results[key] = results[key].crop(bbox)
+
+ # for mask rcnn
+ for key in results.get('bbox_fields', []):
+ results[key], kept_inx = self.crop_bboxes(results[key], bbox)
+ if key == 'gt_bboxes':
+ # ignore gt_labels accordingly
+ if 'gt_labels' in results:
+ ori_labels = results['gt_labels']
+ ori_inst_num = len(ori_labels)
+ results['gt_labels'] = [
+ ori_labels[idx] for idx in range(ori_inst_num)
+ if idx in kept_inx
+ ]
+ # ignore g_masks accordingly
+ if 'gt_masks' in results:
+ ori_mask = results['gt_masks'].masks
+ kept_mask = [
+ ori_mask[idx] for idx in range(ori_inst_num)
+ if idx in kept_inx
+ ]
+ target_h, target_w = bbox[3] - bbox[1], bbox[2] - bbox[0]
+ if len(kept_inx) > 0:
+ kept_mask = np.stack(kept_mask)
+ else:
+ kept_mask = np.empty((0, target_h, target_w),
+ dtype=np.float32)
+ results['gt_masks'] = BitmapMasks(kept_mask, target_h,
+ target_w)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomRotateTextDet:
+ """Randomly rotate images."""
+
+ def __init__(self, rotate_ratio=1.0, max_angle=10):
+ self.rotate_ratio = rotate_ratio
+ self.max_angle = max_angle
+
+ @staticmethod
+ def sample_angle(max_angle):
+ angle = np.random.random_sample() * 2 * max_angle - max_angle
+ return angle
+
+ @staticmethod
+ def rotate_img(img, angle):
+ h, w = img.shape[:2]
+ rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
+ img_target = cv2.warpAffine(
+ img, rotation_matrix, (w, h), flags=cv2.INTER_NEAREST)
+ assert img_target.shape == img.shape
+ return img_target
+
+ def __call__(self, results):
+ if np.random.random_sample() < self.rotate_ratio:
+ # rotate imgs
+ results['rotated_angle'] = self.sample_angle(self.max_angle)
+ img = self.rotate_img(results['img'], results['rotated_angle'])
+ results['img'] = img
+ img_shape = img.shape
+ results['img_shape'] = img_shape
+
+ # rotate masks
+ for key in results.get('mask_fields', []):
+ masks = results[key].masks
+ mask_list = []
+ for m in masks:
+ rotated_m = self.rotate_img(m, results['rotated_angle'])
+ mask_list.append(rotated_m)
+ results[key] = BitmapMasks(mask_list, *(img_shape[:2]))
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class ColorJitter:
+ """An interface for torch color jitter so that it can be invoked in
+ dbnet_detection pipeline."""
+
+ def __init__(self, **kwargs):
+ self.transform = transforms.ColorJitter(**kwargs)
+
+ def __call__(self, results):
+ # img is bgr
+ img = results['img'][..., ::-1]
+ img = Image.fromarray(img)
+ img = self.transform(img)
+ img = np.asarray(img)
+ img = img[..., ::-1]
+ results['img'] = img
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class ScaleAspectJitter(Resize):
+ """Resize image and segmentation mask encoded by coordinates.
+
+ Allowed resize types are `around_min_img_scale`, `long_short_bound`, and
+ `indep_sample_in_range`.
+ """
+
+ def __init__(self,
+ img_scale=None,
+ multiscale_mode='range',
+ ratio_range=None,
+ keep_ratio=False,
+ resize_type='around_min_img_scale',
+ aspect_ratio_range=None,
+ long_size_bound=None,
+ short_size_bound=None,
+ scale_range=None):
+ super().__init__(
+ img_scale=img_scale,
+ multiscale_mode=multiscale_mode,
+ ratio_range=ratio_range,
+ keep_ratio=keep_ratio)
+ assert not keep_ratio
+ assert resize_type in [
+ 'around_min_img_scale', 'long_short_bound', 'indep_sample_in_range'
+ ]
+ self.resize_type = resize_type
+
+ if resize_type == 'indep_sample_in_range':
+ assert ratio_range is None
+ assert aspect_ratio_range is None
+ assert short_size_bound is None
+ assert long_size_bound is None
+ assert scale_range is not None
+ else:
+ assert scale_range is None
+ assert isinstance(ratio_range, tuple)
+ assert isinstance(aspect_ratio_range, tuple)
+ assert check_argument.equal_len(ratio_range, aspect_ratio_range)
+
+ if resize_type in ['long_short_bound']:
+ assert short_size_bound is not None
+ assert long_size_bound is not None
+
+ self.aspect_ratio_range = aspect_ratio_range
+ self.long_size_bound = long_size_bound
+ self.short_size_bound = short_size_bound
+ self.scale_range = scale_range
+
+ @staticmethod
+ def sample_from_range(range):
+ assert len(range) == 2
+ min_value, max_value = min(range), max(range)
+ value = np.random.random_sample() * (max_value - min_value) + min_value
+
+ return value
+
+ def _random_scale(self, results):
+
+ if self.resize_type == 'indep_sample_in_range':
+ w = self.sample_from_range(self.scale_range)
+ h = self.sample_from_range(self.scale_range)
+ results['scale'] = (int(w), int(h)) # (w,h)
+ results['scale_idx'] = None
+ return
+ h, w = results['img'].shape[0:2]
+ if self.resize_type == 'long_short_bound':
+ scale1 = 1
+ if max(h, w) > self.long_size_bound:
+ scale1 = self.long_size_bound / max(h, w)
+ scale2 = self.sample_from_range(self.ratio_range)
+ scale = scale1 * scale2
+ if min(h, w) * scale <= self.short_size_bound:
+ scale = (self.short_size_bound + 10) * 1.0 / min(h, w)
+ elif self.resize_type == 'around_min_img_scale':
+ short_size = min(self.img_scale[0])
+ ratio = self.sample_from_range(self.ratio_range)
+ scale = (ratio * short_size) / min(h, w)
+ else:
+ raise NotImplementedError
+
+ aspect = self.sample_from_range(self.aspect_ratio_range)
+ h_scale = scale * math.sqrt(aspect)
+ w_scale = scale / math.sqrt(aspect)
+ results['scale'] = (int(w * w_scale), int(h * h_scale)) # (w,h)
+ results['scale_idx'] = None
+
+
+@PIPELINES.register_module()
+class AffineJitter:
+ """An interface for torchvision random affine so that it can be invoked in
+ dbnet_det pipeline."""
+
+ def __init__(self,
+ degrees=4,
+ translate=(0.02, 0.04),
+ scale=(0.9, 1.1),
+ shear=None,
+ resample=False,
+ fillcolor=0):
+ self.transform = transforms.RandomAffine(
+ degrees=degrees,
+ translate=translate,
+ scale=scale,
+ shear=shear,
+ resample=resample,
+ fillcolor=fillcolor)
+
+ def __call__(self, results):
+ # img is bgr
+ img = results['img'][..., ::-1]
+ img = Image.fromarray(img)
+ img = self.transform(img)
+ img = np.asarray(img)
+ img = img[..., ::-1]
+ results['img'] = img
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomCropPolyInstances:
+ """Randomly crop images and make sure to contain at least one intact
+ instance."""
+
+ def __init__(self,
+ instance_key='gt_masks',
+ crop_ratio=5.0 / 8.0,
+ min_side_ratio=0.4):
+ super().__init__()
+ self.instance_key = instance_key
+ self.crop_ratio = crop_ratio
+ self.min_side_ratio = min_side_ratio
+
+ def sample_valid_start_end(self, valid_array, min_len, max_start, min_end):
+
+ assert isinstance(min_len, int)
+ assert len(valid_array) > min_len
+
+ start_array = valid_array.copy()
+ max_start = min(len(start_array) - min_len, max_start)
+ start_array[max_start:] = 0
+ start_array[0] = 1
+ diff_array = np.hstack([0, start_array]) - np.hstack([start_array, 0])
+ region_starts = np.where(diff_array < 0)[0]
+ region_ends = np.where(diff_array > 0)[0]
+ region_ind = np.random.randint(0, len(region_starts))
+ start = np.random.randint(region_starts[region_ind],
+ region_ends[region_ind])
+
+ end_array = valid_array.copy()
+ min_end = max(start + min_len, min_end)
+ end_array[:min_end] = 0
+ end_array[-1] = 1
+ diff_array = np.hstack([0, end_array]) - np.hstack([end_array, 0])
+ region_starts = np.where(diff_array < 0)[0]
+ region_ends = np.where(diff_array > 0)[0]
+ region_ind = np.random.randint(0, len(region_starts))
+ end = np.random.randint(region_starts[region_ind],
+ region_ends[region_ind])
+ return start, end
+
+ def sample_crop_box(self, img_size, results):
+ """Generate crop box and make sure not to crop the polygon instances.
+
+ Args:
+ img_size (tuple(int)): The image size (h, w).
+ results (dict): The results dict.
+ """
+
+ assert isinstance(img_size, tuple)
+ h, w = img_size[:2]
+
+ key_masks = results[self.instance_key].masks
+ x_valid_array = np.ones(w, dtype=np.int32)
+ y_valid_array = np.ones(h, dtype=np.int32)
+
+ selected_mask = key_masks[np.random.randint(0, len(key_masks))]
+ selected_mask = selected_mask[0].reshape((-1, 2)).astype(np.int32)
+ max_x_start = max(np.min(selected_mask[:, 0]) - 2, 0)
+ min_x_end = min(np.max(selected_mask[:, 0]) + 3, w - 1)
+ max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0)
+ min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1)
+
+ for key in results.get('mask_fields', []):
+ if len(results[key].masks) == 0:
+ continue
+ masks = results[key].masks
+ for mask in masks:
+ assert len(mask) == 1
+ mask = mask[0].reshape((-1, 2)).astype(np.int32)
+ clip_x = np.clip(mask[:, 0], 0, w - 1)
+ clip_y = np.clip(mask[:, 1], 0, h - 1)
+ min_x, max_x = np.min(clip_x), np.max(clip_x)
+ min_y, max_y = np.min(clip_y), np.max(clip_y)
+
+ x_valid_array[min_x - 2:max_x + 3] = 0
+ y_valid_array[min_y - 2:max_y + 3] = 0
+
+ min_w = int(w * self.min_side_ratio)
+ min_h = int(h * self.min_side_ratio)
+
+ x1, x2 = self.sample_valid_start_end(x_valid_array, min_w, max_x_start,
+ min_x_end)
+ y1, y2 = self.sample_valid_start_end(y_valid_array, min_h, max_y_start,
+ min_y_end)
+
+ return np.array([x1, y1, x2, y2])
+
+ def crop_img(self, img, bbox):
+ assert img.ndim == 3
+ h, w, _ = img.shape
+ assert 0 <= bbox[1] < bbox[3] <= h
+ assert 0 <= bbox[0] < bbox[2] <= w
+ return img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
+
+ def __call__(self, results):
+ if len(results[self.instance_key].masks) < 1:
+ return results
+ if np.random.random_sample() < self.crop_ratio:
+ crop_box = self.sample_crop_box(results['img'].shape, results)
+ results['crop_region'] = crop_box
+ img = self.crop_img(results['img'], crop_box)
+ results['img'] = img
+ results['img_shape'] = img.shape
+
+ # crop and filter masks
+ x1, y1, x2, y2 = crop_box
+ w = max(x2 - x1, 1)
+ h = max(y2 - y1, 1)
+ labels = results['gt_labels']
+ valid_labels = []
+ for key in results.get('mask_fields', []):
+ if len(results[key].masks) == 0:
+ continue
+ results[key] = results[key].crop(crop_box)
+ # filter out polygons beyond crop box.
+ masks = results[key].masks
+ valid_masks_list = []
+
+ for ind, mask in enumerate(masks):
+ assert len(mask) == 1
+ polygon = mask[0].reshape((-1, 2))
+ if (polygon[:, 0] >
+ -4).all() and (polygon[:, 0] < w + 4).all() and (
+ polygon[:, 1] > -4).all() and (polygon[:, 1] <
+ h + 4).all():
+ mask[0][::2] = np.clip(mask[0][::2], 0, w)
+ mask[0][1::2] = np.clip(mask[0][1::2], 0, h)
+ if key == self.instance_key:
+ valid_labels.append(labels[ind])
+ valid_masks_list.append(mask)
+
+ results[key] = PolygonMasks(valid_masks_list, h, w)
+ results['gt_labels'] = np.array(valid_labels)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomRotatePolyInstances:
+
+ def __init__(self,
+ rotate_ratio=0.5,
+ max_angle=10,
+ pad_with_fixed_color=False,
+ pad_value=(0, 0, 0)):
+ """Randomly rotate images and polygon masks.
+
+ Args:
+ rotate_ratio (float): The ratio of samples to operate rotation.
+ max_angle (int): The maximum rotation angle.
+ pad_with_fixed_color (bool): The flag for whether to pad rotated
+ image with fixed value. If set to False, the rotated image will
+ be padded onto cropped image.
+ pad_value (tuple(int)): The color value for padding rotated image.
+ """
+ self.rotate_ratio = rotate_ratio
+ self.max_angle = max_angle
+ self.pad_with_fixed_color = pad_with_fixed_color
+ self.pad_value = pad_value
+
+ def rotate(self, center, points, theta, center_shift=(0, 0)):
+ # rotate points.
+ (center_x, center_y) = center
+ center_y = -center_y
+ x, y = points[::2], points[1::2]
+ y = -y
+
+ theta = theta / 180 * math.pi
+ cos = math.cos(theta)
+ sin = math.sin(theta)
+
+ x = (x - center_x)
+ y = (y - center_y)
+
+ _x = center_x + x * cos - y * sin + center_shift[0]
+ _y = -(center_y + x * sin + y * cos) + center_shift[1]
+
+ points[::2], points[1::2] = _x, _y
+ return points
+
+ def cal_canvas_size(self, ori_size, degree):
+ assert isinstance(ori_size, tuple)
+ angle = degree * math.pi / 180.0
+ h, w = ori_size[:2]
+
+ cos = math.cos(angle)
+ sin = math.sin(angle)
+ canvas_h = int(w * math.fabs(sin) + h * math.fabs(cos))
+ canvas_w = int(w * math.fabs(cos) + h * math.fabs(sin))
+
+ canvas_size = (canvas_h, canvas_w)
+ return canvas_size
+
+ def sample_angle(self, max_angle):
+ angle = np.random.random_sample() * 2 * max_angle - max_angle
+ return angle
+
+ def rotate_img(self, img, angle, canvas_size):
+ h, w = img.shape[:2]
+ rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
+ rotation_matrix[0, 2] += int((canvas_size[1] - w) / 2)
+ rotation_matrix[1, 2] += int((canvas_size[0] - h) / 2)
+
+ if self.pad_with_fixed_color:
+ target_img = cv2.warpAffine(
+ img,
+ rotation_matrix, (canvas_size[1], canvas_size[0]),
+ flags=cv2.INTER_NEAREST,
+ borderValue=self.pad_value)
+ else:
+ mask = np.zeros_like(img)
+ (h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
+ np.random.randint(0, w * 7 // 8))
+ img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
+ img_cut = dbnet_cv.imresize(img_cut, (canvas_size[1], canvas_size[0]))
+ mask = cv2.warpAffine(
+ mask,
+ rotation_matrix, (canvas_size[1], canvas_size[0]),
+ borderValue=[1, 1, 1])
+ target_img = cv2.warpAffine(
+ img,
+ rotation_matrix, (canvas_size[1], canvas_size[0]),
+ borderValue=[0, 0, 0])
+ target_img = target_img + img_cut * mask
+
+ return target_img
+
+ def __call__(self, results):
+ if np.random.random_sample() < self.rotate_ratio:
+ img = results['img']
+ h, w = img.shape[:2]
+ angle = self.sample_angle(self.max_angle)
+ canvas_size = self.cal_canvas_size((h, w), angle)
+ center_shift = (int(
+ (canvas_size[1] - w) / 2), int((canvas_size[0] - h) / 2))
+
+ # rotate image
+ results['rotated_poly_angle'] = angle
+ img = self.rotate_img(img, angle, canvas_size)
+ results['img'] = img
+ img_shape = img.shape
+ results['img_shape'] = img_shape
+
+ # rotate polygons
+ for key in results.get('mask_fields', []):
+ if len(results[key].masks) == 0:
+ continue
+ masks = results[key].masks
+ rotated_masks = []
+ for mask in masks:
+ rotated_mask = self.rotate((w / 2, h / 2), mask[0], angle,
+ center_shift)
+ rotated_masks.append([rotated_mask])
+
+ results[key] = PolygonMasks(rotated_masks, *(img_shape[:2]))
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class SquareResizePad:
+
+ def __init__(self,
+ target_size,
+ pad_ratio=0.6,
+ pad_with_fixed_color=False,
+ pad_value=(0, 0, 0)):
+ """Resize or pad images to be square shape.
+
+ Args:
+ target_size (int): The target size of square shaped image.
+ pad_with_fixed_color (bool): The flag for whether to pad rotated
+ image with fixed value. If set to False, the rescales image will
+ be padded onto cropped image.
+ pad_value (tuple(int)): The color value for padding rotated image.
+ """
+ assert isinstance(target_size, int)
+ assert isinstance(pad_ratio, float)
+ assert isinstance(pad_with_fixed_color, bool)
+ assert isinstance(pad_value, tuple)
+
+ self.target_size = target_size
+ self.pad_ratio = pad_ratio
+ self.pad_with_fixed_color = pad_with_fixed_color
+ self.pad_value = pad_value
+
+ def resize_img(self, img, keep_ratio=True):
+ h, w, _ = img.shape
+ if keep_ratio:
+ t_h = self.target_size if h >= w else int(h * self.target_size / w)
+ t_w = self.target_size if h <= w else int(w * self.target_size / h)
+ else:
+ t_h = t_w = self.target_size
+ img = dbnet_cv.imresize(img, (t_w, t_h))
+ return img, (t_h, t_w)
+
+ def square_pad(self, img):
+ h, w = img.shape[:2]
+ if h == w:
+ return img, (0, 0)
+ pad_size = max(h, w)
+ if self.pad_with_fixed_color:
+ expand_img = np.ones((pad_size, pad_size, 3), dtype=np.uint8)
+ expand_img[:] = self.pad_value
+ else:
+ (h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
+ np.random.randint(0, w * 7 // 8))
+ img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
+ expand_img = dbnet_cv.imresize(img_cut, (pad_size, pad_size))
+ if h > w:
+ y0, x0 = 0, (h - w) // 2
+ else:
+ y0, x0 = (w - h) // 2, 0
+ expand_img[y0:y0 + h, x0:x0 + w] = img
+ offset = (x0, y0)
+
+ return expand_img, offset
+
+ def square_pad_mask(self, points, offset):
+ x0, y0 = offset
+ pad_points = points.copy()
+ pad_points[::2] = pad_points[::2] + x0
+ pad_points[1::2] = pad_points[1::2] + y0
+ return pad_points
+
+ def __call__(self, results):
+ img = results['img']
+
+ if np.random.random_sample() < self.pad_ratio:
+ img, out_size = self.resize_img(img, keep_ratio=True)
+ img, offset = self.square_pad(img)
+ else:
+ img, out_size = self.resize_img(img, keep_ratio=False)
+ offset = (0, 0)
+
+ results['img'] = img
+ results['img_shape'] = img.shape
+
+ for key in results.get('mask_fields', []):
+ if len(results[key].masks) == 0:
+ continue
+ results[key] = results[key].resize(out_size)
+ masks = results[key].masks
+ processed_masks = []
+ for mask in masks:
+ square_pad_mask = self.square_pad_mask(mask[0], offset)
+ processed_masks.append([square_pad_mask])
+
+ results[key] = PolygonMasks(processed_masks, *(img.shape[:2]))
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomScaling:
+
+ def __init__(self, size=800, scale=(3. / 4, 5. / 2)):
+ """Random scale the image while keeping aspect.
+
+ Args:
+ size (int) : Base size before scaling.
+ scale (tuple(float)) : The range of scaling.
+ """
+ assert isinstance(size, int)
+ assert isinstance(scale, float) or isinstance(scale, tuple)
+ self.size = size
+ self.scale = scale if isinstance(scale, tuple) \
+ else (1 - scale, 1 + scale)
+
+ def __call__(self, results):
+ image = results['img']
+ h, w, _ = results['img_shape']
+
+ aspect_ratio = np.random.uniform(min(self.scale), max(self.scale))
+ scales = self.size * 1.0 / max(h, w) * aspect_ratio
+ scales = np.array([scales, scales])
+ out_size = (int(h * scales[1]), int(w * scales[0]))
+ image = dbnet_cv.imresize(image, out_size[::-1])
+
+ results['img'] = image
+ results['img_shape'] = image.shape
+
+ for key in results.get('mask_fields', []):
+ if len(results[key].masks) == 0:
+ continue
+ results[key] = results[key].resize(out_size)
+
+ return results
+
+
+@PIPELINES.register_module()
+class RandomCropFlip:
+
+ def __init__(self,
+ pad_ratio=0.1,
+ crop_ratio=0.5,
+ iter_num=1,
+ min_area_ratio=0.2):
+ """Random crop and flip a patch of the image.
+
+ Args:
+ crop_ratio (float): The ratio of cropping.
+ iter_num (int): Number of operations.
+ min_area_ratio (float): Minimal area ratio between cropped patch
+ and original image.
+ """
+ assert isinstance(crop_ratio, float)
+ assert isinstance(iter_num, int)
+ assert isinstance(min_area_ratio, float)
+
+ self.pad_ratio = pad_ratio
+ self.epsilon = 1e-2
+ self.crop_ratio = crop_ratio
+ self.iter_num = iter_num
+ self.min_area_ratio = min_area_ratio
+
+ def __call__(self, results):
+ for i in range(self.iter_num):
+ results = self.random_crop_flip(results)
+ return results
+
+ def random_crop_flip(self, results):
+ image = results['img']
+ polygons = results['gt_masks'].masks
+ ignore_polygons = results['gt_masks_ignore'].masks
+ all_polygons = polygons + ignore_polygons
+ if len(polygons) == 0:
+ return results
+
+ if np.random.random() >= self.crop_ratio:
+ return results
+
+ h, w, _ = results['img_shape']
+ area = h * w
+ pad_h = int(h * self.pad_ratio)
+ pad_w = int(w * self.pad_ratio)
+ h_axis, w_axis = self.generate_crop_target(image, all_polygons, pad_h,
+ pad_w)
+ if len(h_axis) == 0 or len(w_axis) == 0:
+ return results
+
+ attempt = 0
+ while attempt < 10:
+ attempt += 1
+ polys_keep = []
+ polys_new = []
+ ign_polys_keep = []
+ ign_polys_new = []
+ xx = np.random.choice(w_axis, size=2)
+ xmin = np.min(xx) - pad_w
+ xmax = np.max(xx) - pad_w
+ xmin = np.clip(xmin, 0, w - 1)
+ xmax = np.clip(xmax, 0, w - 1)
+ yy = np.random.choice(h_axis, size=2)
+ ymin = np.min(yy) - pad_h
+ ymax = np.max(yy) - pad_h
+ ymin = np.clip(ymin, 0, h - 1)
+ ymax = np.clip(ymax, 0, h - 1)
+ if (xmax - xmin) * (ymax - ymin) < area * self.min_area_ratio:
+ # area too small
+ continue
+
+ pts = np.stack([[xmin, xmax, xmax, xmin],
+ [ymin, ymin, ymax, ymax]]).T.astype(np.int32)
+ pp = plg(pts)
+ fail_flag = False
+ for polygon in polygons:
+ ppi = plg(polygon[0].reshape(-1, 2))
+ ppiou = eval_utils.poly_intersection(ppi, pp)
+ if np.abs(ppiou - float(ppi.area)) > self.epsilon and \
+ np.abs(ppiou) > self.epsilon:
+ fail_flag = True
+ break
+ elif np.abs(ppiou - float(ppi.area)) < self.epsilon:
+ polys_new.append(polygon)
+ else:
+ polys_keep.append(polygon)
+
+ for polygon in ignore_polygons:
+ ppi = plg(polygon[0].reshape(-1, 2))
+ ppiou = eval_utils.poly_intersection(ppi, pp)
+ if np.abs(ppiou - float(ppi.area)) > self.epsilon and \
+ np.abs(ppiou) > self.epsilon:
+ fail_flag = True
+ break
+ elif np.abs(ppiou - float(ppi.area)) < self.epsilon:
+ ign_polys_new.append(polygon)
+ else:
+ ign_polys_keep.append(polygon)
+
+ if fail_flag:
+ continue
+ else:
+ break
+
+ cropped = image[ymin:ymax, xmin:xmax, :]
+ select_type = np.random.randint(3)
+ if select_type == 0:
+ img = np.ascontiguousarray(cropped[:, ::-1])
+ elif select_type == 1:
+ img = np.ascontiguousarray(cropped[::-1, :])
+ else:
+ img = np.ascontiguousarray(cropped[::-1, ::-1])
+ image[ymin:ymax, xmin:xmax, :] = img
+ results['img'] = image
+
+ if len(polys_new) + len(ign_polys_new) != 0:
+ height, width, _ = cropped.shape
+ if select_type == 0:
+ for idx, polygon in enumerate(polys_new):
+ poly = polygon[0].reshape(-1, 2)
+ poly[:, 0] = width - poly[:, 0] + 2 * xmin
+ polys_new[idx] = [poly.reshape(-1, )]
+ for idx, polygon in enumerate(ign_polys_new):
+ poly = polygon[0].reshape(-1, 2)
+ poly[:, 0] = width - poly[:, 0] + 2 * xmin
+ ign_polys_new[idx] = [poly.reshape(-1, )]
+ elif select_type == 1:
+ for idx, polygon in enumerate(polys_new):
+ poly = polygon[0].reshape(-1, 2)
+ poly[:, 1] = height - poly[:, 1] + 2 * ymin
+ polys_new[idx] = [poly.reshape(-1, )]
+ for idx, polygon in enumerate(ign_polys_new):
+ poly = polygon[0].reshape(-1, 2)
+ poly[:, 1] = height - poly[:, 1] + 2 * ymin
+ ign_polys_new[idx] = [poly.reshape(-1, )]
+ else:
+ for idx, polygon in enumerate(polys_new):
+ poly = polygon[0].reshape(-1, 2)
+ poly[:, 0] = width - poly[:, 0] + 2 * xmin
+ poly[:, 1] = height - poly[:, 1] + 2 * ymin
+ polys_new[idx] = [poly.reshape(-1, )]
+ for idx, polygon in enumerate(ign_polys_new):
+ poly = polygon[0].reshape(-1, 2)
+ poly[:, 0] = width - poly[:, 0] + 2 * xmin
+ poly[:, 1] = height - poly[:, 1] + 2 * ymin
+ ign_polys_new[idx] = [poly.reshape(-1, )]
+ polygons = polys_keep + polys_new
+ ignore_polygons = ign_polys_keep + ign_polys_new
+ results['gt_masks'] = PolygonMasks(polygons, *(image.shape[:2]))
+ results['gt_masks_ignore'] = PolygonMasks(ignore_polygons,
+ *(image.shape[:2]))
+
+ return results
+
+ def generate_crop_target(self, image, all_polys, pad_h, pad_w):
+ """Generate crop target and make sure not to crop the polygon
+ instances.
+
+ Args:
+ image (ndarray): The image waited to be crop.
+ all_polys (list[list[ndarray]]): All polygons including ground
+ truth polygons and ground truth ignored polygons.
+ pad_h (int): Padding length of height.
+ pad_w (int): Padding length of width.
+ Returns:
+ h_axis (ndarray): Vertical cropping range.
+ w_axis (ndarray): Horizontal cropping range.
+ """
+ h, w, _ = image.shape
+ h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
+ w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
+
+ text_polys = []
+ for polygon in all_polys:
+ rect = cv2.minAreaRect(polygon[0].astype(np.int32).reshape(-1, 2))
+ box = cv2.boxPoints(rect)
+ box = np.int0(box)
+ text_polys.append([box[0], box[1], box[2], box[3]])
+
+ polys = np.array(text_polys, dtype=np.int32)
+ for poly in polys:
+ poly = np.round(poly, decimals=0).astype(np.int32)
+ minx = np.min(poly[:, 0])
+ maxx = np.max(poly[:, 0])
+ w_array[minx + pad_w:maxx + pad_w] = 1
+ miny = np.min(poly[:, 1])
+ maxy = np.max(poly[:, 1])
+ h_array[miny + pad_h:maxy + pad_h] = 1
+
+ h_axis = np.where(h_array == 0)[0]
+ w_axis = np.where(w_array == 0)[0]
+ return h_axis, w_axis
+
+
+@PIPELINES.register_module()
+class PyramidRescale:
+ """Resize the image to the base shape, downsample it with gaussian pyramid,
+ and rescale it back to original size.
+
+ Adapted from https://github.com/FangShancheng/ABINet.
+
+ Args:
+ factor (int): The decay factor from base size, or the number of
+ downsampling operations from the base layer.
+ base_shape (tuple(int)): The shape of the base layer of the pyramid.
+ randomize_factor (bool): If True, the final factor would be a random
+ integer in [0, factor].
+
+ :Required Keys:
+ - | ``img`` (ndarray): The input image.
+
+ :Affected Keys:
+ :Modified:
+ - | ``img`` (ndarray): The modified image.
+ """
+
+ def __init__(self, factor=4, base_shape=(128, 512), randomize_factor=True):
+ assert isinstance(factor, int)
+ assert isinstance(base_shape, list) or isinstance(base_shape, tuple)
+ assert len(base_shape) == 2
+ assert isinstance(randomize_factor, bool)
+ self.factor = factor if not randomize_factor else np.random.randint(
+ 0, factor + 1)
+ self.base_w, self.base_h = base_shape
+
+ def __call__(self, results):
+ assert 'img' in results
+ if self.factor == 0:
+ return results
+ img = results['img']
+ src_h, src_w = img.shape[:2]
+ scale_img = dbnet_cv.imresize(img, (self.base_w, self.base_h))
+ for _ in range(self.factor):
+ scale_img = cv2.pyrDown(scale_img)
+ scale_img = dbnet_cv.imresize(scale_img, (src_w, src_h))
+ results['img'] = scale_img
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(factor={self.factor}, '
+ repr_str += f'basew={self.basew}, baseh={self.baseh})'
+ return repr_str
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/text_det_dataset.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/text_det_dataset.py
new file mode 100755
index 0000000000000000000000000000000000000000..75fa273c56f55df970c355d8928771d8cf9de072
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/text_det_dataset.py
@@ -0,0 +1,131 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+from dbnet_det.datasets.builder import DATASETS
+
+from dbnet.core.evaluation.hmean import eval_hmean
+from dbnet.datasets.base_dataset import BaseDataset
+
+
+@DATASETS.register_module()
+class TextDetDataset(BaseDataset):
+
+ def _parse_anno_info(self, annotations):
+ """Parse bbox and mask annotation.
+ Args:
+ annotations (dict): Annotations of one image.
+
+ Returns:
+ dict: A dict containing the following keys: bboxes, bboxes_ignore,
+ labels, masks, masks_ignore. "masks" and
+ "masks_ignore" are represented by polygon boundary
+ point sequences.
+ """
+ gt_bboxes, gt_bboxes_ignore = [], []
+ gt_masks, gt_masks_ignore = [], []
+ gt_labels = []
+ for ann in annotations:
+ if ann.get('iscrowd', False):
+ gt_bboxes_ignore.append(ann['bbox'])
+ gt_masks_ignore.append(ann.get('segmentation', None))
+ else:
+ gt_bboxes.append(ann['bbox'])
+ gt_labels.append(ann['category_id'])
+ gt_masks.append(ann.get('segmentation', None))
+ if gt_bboxes:
+ gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
+ gt_labels = np.array(gt_labels, dtype=np.int64)
+ else:
+ gt_bboxes = np.zeros((0, 4), dtype=np.float32)
+ gt_labels = np.array([], dtype=np.int64)
+
+ if gt_bboxes_ignore:
+ gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
+ else:
+ gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
+
+ ann = dict(
+ bboxes=gt_bboxes,
+ labels=gt_labels,
+ bboxes_ignore=gt_bboxes_ignore,
+ masks_ignore=gt_masks_ignore,
+ masks=gt_masks)
+
+ return ann
+
+ def prepare_train_img(self, index):
+ """Get training data and annotations from pipeline.
+
+ Args:
+ index (int): Index of data.
+
+ Returns:
+ dict: Training data and annotation after pipeline with new keys
+ introduced by pipeline.
+ """
+ img_ann_info = self.data_infos[index]
+ img_info = {
+ 'filename': img_ann_info['file_name'],
+ 'height': img_ann_info['height'],
+ 'width': img_ann_info['width']
+ }
+ ann_info = self._parse_anno_info(img_ann_info['annotations'])
+ results = dict(img_info=img_info, ann_info=ann_info)
+ results['bbox_fields'] = []
+ results['mask_fields'] = []
+ results['seg_fields'] = []
+ self.pre_pipeline(results)
+
+ return self.pipeline(results)
+
+ def evaluate(self,
+ results,
+ metric='hmean-iou',
+ score_thr=None,
+ min_score_thr=0.3,
+ max_score_thr=0.9,
+ step=0.1,
+ rank_list=None,
+ logger=None,
+ **kwargs):
+ """Evaluate the dataset.
+
+ Args:
+ results (list): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated.
+ score_thr (float): Deprecated. Please use min_score_thr instead.
+ min_score_thr (float): Minimum score threshold of prediction map.
+ max_score_thr (float): Maximum score threshold of prediction map.
+ step (float): The spacing between score thresholds.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ rank_list (str): json file used to save eval result
+ of each image after ranking.
+ Returns:
+ dict[str: float]
+ """
+ metrics = metric if isinstance(metric, list) else [metric]
+ allowed_metrics = ['hmean-iou', 'hmean-ic13']
+ metrics = set(metrics) & set(allowed_metrics)
+
+ img_infos = []
+ ann_infos = []
+ for i in range(len(self)):
+ img_ann_info = self.data_infos[i]
+ img_info = {'filename': img_ann_info['file_name']}
+ ann_info = self._parse_anno_info(img_ann_info['annotations'])
+ img_infos.append(img_info)
+ ann_infos.append(ann_info)
+
+ eval_results = eval_hmean(
+ results,
+ img_infos,
+ ann_infos,
+ metrics=metrics,
+ score_thr=score_thr,
+ min_score_thr=min_score_thr,
+ max_score_thr=max_score_thr,
+ step=step,
+ logger=logger,
+ rank_list=rank_list)
+
+ return eval_results
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/uniform_concat_dataset.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/uniform_concat_dataset.py
new file mode 100755
index 0000000000000000000000000000000000000000..104dd3fdec6e4fa1d05ffc103230824554d874c2
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/uniform_concat_dataset.py
@@ -0,0 +1,151 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+from collections import defaultdict
+
+import numpy as np
+from dbnet_cv.utils import print_log
+from dbnet_det.datasets import DATASETS, ConcatDataset, build_dataset
+
+from dbnet.utils import is_2dlist, is_type_list
+
+
+@DATASETS.register_module()
+class UniformConcatDataset(ConcatDataset):
+ """A wrapper of ConcatDataset which support dataset pipeline assignment and
+ replacement.
+
+ Args:
+ datasets (list[dict] | list[list[dict]]): A list of datasets cfgs.
+ separate_eval (bool): Whether to evaluate the results
+ separately if it is used as validation dataset.
+ Defaults to True.
+ show_mean_scores (str | bool): Whether to compute the mean evaluation
+ results, only applicable when ``separate_eval=True``. Options are
+ [True, False, ``auto``]. If ``True``, mean results will be added to
+ the result dictionary with keys in the form of
+ ``mean_{metric_name}``. If 'auto', mean results will be shown only
+ when more than 1 dataset is wrapped.
+ pipeline (None | list[dict] | list[list[dict]]): If ``None``,
+ each dataset in datasets use its own pipeline;
+ If ``list[dict]``, it will be assigned to the dataset whose
+ pipeline is None in datasets;
+ If ``list[list[dict]]``, pipeline of dataset which is None
+ in datasets will be replaced by the corresponding pipeline
+ in the list.
+ force_apply (bool): If True, apply pipeline above to each dataset
+ even if it have its own pipeline. Default: False.
+ """
+
+ def __init__(self,
+ datasets,
+ separate_eval=True,
+ show_mean_scores='auto',
+ pipeline=None,
+ force_apply=False,
+ **kwargs):
+ new_datasets = []
+ if pipeline is not None:
+ assert isinstance(
+ pipeline,
+ list), 'pipeline must be list[dict] or list[list[dict]].'
+ if is_type_list(pipeline, dict):
+ self._apply_pipeline(datasets, pipeline, force_apply)
+ new_datasets = datasets
+ elif is_2dlist(pipeline):
+ assert is_2dlist(datasets)
+ assert len(datasets) == len(pipeline)
+ for sub_datasets, tmp_pipeline in zip(datasets, pipeline):
+ self._apply_pipeline(sub_datasets, tmp_pipeline,
+ force_apply)
+ new_datasets.extend(sub_datasets)
+ else:
+ if is_2dlist(datasets):
+ for sub_datasets in datasets:
+ new_datasets.extend(sub_datasets)
+ else:
+ new_datasets = datasets
+ datasets = [build_dataset(c, kwargs) for c in new_datasets]
+ super().__init__(datasets, separate_eval)
+
+ if not separate_eval:
+ raise NotImplementedError(
+ 'Evaluating datasets as a whole is not'
+ ' supported yet. Please use "separate_eval=True"')
+
+ assert isinstance(show_mean_scores, bool) or show_mean_scores == 'auto'
+ if show_mean_scores == 'auto':
+ show_mean_scores = len(self.datasets) > 1
+ self.show_mean_scores = show_mean_scores
+ if show_mean_scores is True or show_mean_scores == 'auto' and len(
+ self.datasets) > 1:
+ if len(set([type(ds) for ds in self.datasets])) != 1:
+ raise NotImplementedError(
+ 'To compute mean evaluation scores, all datasets'
+ 'must have the same type')
+
+ @staticmethod
+ def _apply_pipeline(datasets, pipeline, force_apply=False):
+ from_cfg = all(isinstance(x, dict) for x in datasets)
+ assert from_cfg, 'datasets should be config dicts'
+ assert all(isinstance(x, dict) for x in pipeline)
+ for dataset in datasets:
+ if dataset['pipeline'] is None or force_apply:
+ dataset['pipeline'] = copy.deepcopy(pipeline)
+
+ def evaluate(self, results, logger=None, **kwargs):
+ """Evaluate the results.
+
+ Args:
+ results (list[list | tuple]): Testing results of the dataset.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+
+ Returns:
+ dict[str: float]: Results of each separate
+ dataset if `self.separate_eval=True`.
+ """
+ assert len(results) == self.cumulative_sizes[-1], \
+ ('Dataset and results have different sizes: '
+ f'{self.cumulative_sizes[-1]} v.s. {len(results)}')
+
+ # Check whether all the datasets support evaluation
+ for dataset in self.datasets:
+ assert hasattr(dataset, 'evaluate'), \
+ f'{type(dataset)} does not implement evaluate function'
+
+ if self.separate_eval:
+ dataset_idx = -1
+
+ total_eval_results = dict()
+
+ if self.show_mean_scores:
+ mean_eval_results = defaultdict(list)
+
+ for dataset in self.datasets:
+ start_idx = 0 if dataset_idx == -1 else \
+ self.cumulative_sizes[dataset_idx]
+ end_idx = self.cumulative_sizes[dataset_idx + 1]
+
+ results_per_dataset = results[start_idx:end_idx]
+ print_log(
+ f'\nEvaluating {dataset.ann_file} with '
+ f'{len(results_per_dataset)} images now',
+ logger=logger)
+
+ eval_results_per_dataset = dataset.evaluate(
+ results_per_dataset, logger=logger, **kwargs)
+ dataset_idx += 1
+ for k, v in eval_results_per_dataset.items():
+ total_eval_results.update({f'{dataset_idx}_{k}': v})
+ if self.show_mean_scores:
+ mean_eval_results[k].append(v)
+
+ if self.show_mean_scores:
+ for k, v in mean_eval_results.items():
+ total_eval_results[f'mean_{k}'] = np.mean(v)
+
+ return total_eval_results
+ else:
+ raise NotImplementedError(
+ 'Evaluating datasets as a whole is not'
+ ' supported yet. Please use "separate_eval=True"')
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..f2fc30a528236846177a621f73a3f10220d679df
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .loader import AnnFileLoader, HardDiskLoader, LmdbLoader
+from .parser import LineJsonParser, LineStrParser
+
+__all__ = [
+ 'HardDiskLoader', 'LmdbLoader', 'AnnFileLoader', 'LineStrParser',
+ 'LineJsonParser'
+]
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/backend.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/backend.py
new file mode 100755
index 0000000000000000000000000000000000000000..f72b6f63408a9e2d270af89ec932f7fc618736ce
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/backend.py
@@ -0,0 +1,198 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import os
+import os.path as osp
+import shutil
+import warnings
+
+import dbnet_cv
+
+from dbnet import digit_version
+from dbnet.utils import list_from_file
+
+
+class LmdbAnnFileBackend:
+ """Lmdb storage backend for annotation file.
+
+ Args:
+ lmdb_path (str): Lmdb file path.
+ """
+
+ def __init__(self, lmdb_path, encoding='utf8'):
+ """Currently we support two lmdb formats, one is the lmdb file with
+ only labels generated by txt2lmdb (deprecated), and one is the lmdb
+ file generated by recog2lmdb.
+
+ The former stores string in 'filename text' format directly in lmdb,
+ while the latter uses a more reasonable image_key as well as label_key
+ for querying.
+ """
+ self.lmdb_path = lmdb_path
+ self.encoding = encoding
+ self.deprecated_format = False
+ env = self._get_env()
+ with env.begin(write=False) as txn:
+ try:
+ self.total_number = int(
+ txn.get('num-samples'.encode('utf-8')).decode(
+ self.encoding))
+ except AttributeError:
+ warnings.warn(
+ 'DeprecationWarning: The lmdb dataset generated with '
+ 'txt2lmdb will be deprecate, please use the latest '
+ 'tools/data/utils/recog2lmdb to generate lmdb dataset. '
+ 'See https://mmocr.readthedocs.io/en/latest/tools.html#'
+ 'convert-text-recognition-dataset-to-lmdb-format for '
+ 'details.')
+ self.total_number = int(
+ txn.get('total_number'.encode('utf-8')).decode(
+ self.encoding))
+ self.deprecated_format = True
+ # The lmdb file may contain only the label, or it may contain both
+ # the label and the image, so we use image_key here for probing.
+ image_key = f'image-{1:09d}'
+ if txn.get(image_key.encode(encoding)) is None:
+ self.label_only = True
+ else:
+ self.label_only = False
+
+ def __getitem__(self, index):
+ """Retrieve one line from lmdb file by index.
+
+ In order to support space
+ reading, the returned lines are in the form of json, such as
+ '{'filename': 'image1.jpg' ,'text':'HELLO'}'
+ """
+ if not hasattr(self, 'env'):
+ self.env = self._get_env()
+
+ with self.env.begin(write=False) as txn:
+ if self.deprecated_format:
+ line = txn.get(str(index).encode('utf-8')).decode(
+ self.encoding)
+ keys = line.strip('/n').split(' ')
+ if len(keys) == 4:
+ filename, height, width, annotations = keys
+ line = json.dumps(
+ dict(
+ filename=filename,
+ height=height,
+ width=width,
+ annotations=annotations),
+ ensure_ascii=False)
+ elif len(keys) == 2:
+ filename, text = keys
+ line = json.dumps(
+ dict(filename=filename, text=text), ensure_ascii=False)
+ else:
+ index = index + 1
+ label_key = f'label-{index:09d}'
+ if self.label_only:
+ line = txn.get(label_key.encode('utf-8')).decode(
+ self.encoding)
+ else:
+ img_key = f'image-{index:09d}'
+ text = txn.get(label_key.encode('utf-8')).decode(
+ self.encoding)
+ line = json.dumps(
+ dict(filename=img_key, text=text), ensure_ascii=False)
+ return line
+
+ def __len__(self):
+ return self.total_number
+
+ def _get_env(self):
+ try:
+ import lmdb
+ except ImportError:
+ raise ImportError(
+ 'Please install lmdb to enable LmdbAnnFileBackend.')
+ return lmdb.open(
+ self.lmdb_path,
+ max_readers=1,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False,
+ )
+
+ def close(self):
+ self.env.close()
+
+
+class HardDiskAnnFileBackend:
+ """Load annotation file with raw hard disks storage backend."""
+
+ def __init__(self, file_format='txt'):
+ assert file_format in ['txt', 'lmdb']
+ self.file_format = file_format
+
+ def __call__(self, ann_file):
+ if self.file_format == 'lmdb':
+ return LmdbAnnFileBackend(ann_file)
+
+ return list_from_file(ann_file)
+
+
+class PetrelAnnFileBackend:
+ """Load annotation file with petrel storage backend."""
+
+ def __init__(self, file_format='txt', save_dir='tmp_dir'):
+ assert file_format in ['txt', 'lmdb']
+ self.file_format = file_format
+ self.save_dir = save_dir
+
+ def __call__(self, ann_file):
+ file_client = dbnet_cv.FileClient(backend='petrel')
+
+ if self.file_format == 'lmdb':
+ dbnet_cv_version = digit_version(dbnet_cv.__version__)
+ if dbnet_cv_version < digit_version('1.3.16'):
+ raise Exception('Please update dbnet_cv to 1.3.16 or higher '
+ 'to enable "get_local_path" of "FileClient".')
+ assert file_client.isdir(ann_file)
+ files = file_client.list_dir_or_file(ann_file)
+
+ ann_file_rel_path = ann_file.split('s3://')[-1]
+ ann_file_dir = osp.dirname(ann_file_rel_path)
+ ann_file_name = osp.basename(ann_file_rel_path)
+ local_dir = osp.join(self.save_dir, ann_file_dir, ann_file_name)
+ if osp.exists(local_dir):
+ warnings.warn(
+ f'local_ann_file: {local_dir} is already existed and '
+ 'will be used. If it is not the correct ann_file '
+ 'corresponding to {ann_file}, please remove it or '
+ 'change "save_dir" first then try again.')
+ else:
+ os.makedirs(local_dir, exist_ok=True)
+ print(f'Fetching {ann_file} to {local_dir}...')
+ for each_file in files:
+ tmp_file_path = file_client.join_path(ann_file, each_file)
+ with file_client.get_local_path(
+ tmp_file_path) as local_path:
+ shutil.copy(local_path, osp.join(local_dir, each_file))
+
+ return LmdbAnnFileBackend(local_dir)
+
+ lines = str(file_client.get(ann_file), encoding='utf-8').split('\n')
+
+ return [x for x in lines if x.strip() != '']
+
+
+class HTTPAnnFileBackend:
+ """Load annotation file with http storage backend."""
+
+ def __init__(self, file_format='txt'):
+ assert file_format in ['txt', 'lmdb']
+ self.file_format = file_format
+
+ def __call__(self, ann_file):
+ file_client = dbnet_cv.FileClient(backend='http')
+
+ if self.file_format == 'lmdb':
+ raise NotImplementedError(
+ 'Loading lmdb file on http is not supported yet.')
+
+ lines = str(file_client.get(ann_file), encoding='utf-8').split('\n')
+
+ return [x for x in lines if x.strip() != '']
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/loader.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/loader.py
new file mode 100755
index 0000000000000000000000000000000000000000..dff51c6cd8f8c8320516ae198134bcfd82b8431c
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/loader.py
@@ -0,0 +1,112 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+from dbnet.datasets.builder import LOADERS, build_parser
+from .backend import (HardDiskAnnFileBackend, HTTPAnnFileBackend,
+ PetrelAnnFileBackend)
+
+
+@LOADERS.register_module()
+class AnnFileLoader:
+ """Annotation file loader to load annotations from ann_file, and parse raw
+ annotation to dict format with certain parser.
+
+ Args:
+ ann_file (str): Annotation file path.
+ parser (dict): Dictionary to construct parser
+ to parse original annotation infos.
+ repeat (int|float): Repeated times of dataset.
+ file_storage_backend (str): The storage backend type for annotation
+ file. Options are "disk", "http" and "petrel". Default: "disk".
+ file_format (str): The format of annotation file. Options are
+ "txt" and "lmdb". Default: "txt".
+ """
+
+ _backends = {
+ 'disk': HardDiskAnnFileBackend,
+ 'petrel': PetrelAnnFileBackend,
+ 'http': HTTPAnnFileBackend
+ }
+
+ def __init__(self,
+ ann_file,
+ parser,
+ repeat=1,
+ file_storage_backend='disk',
+ file_format='txt',
+ **kwargs):
+ assert isinstance(ann_file, str)
+ assert isinstance(repeat, (int, float))
+ assert isinstance(parser, dict)
+ assert repeat > 0
+ assert file_storage_backend in ['disk', 'http', 'petrel']
+ assert file_format in ['txt', 'lmdb']
+
+ if file_format == 'lmdb' and parser['type'] == 'LineStrParser':
+ raise ValueError('We only support using LineJsonParser '
+ 'to parse lmdb file. Please use LineJsonParser '
+ 'in the dataset config')
+ self.parser = build_parser(parser)
+ self.repeat = repeat
+ self.ann_file_backend = self._backends[file_storage_backend](
+ file_format, **kwargs)
+ self.ori_data_infos = self._load(ann_file)
+
+ def __len__(self):
+ return int(len(self.ori_data_infos) * self.repeat)
+
+ def _load(self, ann_file):
+ """Load annotation file."""
+
+ return self.ann_file_backend(ann_file)
+
+ def __getitem__(self, index):
+ """Retrieve anno info of one instance with dict format."""
+ return self.parser.get_item(self.ori_data_infos, index)
+
+ def __iter__(self):
+ self._n = 0
+ return self
+
+ def __next__(self):
+ if self._n < len(self):
+ data = self[self._n]
+ self._n += 1
+ return data
+ raise StopIteration
+
+ def close(self):
+ """For ann_file with lmdb format only."""
+ self.ori_data_infos.close()
+
+
+@LOADERS.register_module()
+class HardDiskLoader(AnnFileLoader):
+ """Load txt format annotation file from hard disks."""
+
+ def __init__(self, ann_file, parser, repeat=1):
+ warnings.warn(
+ 'HardDiskLoader is deprecated, please use '
+ 'AnnFileLoader instead.', UserWarning)
+ super().__init__(
+ ann_file,
+ parser,
+ repeat,
+ file_storage_backend='disk',
+ file_format='txt')
+
+
+@LOADERS.register_module()
+class LmdbLoader(AnnFileLoader):
+ """Load lmdb format annotation file from hard disks."""
+
+ def __init__(self, ann_file, parser, repeat=1):
+ warnings.warn(
+ 'LmdbLoader is deprecated, please use '
+ 'AnnFileLoader instead.', UserWarning)
+ super().__init__(
+ ann_file,
+ parser,
+ repeat,
+ file_storage_backend='disk',
+ file_format='lmdb')
diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/parser.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/parser.py
new file mode 100755
index 0000000000000000000000000000000000000000..83abb3dc770c1a9428764a5eaba0e3b69901bca4
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/parser.py
@@ -0,0 +1,82 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import warnings
+
+from dbnet.datasets.builder import PARSERS
+from dbnet.utils import StringStrip
+
+
+@PARSERS.register_module()
+class LineStrParser:
+ """Parse string of one line in annotation file to dict format.
+
+ Args:
+ keys (list[str]): Keys in result dict.
+ keys_idx (list[int]): Value index in sub-string list
+ for each key above.
+ separator (str): Separator to separate string to list of sub-string.
+ """
+
+ def __init__(self,
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ',
+ **kwargs):
+ assert isinstance(keys, list)
+ assert isinstance(keys_idx, list)
+ assert isinstance(separator, str)
+ assert len(keys) > 0
+ assert len(keys) == len(keys_idx)
+ self.keys = keys
+ self.keys_idx = keys_idx
+ self.separator = separator
+ self.strip_cls = StringStrip(**kwargs)
+
+ def get_item(self, data_ret, index):
+ map_index = index % len(data_ret)
+ line_str = data_ret[map_index]
+ line_str = self.strip_cls(line_str)
+ if len(line_str.split(' ')) > 2:
+ msg = 'More than two blank spaces were detected. '
+ msg += 'Please use LineJsonParser to handle '
+ msg += 'annotations with blanks. '
+ msg += 'Check Doc '
+ msg += 'https://mmocr.readthedocs.io/en/latest/'
+ msg += 'tutorials/blank_recog.html '
+ msg += 'for details.'
+ warnings.warn(msg)
+ line_str = line_str.split(self.separator)
+ if len(line_str) <= max(self.keys_idx):
+ raise Exception(
+ f'key index: {max(self.keys_idx)} out of range: {line_str}')
+
+ line_info = {}
+ for i, key in enumerate(self.keys):
+ line_info[key] = line_str[self.keys_idx[i]]
+ return line_info
+
+
+@PARSERS.register_module()
+class LineJsonParser:
+ """Parse json-string of one line in annotation file to dict format.
+
+ Args:
+ keys (list[str]): Keys in both json-string and result dict.
+ """
+
+ def __init__(self, keys=[]):
+ assert isinstance(keys, list)
+ assert len(keys) > 0
+ self.keys = keys
+
+ def get_item(self, data_ret, index):
+ map_index = index % len(data_ret)
+ json_str = data_ret[map_index]
+ line_json_obj = json.loads(json_str)
+ line_info = {}
+ for key in self.keys:
+ if key not in line_json_obj:
+ raise Exception(f'key {key} not in line json {line_json_obj}')
+ line_info[key] = line_json_obj[key]
+
+ return line_info
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..80a0426f38cedc3d5f7179a7e6d79ae41cc034f3
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from . import common, textdet
+from .builder import (BACKBONES, CONVERTORS, DECODERS, DETECTORS, ENCODERS,
+ HEADS, LOSSES, NECKS, PREPROCESSOR, build_backbone,
+ build_convertor, build_decoder, build_detector,
+ build_encoder, build_loss, build_preprocessor)
+from .common import * # NOQA
+# from .kie import * # NOQA
+# from .ner import * # NOQA
+from .textdet import * # NOQA
+# from .textrecog import * # NOQA
+
+# __all__ = [
+# 'BACKBONES', 'DETECTORS', 'HEADS', 'LOSSES', 'NECKS', 'build_backbone',
+# 'build_detector', 'build_loss', 'CONVERTORS', 'ENCODERS', 'DECODERS',
+# 'PREPROCESSOR', 'build_convertor', 'build_encoder', 'build_decoder',
+# 'build_preprocessor'
+# ]
+# __all__ += common.__all__ + kie.__all__ + textdet.__all__ + textrecog.__all__
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/builder.py b/cv/ocr/dbnet/pytorch/dbnet/models/builder.py
new file mode 100755
index 0000000000000000000000000000000000000000..d5a1963674671aed82fcbec3e77c2cc7827d5057
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/builder.py
@@ -0,0 +1,152 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch.nn as nn
+from dbnet_cv.cnn import ACTIVATION_LAYERS as DBNET_CV_ACTIVATION_LAYERS
+from dbnet_cv.cnn import UPSAMPLE_LAYERS as DBNET_CV_UPSAMPLE_LAYERS
+from dbnet_cv.utils import Registry, build_from_cfg
+from dbnet_det.models.builder import BACKBONES as dbnet_det_BACKBONES
+
+CONVERTORS = Registry('convertor')
+ENCODERS = Registry('encoder')
+DECODERS = Registry('decoder')
+PREPROCESSOR = Registry('preprocessor')
+POSTPROCESSOR = Registry('postprocessor')
+
+UPSAMPLE_LAYERS = Registry('upsample layer', parent=DBNET_CV_UPSAMPLE_LAYERS)
+BACKBONES = Registry('models', parent=dbnet_det_BACKBONES)
+LOSSES = BACKBONES
+DETECTORS = BACKBONES
+ROI_EXTRACTORS = BACKBONES
+HEADS = BACKBONES
+NECKS = BACKBONES
+FUSERS = BACKBONES
+RECOGNIZERS = BACKBONES
+
+ACTIVATION_LAYERS = Registry('activation layer', parent=DBNET_CV_ACTIVATION_LAYERS)
+
+
+def build_recognizer(cfg, train_cfg=None, test_cfg=None):
+ """Build recognizer."""
+ return build_from_cfg(cfg, RECOGNIZERS,
+ dict(train_cfg=train_cfg, test_cfg=test_cfg))
+
+
+def build_convertor(cfg):
+ """Build label convertor for scene text recognizer."""
+ return build_from_cfg(cfg, CONVERTORS)
+
+
+def build_encoder(cfg):
+ """Build encoder for scene text recognizer."""
+ return build_from_cfg(cfg, ENCODERS)
+
+
+def build_decoder(cfg):
+ """Build decoder for scene text recognizer."""
+ return build_from_cfg(cfg, DECODERS)
+
+
+def build_preprocessor(cfg):
+ """Build preprocessor for scene text recognizer."""
+ return build_from_cfg(cfg, PREPROCESSOR)
+
+
+def build_postprocessor(cfg):
+ """Build postprocessor for scene text detector."""
+ return build_from_cfg(cfg, POSTPROCESSOR)
+
+
+def build_roi_extractor(cfg):
+ """Build roi extractor."""
+ return ROI_EXTRACTORS.build(cfg)
+
+
+def build_loss(cfg):
+ """Build loss."""
+ return LOSSES.build(cfg)
+
+
+def build_backbone(cfg):
+ """Build backbone."""
+ return BACKBONES.build(cfg)
+
+
+def build_head(cfg):
+ """Build head."""
+ return HEADS.build(cfg)
+
+
+def build_neck(cfg):
+ """Build neck."""
+ return NECKS.build(cfg)
+
+
+def build_fuser(cfg):
+ """Build fuser."""
+ return FUSERS.build(cfg)
+
+
+def build_upsample_layer(cfg, *args, **kwargs):
+ """Build upsample layer.
+
+ Args:
+ cfg (dict): The upsample layer config, which should contain:
+
+ - type (str): Layer type.
+ - scale_factor (int): Upsample ratio, which is not applicable to
+ deconv.
+ - layer args: Args needed to instantiate a upsample layer.
+ args (argument list): Arguments passed to the ``__init__``
+ method of the corresponding conv layer.
+ kwargs (keyword arguments): Keyword arguments passed to the
+ ``__init__`` method of the corresponding conv layer.
+
+ Returns:
+ nn.Module: Created upsample layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
+ if 'type' not in cfg:
+ raise KeyError(
+ f'the cfg dict must contain the key "type", but got {cfg}')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in UPSAMPLE_LAYERS:
+ raise KeyError(f'Unrecognized upsample type {layer_type}')
+ else:
+ upsample = UPSAMPLE_LAYERS.get(layer_type)
+
+ if upsample is nn.Upsample:
+ cfg_['mode'] = layer_type
+ layer = upsample(*args, **kwargs, **cfg_)
+ return layer
+
+
+def build_activation_layer(cfg):
+ """Build activation layer.
+
+ Args:
+ cfg (dict): The activation layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate an activation layer.
+
+ Returns:
+ nn.Module: Created activation layer.
+ """
+ return build_from_cfg(cfg, ACTIVATION_LAYERS)
+
+
+def build_detector(cfg, train_cfg=None, test_cfg=None):
+ """Build detector."""
+ if train_cfg is not None or test_cfg is not None:
+ warnings.warn(
+ 'train_cfg and test_cfg is deprecated, '
+ 'please specify them in model', UserWarning)
+ assert cfg.get('train_cfg') is None or train_cfg is None, \
+ 'train_cfg specified in both outer field and model field '
+ assert cfg.get('test_cfg') is None or test_cfg is None, \
+ 'test_cfg specified in both outer field and model field '
+ return DETECTORS.build(
+ cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/common/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/common/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..066ee87676b6c35fa132218bdcc513ce86c90d61
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/common/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from . import backbones,losses
+from .backbones import * # NOQA
+from .losses import * # NOQA
+
+__all__ = backbones.__all__ + losses.__all__
\ No newline at end of file
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/common/backbones/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/common/backbones/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..3c384ba3010dd3fc81b562f7101c63ecaef1e0a6
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/common/backbones/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .unet import UNet
+
+__all__ = ['UNet']
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/common/backbones/unet.py b/cv/ocr/dbnet/pytorch/dbnet/models/common/backbones/unet.py
new file mode 100755
index 0000000000000000000000000000000000000000..2a6d0f8cd8154dda5ecc0a6609d4797ec7d1ce01
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/common/backbones/unet.py
@@ -0,0 +1,516 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from dbnet_cv.cnn import ConvModule, build_norm_layer
+from dbnet_cv.runner import BaseModule
+from dbnet_cv.utils.parrots_wrapper import _BatchNorm
+
+from dbnet.models.builder import (BACKBONES, UPSAMPLE_LAYERS,
+ build_activation_layer, build_upsample_layer)
+
+
+class UpConvBlock(nn.Module):
+ """Upsample convolution block in decoder for UNet.
+
+ This upsample convolution block consists of one upsample module
+ followed by one convolution block. The upsample module expands the
+ high-level low-resolution feature map and the convolution block fuses
+ the upsampled high-level low-resolution feature map and the low-level
+ high-resolution feature map from encoder.
+
+ Args:
+ conv_block (nn.Sequential): Sequential of convolutional layers.
+ in_channels (int): Number of input channels of the high-level
+ skip_channels (int): Number of input channels of the low-level
+ high-resolution feature map from encoder.
+ out_channels (int): Number of output channels.
+ num_convs (int): Number of convolutional layers in the conv_block.
+ Default: 2.
+ stride (int): Stride of convolutional layer in conv_block. Default: 1.
+ dilation (int): Dilation rate of convolutional layer in conv_block.
+ Default: 1.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ upsample_cfg (dict): The upsample config of the upsample module in
+ decoder. Default: dict(type='InterpConv'). If the size of
+ high-level feature map is the same as that of skip feature map
+ (low-level feature map from encoder), it does not need upsample the
+ high-level feature map and the upsample_cfg is None.
+ dcn (bool): Use deformable convolution in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+ """
+
+ def __init__(self,
+ conv_block,
+ in_channels,
+ skip_channels,
+ out_channels,
+ num_convs=2,
+ stride=1,
+ dilation=1,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ dcn=None,
+ plugins=None):
+ super().__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+
+ self.conv_block = conv_block(
+ in_channels=2 * skip_channels,
+ out_channels=out_channels,
+ num_convs=num_convs,
+ stride=stride,
+ dilation=dilation,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ dcn=None,
+ plugins=None)
+ if upsample_cfg is not None:
+ self.upsample = build_upsample_layer(
+ cfg=upsample_cfg,
+ in_channels=in_channels,
+ out_channels=skip_channels,
+ with_cp=with_cp,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ else:
+ self.upsample = ConvModule(
+ in_channels,
+ skip_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ def forward(self, skip, x):
+ """Forward function."""
+
+ x = self.upsample(x)
+ out = torch.cat([skip, x], dim=1)
+ out = self.conv_block(out)
+
+ return out
+
+
+class BasicConvBlock(nn.Module):
+ """Basic convolutional block for UNet.
+
+ This module consists of several plain convolutional layers.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ num_convs (int): Number of convolutional layers. Default: 2.
+ stride (int): Whether use stride convolution to downsample
+ the input feature map. If stride=2, it only uses stride convolution
+ in the first convolutional layer to downsample the input feature
+ map. Options are 1 or 2. Default: 1.
+ dilation (int): Whether use dilated convolution to expand the
+ receptive field. Set dilation rate of each convolutional layer and
+ the dilation rate of the first convolutional layer is always 1.
+ Default: 1.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ dcn (bool): Use deformable convolution in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_convs=2,
+ stride=1,
+ dilation=1,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ dcn=None,
+ plugins=None):
+ super().__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+
+ self.with_cp = with_cp
+ convs = []
+ for i in range(num_convs):
+ convs.append(
+ ConvModule(
+ in_channels=in_channels if i == 0 else out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride if i == 0 else 1,
+ dilation=1 if i == 0 else dilation,
+ padding=1 if i == 0 else dilation,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ self.convs = nn.Sequential(*convs)
+
+ def forward(self, x):
+ """Forward function."""
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.convs, x)
+ else:
+ out = self.convs(x)
+ return out
+
+
+@UPSAMPLE_LAYERS.register_module()
+class DeconvModule(nn.Module):
+ """Deconvolution upsample module in decoder for UNet (2X upsample).
+
+ This module uses deconvolution to upsample feature map in the decoder
+ of UNet.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ kernel_size (int): Kernel size of the convolutional layer. Default: 4.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ with_cp=False,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ *,
+ kernel_size=4,
+ scale_factor=2):
+ super().__init__()
+
+ assert (
+ kernel_size - scale_factor >= 0
+ and (kernel_size - scale_factor) % 2 == 0), (
+ f'kernel_size should be greater than or equal to scale_factor '
+ f'and (kernel_size - scale_factor) should be even numbers, '
+ f'while the kernel size is {kernel_size} and scale_factor is '
+ f'{scale_factor}.')
+
+ stride = scale_factor
+ padding = (kernel_size - scale_factor) // 2
+ self.with_cp = with_cp
+ deconv = nn.ConvTranspose2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding)
+
+ _, norm = build_norm_layer(norm_cfg, out_channels)
+ activate = build_activation_layer(act_cfg)
+ self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
+
+ def forward(self, x):
+ """Forward function."""
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.deconv_upsamping, x)
+ else:
+ out = self.deconv_upsamping(x)
+ return out
+
+
+@UPSAMPLE_LAYERS.register_module()
+class InterpConv(nn.Module):
+ """Interpolation upsample module in decoder for UNet.
+
+ This module uses interpolation to upsample feature map in the decoder
+ of UNet. It consists of one interpolation upsample layer and one
+ convolutional layer. It can be one interpolation upsample layer followed
+ by one convolutional layer (conv_first=False) or one convolutional layer
+ followed by one interpolation upsample layer (conv_first=True).
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ conv_first (bool): Whether convolutional layer or interpolation
+ upsample layer first. Default: False. It means interpolation
+ upsample layer followed by one convolutional layer.
+ kernel_size (int): Kernel size of the convolutional layer. Default: 1.
+ stride (int): Stride of the convolutional layer. Default: 1.
+ padding (int): Padding of the convolutional layer. Default: 1.
+ upsample_cfg (dict): Interpolation config of the upsample layer.
+ Default: dict(
+ scale_factor=2, mode='bilinear', align_corners=False).
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ with_cp=False,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ *,
+ conv_cfg=None,
+ conv_first=False,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ upsample_cfg=dict(
+ scale_factor=2, mode='bilinear', align_corners=False)):
+ super().__init__()
+
+ self.with_cp = with_cp
+ conv = ConvModule(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ upsample = nn.Upsample(**upsample_cfg)
+ if conv_first:
+ self.interp_upsample = nn.Sequential(conv, upsample)
+ else:
+ self.interp_upsample = nn.Sequential(upsample, conv)
+
+ def forward(self, x):
+ """Forward function."""
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.interp_upsample, x)
+ else:
+ out = self.interp_upsample(x)
+ return out
+
+
+@BACKBONES.register_module()
+class UNet(BaseModule):
+ """UNet backbone.
+ U-Net: Convolutional Networks for Biomedical Image Segmentation.
+ https://arxiv.org/pdf/1505.04597.pdf
+
+ Args:
+ in_channels (int): Number of input image channels. Default" 3.
+ base_channels (int): Number of base channels of each stage.
+ The output channels of the first stage. Default: 64.
+ num_stages (int): Number of stages in encoder, normally 5. Default: 5.
+ strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
+ len(strides) is equal to num_stages. Normally the stride of the
+ first stage in encoder is 1. If strides[i]=2, it uses stride
+ convolution to downsample in the correspondence encoder stage.
+ Default: (1, 1, 1, 1, 1).
+ enc_num_convs (Sequence[int]): Number of convolutional layers in the
+ convolution block of the correspondence encoder stage.
+ Default: (2, 2, 2, 2, 2).
+ dec_num_convs (Sequence[int]): Number of convolutional layers in the
+ convolution block of the correspondence decoder stage.
+ Default: (2, 2, 2, 2).
+ downsamples (Sequence[int]): Whether use MaxPool to downsample the
+ feature map after the first stage of encoder
+ (stages: [1, num_stages)). If the correspondence encoder stage use
+ stride convolution (strides[i]=2), it will never use MaxPool to
+ downsample, even downsamples[i-1]=True.
+ Default: (True, True, True, True).
+ enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
+ Default: (1, 1, 1, 1, 1).
+ dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
+ Default: (1, 1, 1, 1).
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ upsample_cfg (dict): The upsample config of the upsample module in
+ decoder. Default: dict(type='InterpConv').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ dcn (bool): Use deformable convolution in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+
+ Notice:
+ The input image size should be divisible by the whole downsample rate
+ of the encoder. More detail of the whole downsample rate can be found
+ in UNet._check_input_divisible.
+
+ """
+
+ def __init__(self,
+ in_channels=3,
+ base_channels=64,
+ num_stages=5,
+ strides=(1, 1, 1, 1, 1),
+ enc_num_convs=(2, 2, 2, 2, 2),
+ dec_num_convs=(2, 2, 2, 2),
+ downsamples=(True, True, True, True),
+ enc_dilations=(1, 1, 1, 1, 1),
+ dec_dilations=(1, 1, 1, 1),
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ norm_eval=False,
+ dcn=None,
+ plugins=None,
+ init_cfg=[
+ dict(type='Kaiming', layer='Conv2d'),
+ dict(
+ type='Constant',
+ layer=['_BatchNorm', 'GroupNorm'],
+ val=1)
+ ]):
+ super().__init__(init_cfg=init_cfg)
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+ assert len(strides) == num_stages, (
+ 'The length of strides should be equal to num_stages, '
+ f'while the strides is {strides}, the length of '
+ f'strides is {len(strides)}, and the num_stages is '
+ f'{num_stages}.')
+ assert len(enc_num_convs) == num_stages, (
+ 'The length of enc_num_convs should be equal to num_stages, '
+ f'while the enc_num_convs is {enc_num_convs}, the length of '
+ f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '
+ f'{num_stages}.')
+ assert len(dec_num_convs) == (num_stages - 1), (
+ 'The length of dec_num_convs should be equal to (num_stages-1), '
+ f'while the dec_num_convs is {dec_num_convs}, the length of '
+ f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '
+ f'{num_stages}.')
+ assert len(downsamples) == (num_stages - 1), (
+ 'The length of downsamples should be equal to (num_stages-1), '
+ f'while the downsamples is {downsamples}, the length of '
+ f'downsamples is {len(downsamples)}, and the num_stages is '
+ f'{num_stages}.')
+ assert len(enc_dilations) == num_stages, (
+ 'The length of enc_dilations should be equal to num_stages, '
+ f'while the enc_dilations is {enc_dilations}, the length of '
+ f'enc_dilations is {len(enc_dilations)}, and the num_stages is '
+ f'{num_stages}.')
+ assert len(dec_dilations) == (num_stages - 1), (
+ 'The length of dec_dilations should be equal to (num_stages-1), '
+ f'while the dec_dilations is {dec_dilations}, the length of '
+ f'dec_dilations is {len(dec_dilations)}, and the num_stages is '
+ f'{num_stages}.')
+ self.num_stages = num_stages
+ self.strides = strides
+ self.downsamples = downsamples
+ self.norm_eval = norm_eval
+ self.base_channels = base_channels
+
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+
+ for i in range(num_stages):
+ enc_conv_block = []
+ if i != 0:
+ if strides[i] == 1 and downsamples[i - 1]:
+ enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
+ upsample = (strides[i] != 1 or downsamples[i - 1])
+ self.decoder.append(
+ UpConvBlock(
+ conv_block=BasicConvBlock,
+ in_channels=base_channels * 2**i,
+ skip_channels=base_channels * 2**(i - 1),
+ out_channels=base_channels * 2**(i - 1),
+ num_convs=dec_num_convs[i - 1],
+ stride=1,
+ dilation=dec_dilations[i - 1],
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ upsample_cfg=upsample_cfg if upsample else None,
+ dcn=None,
+ plugins=None))
+
+ enc_conv_block.append(
+ BasicConvBlock(
+ in_channels=in_channels,
+ out_channels=base_channels * 2**i,
+ num_convs=enc_num_convs[i],
+ stride=strides[i],
+ dilation=enc_dilations[i],
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ dcn=None,
+ plugins=None))
+ self.encoder.append((nn.Sequential(*enc_conv_block)))
+ in_channels = base_channels * 2**i
+
+ def forward(self, x):
+ self._check_input_divisible(x)
+ enc_outs = []
+ for enc in self.encoder:
+ x = enc(x)
+ enc_outs.append(x)
+ dec_outs = [x]
+ for i in reversed(range(len(self.decoder))):
+ x = self.decoder[i](enc_outs[i], x)
+ dec_outs.append(x)
+
+ return dec_outs
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep normalization layer
+ freezed."""
+ super().train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+ def _check_input_divisible(self, x):
+ h, w = x.shape[-2:]
+ whole_downsample_rate = 1
+ for i in range(1, self.num_stages):
+ if self.strides[i] == 2 or self.downsamples[i - 1]:
+ whole_downsample_rate *= 2
+ assert (
+ h % whole_downsample_rate == 0 and w % whole_downsample_rate == 0
+ ), (f'The input image size {(h, w)} should be divisible by the whole '
+ f'downsample rate {whole_downsample_rate}, when num_stages is '
+ f'{self.num_stages}, strides is {self.strides}, and downsamples '
+ f'is {self.downsamples}.')
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/common/detectors/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/common/detectors/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..609824a1b0e67b0110b5b101151243bcd0e338ec
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/common/detectors/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .single_stage import SingleStageDetector
+
+__all__ = ['SingleStageDetector']
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/common/detectors/single_stage.py b/cv/ocr/dbnet/pytorch/dbnet/models/common/detectors/single_stage.py
new file mode 100755
index 0000000000000000000000000000000000000000..dad1d98bcff9825d9e2205c0c72e06e7c5c14e77
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/common/detectors/single_stage.py
@@ -0,0 +1,39 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+from dbnet_det.models.detectors import \
+ SingleStageDetector as dbnet_det_SingleStageDetector
+
+from dbnet.models.builder import (DETECTORS, build_backbone, build_head,
+ build_neck)
+
+
+@DETECTORS.register_module()
+class SingleStageDetector(dbnet_det_SingleStageDetector):
+ """Base class for single-stage detectors.
+
+ Single-stage detectors directly and densely predict bounding boxes on the
+ output features of the backbone+neck.
+ """
+
+ def __init__(self,
+ backbone,
+ neck=None,
+ bbox_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(dbnet_det_SingleStageDetector, self).__init__(init_cfg=init_cfg)
+ if pretrained:
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
+ 'please use "init_cfg" instead')
+ backbone.pretrained = pretrained
+ self.backbone = build_backbone(backbone)
+ if neck is not None:
+ self.neck = build_neck(neck)
+ bbox_head.update(train_cfg=train_cfg)
+ bbox_head.update(test_cfg=test_cfg)
+ self.bbox_head = build_head(bbox_head)
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/common/losses/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/common/losses/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..b34608a877888189ec539f3e5ddebadb77c203cf
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/common/losses/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .dice_loss import DiceLoss
+# from .focal_loss import FocalLoss
+
+# __all__ = ['DiceLoss', 'FocalLoss']
+__all__ = ['DiceLoss']
\ No newline at end of file
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/common/losses/dice_loss.py b/cv/ocr/dbnet/pytorch/dbnet/models/common/losses/dice_loss.py
new file mode 100755
index 0000000000000000000000000000000000000000..d2ace254f659bd3940f0b147738da01ced2dee51
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/common/losses/dice_loss.py
@@ -0,0 +1,31 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+from dbnet.models.builder import LOSSES
+
+
+@LOSSES.register_module()
+class DiceLoss(nn.Module):
+
+ def __init__(self, eps=1e-6):
+ super().__init__()
+ assert isinstance(eps, float)
+ self.eps = eps
+
+ def forward(self, pred, target, mask=None):
+
+ pred = pred.contiguous().view(pred.size()[0], -1)
+ target = target.contiguous().view(target.size()[0], -1)
+
+ if mask is not None:
+ mask = mask.contiguous().view(mask.size()[0], -1)
+ pred = pred * mask
+ target = target * mask
+
+ a = torch.sum(pred * target)
+ b = torch.sum(pred)
+ c = torch.sum(target)
+ d = (2 * a) / (b + c + self.eps)
+
+ return 1 - d
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/common/losses/focal_loss.py b/cv/ocr/dbnet/pytorch/dbnet/models/common/losses/focal_loss.py
new file mode 100755
index 0000000000000000000000000000000000000000..1a42ab013e278832fe6c8eed20f4a4c879f4d8cf
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/common/losses/focal_loss.py
@@ -0,0 +1,31 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class FocalLoss(nn.Module):
+ """Multi-class Focal loss implementation.
+
+ Args:
+ gamma (float): The larger the gamma, the smaller
+ the loss weight of easier samples.
+ weight (float): A manual rescaling weight given to each
+ class.
+ ignore_index (int): Specifies a target value that is ignored
+ and does not contribute to the input gradient.
+ """
+
+ def __init__(self, gamma=2, weight=None, ignore_index=-100):
+ super().__init__()
+ self.gamma = gamma
+ self.weight = weight
+ self.ignore_index = ignore_index
+
+ def forward(self, input, target):
+ logit = F.log_softmax(input, dim=1)
+ pt = torch.exp(logit)
+ logit = (1 - pt)**self.gamma * logit
+ loss = F.nll_loss(
+ logit, target, self.weight, ignore_index=self.ignore_index)
+ return loss
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..027e812f790d9572ec5d83b78ee9ce33a5ed415a
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from . import dense_heads, detectors, losses, necks, postprocess
+from .dense_heads import * # NOQA
+from .detectors import * # NOQA
+from .losses import * # NOQA
+from .necks import * # NOQA
+from .postprocess import * # NOQA
+
+__all__ = (
+ dense_heads.__all__ + detectors.__all__ + losses.__all__ + necks.__all__ +
+ postprocess.__all__)
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..fd782a503190bb7e437ca6220c761ec9464626ac
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .db_head import DBHead
+# from .drrg_head import DRRGHead
+# from .fce_head import FCEHead
+from .head_mixin import HeadMixin
+# from .pan_head import PANHead
+# from .pse_head import PSEHead
+# from .textsnake_head import TextSnakeHead
+
+
+__all__ = [
+ 'DBHead','HeadMixin'
+]
+
+# __all__ = [
+# 'PSEHead', 'PANHead', 'DBHead', 'FCEHead', 'TextSnakeHead', 'DRRGHead',
+# 'HeadMixin'
+# ]
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/db_head.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/db_head.py
new file mode 100755
index 0000000000000000000000000000000000000000..dc6a6899506e31c32228c9343ee49059a33201c6
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/db_head.py
@@ -0,0 +1,96 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+import torch.nn as nn
+from dbnet_cv.runner import BaseModule, Sequential
+
+from dbnet.models.builder import HEADS
+from .head_mixin import HeadMixin
+
+
+@HEADS.register_module()
+class DBHead(HeadMixin, BaseModule):
+ """The class for DBNet head.
+
+ This was partially adapted from https://github.com/MhLiao/DB
+
+ Args:
+ in_channels (int): The number of input channels of the db head.
+ with_bias (bool): Whether add bias in Conv2d layer.
+ downsample_ratio (float): The downsample ratio of ground truths.
+ loss (dict): Config of loss for dbnet.
+ postprocessor (dict): Config of postprocessor for dbnet.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ with_bias=False,
+ downsample_ratio=1.0,
+ loss=dict(type='DBLoss'),
+ postprocessor=dict(type='DBPostprocessor', text_repr_type='quad'),
+ init_cfg=[
+ dict(type='Kaiming', layer='Conv'),
+ dict(type='Constant', layer='BatchNorm', val=1., bias=1e-4)
+ ],
+ train_cfg=None,
+ test_cfg=None,
+ **kwargs):
+ old_keys = ['text_repr_type', 'decoding_type']
+ for key in old_keys:
+ if kwargs.get(key, None):
+ postprocessor[key] = kwargs.get(key)
+ warnings.warn(
+ f'{key} is deprecated, please specify '
+ 'it in postprocessor config dict. See '
+ 'https://github.com/open-mmlab/mmocr/pull/640'
+ ' for details.', UserWarning)
+ BaseModule.__init__(self, init_cfg=init_cfg)
+ HeadMixin.__init__(self, loss, postprocessor)
+
+ assert isinstance(in_channels, int)
+
+ self.in_channels = in_channels
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self.downsample_ratio = downsample_ratio
+
+ self.binarize = Sequential(
+ nn.Conv2d(
+ in_channels, in_channels // 4, 3, bias=with_bias, padding=1),
+ nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
+ nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid())
+
+ self.threshold = self._init_thr(in_channels)
+
+ def diff_binarize(self, prob_map, thr_map, k):
+ return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map)))
+
+ def forward(self, inputs):
+ """
+ Args:
+ inputs (Tensor): Shape (batch_size, hidden_size, h, w).
+
+ Returns:
+ Tensor: A tensor of the same shape as input.
+ """
+ prob_map = self.binarize(inputs)
+ thr_map = self.threshold(inputs)
+ binary_map = self.diff_binarize(prob_map, thr_map, k=50)
+ outputs = torch.cat((prob_map, thr_map, binary_map), dim=1)
+ return outputs
+
+ def _init_thr(self, inner_channels, bias=False):
+ in_channels = inner_channels
+
+ seq = Sequential(
+ nn.Conv2d(
+ in_channels, inner_channels // 4, 3, padding=1, bias=bias),
+ nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2),
+ nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), nn.Sigmoid())
+ return seq
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/head_mixin.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/head_mixin.py
new file mode 100755
index 0000000000000000000000000000000000000000..a25cce4ac064e31b1ba109fc6880abff86e97bd0
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/head_mixin.py
@@ -0,0 +1,91 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+
+from dbnet.models.builder import HEADS, build_loss, build_postprocessor
+from dbnet.utils import check_argument
+
+
+@HEADS.register_module()
+class HeadMixin:
+ """Base head class for text detection, including loss calcalation and
+ postprocess.
+
+ Args:
+ loss (dict): Config to build loss.
+ postprocessor (dict): Config to build postprocessor.
+ """
+
+ def __init__(self, loss, postprocessor):
+ assert isinstance(loss, dict)
+ assert isinstance(postprocessor, dict)
+
+ self.loss_module = build_loss(loss)
+ self.postprocessor = build_postprocessor(postprocessor)
+
+ def resize_boundary(self, boundaries, scale_factor):
+ """Rescale boundaries via scale_factor.
+
+ Args:
+ boundaries (list[list[float]]): The boundary list. Each boundary
+ has :math:`2k+1` elements with :math:`k>=4`.
+ scale_factor (ndarray): The scale factor of size :math:`(4,)`.
+
+ Returns:
+ list[list[float]]: The scaled boundaries.
+ """
+ assert check_argument.is_2dlist(boundaries)
+ assert isinstance(scale_factor, np.ndarray)
+ assert scale_factor.shape[0] == 4
+
+ for b in boundaries:
+ sz = len(b)
+ check_argument.valid_boundary(b, True)
+ b[:sz -
+ 1] = (np.array(b[:sz - 1]) *
+ (np.tile(scale_factor[:2], int(
+ (sz - 1) / 2)).reshape(1, sz - 1))).flatten().tolist()
+ return boundaries
+
+ def get_boundary(self, score_maps, img_metas, rescale):
+ """Compute text boundaries via post processing.
+
+ Args:
+ score_maps (Tensor): The text score map.
+ img_metas (dict): The image meta info.
+ rescale (bool): Rescale boundaries to the original image resolution
+ if true, and keep the score_maps resolution if false.
+
+ Returns:
+ dict: A dict where boundary results are stored in
+ ``boundary_result``.
+ """
+
+ assert check_argument.is_type_list(img_metas, dict)
+ assert isinstance(rescale, bool)
+
+ score_maps = score_maps.squeeze()
+ boundaries = self.postprocessor(score_maps)
+
+ if rescale:
+ boundaries = self.resize_boundary(
+ boundaries,
+ 1.0 / self.downsample_ratio / img_metas[0]['scale_factor'])
+
+ results = dict(
+ boundary_result=boundaries, filename=img_metas[0]['filename'])
+
+ return results
+
+ def loss(self, pred_maps, **kwargs):
+ """Compute the loss for scene text detection.
+
+ Args:
+ pred_maps (Tensor): The input score maps of shape
+ :math:`(NxCxHxW)`.
+
+ Returns:
+ dict: The dict for losses.
+ """
+ losses = self.loss_module(pred_maps, self.downsample_ratio, **kwargs)
+
+ return losses
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..89f0f98e370dbbe16da25450be39040863ef99b6
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .dbnet import DBNet
+# from .drrg import DRRG
+# from .fcenet import FCENet
+# from .ocr_mask_rcnn import OCRMaskRCNN
+# from .panet import PANet
+# from .psenet import PSENet
+from .single_stage_text_detector import SingleStageTextDetector
+from .text_detector_mixin import TextDetectorMixin
+# from .textsnake import TextSnake
+
+
+# __all__ = [
+# 'TextDetectorMixin', 'SingleStageTextDetector', 'OCRMaskRCNN', 'DBNet',
+# 'PANet', 'PSENet', 'TextSnake', 'FCENet', 'DRRG'
+# ]
+__all__ = [
+ 'TextDetectorMixin', 'SingleStageTextDetector', 'DBNet'
+ ]
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/dbnet.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/dbnet.py
new file mode 100755
index 0000000000000000000000000000000000000000..c598252e21c357fa00124b4688922ad91e521b64
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/dbnet.py
@@ -0,0 +1,27 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from dbnet.models.builder import DETECTORS
+from .single_stage_text_detector import SingleStageTextDetector
+from .text_detector_mixin import TextDetectorMixin
+
+
+@DETECTORS.register_module()
+class DBNet(TextDetectorMixin, SingleStageTextDetector):
+ """The class for implementing DBNet text detector: Real-time Scene Text
+ Detection with Differentiable Binarization.
+
+ [https://arxiv.org/abs/1911.08947].
+ """
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ show_score=False,
+ init_cfg=None):
+ SingleStageTextDetector.__init__(self, backbone, neck, bbox_head,
+ train_cfg, test_cfg, pretrained,
+ init_cfg)
+ TextDetectorMixin.__init__(self, show_score)
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/single_stage_text_detector.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/single_stage_text_detector.py
new file mode 100755
index 0000000000000000000000000000000000000000..8b2b28a8f0eb8c84ff2266af2fd14dd0675d6bf2
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/single_stage_text_detector.py
@@ -0,0 +1,61 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from dbnet.models.builder import DETECTORS
+from dbnet.models.common.detectors import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class SingleStageTextDetector(SingleStageDetector):
+ """The class for implementing single stage text detector."""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ SingleStageDetector.__init__(self, backbone, neck, bbox_head,
+ train_cfg, test_cfg, pretrained, init_cfg)
+
+ def forward_train(self, img, img_metas, **kwargs):
+ """
+ Args:
+ img (Tensor): Input images of shape (N, C, H, W).
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): A list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys, see
+ :class:`dbnet_det.datasets.pipelines.Collect`.
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ x = self.extract_feat(img)
+ preds = self.bbox_head(x)
+ losses = self.bbox_head.loss(preds, **kwargs)
+ return losses
+
+ def simple_test(self, img, img_metas, rescale=False):
+ x = self.extract_feat(img)
+ outs = self.bbox_head(x)
+
+ # early return to avoid post processing
+ if torch.onnx.is_in_onnx_export():
+ return outs
+
+ if len(img_metas) > 1:
+ boundaries = [
+ self.bbox_head.get_boundary(*(outs[i].unsqueeze(0)),
+ [img_metas[i]], rescale)
+ for i in range(len(img_metas))
+ ]
+
+ else:
+ boundaries = [
+ self.bbox_head.get_boundary(*outs, img_metas, rescale)
+ ]
+
+ return boundaries
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/text_detector_mixin.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/text_detector_mixin.py
new file mode 100755
index 0000000000000000000000000000000000000000..376517a8df4da1c2e567e8dfd4cd76644d9a7208
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/text_detector_mixin.py
@@ -0,0 +1,81 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import dbnet_cv
+
+from dbnet.core import imshow_pred_boundary
+
+
+class TextDetectorMixin:
+ """Base class for text detector, only to show results.
+
+ Args:
+ show_score (bool): Whether to show text instance score.
+ """
+
+ def __init__(self, show_score):
+ self.show_score = show_score
+
+ def show_result(self,
+ img,
+ result,
+ score_thr=0.5,
+ bbox_color='green',
+ text_color='green',
+ thickness=1,
+ font_scale=0.5,
+ win_name='',
+ show=False,
+ wait_time=0,
+ out_file=None):
+ """Draw `result` over `img`.
+
+ Args:
+ img (str or Tensor): The image to be displayed.
+ result (dict): The results to draw over `img`.
+ score_thr (float, optional): Minimum score of bboxes to be shown.
+ Default: 0.3.
+ bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
+ text_color (str or tuple or :obj:`Color`): Color of texts.
+ thickness (int): Thickness of lines.
+ font_scale (float): Font scales of texts.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ Default: 0.
+ show (bool): Whether to show the image.
+ Default: False.
+ out_file (str or None): The filename to write the image.
+ Default: None.imshow_pred_boundary`
+ """
+ img = dbnet_cv.imread(img)
+ img = img.copy()
+ boundaries = None
+ labels = None
+ if 'boundary_result' in result.keys():
+ boundaries = result['boundary_result']
+ labels = [0] * len(boundaries)
+
+ # if out_file specified, do not show image in window
+ if out_file is not None:
+ show = False
+ # draw bounding boxes
+ if boundaries is not None:
+ imshow_pred_boundary(
+ img,
+ boundaries,
+ labels,
+ score_thr=score_thr,
+ boundary_color=bbox_color,
+ text_color=text_color,
+ thickness=thickness,
+ font_scale=font_scale,
+ win_name=win_name,
+ show=show,
+ wait_time=wait_time,
+ out_file=out_file,
+ show_score=self.show_score)
+
+ if not (show or out_file):
+ warnings.warn('show==False and out_file is not specified, '
+ 'result image will be returned')
+ return img
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/losses/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/losses/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..14e6b5d1d5d967e2bb429d21578cf7ac02a1ba52
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/losses/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .db_loss import DBLoss
+# from .drrg_loss import DRRGLoss
+# from .fce_loss import FCELoss
+# from .pan_loss import PANLoss
+# from .pse_loss import PSELoss
+# from .textsnake_loss import TextSnakeLoss
+
+# __all__ = [
+# 'PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss', 'FCELoss', 'DRRGLoss'
+# ]
+__all__ = [
+ 'DBLoss']
\ No newline at end of file
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/losses/db_loss.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/losses/db_loss.py
new file mode 100755
index 0000000000000000000000000000000000000000..b9c32ef579ae030f568a964fcb630eebeb1511cb
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/losses/db_loss.py
@@ -0,0 +1,166 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from dbnet.models.builder import LOSSES
+from dbnet.models.common.losses.dice_loss import DiceLoss
+
+
+@LOSSES.register_module()
+class DBLoss(nn.Module):
+ """The class for implementing DBNet loss.
+
+ This is partially adapted from https://github.com/MhLiao/DB.
+
+ Args:
+ alpha (float): The binary loss coef.
+ beta (float): The threshold loss coef.
+ reduction (str): The way to reduce the loss.
+ negative_ratio (float): The ratio of positives to negatives.
+ eps (float): Epsilon in the threshold loss function.
+ bbce_loss (bool): Whether to use balanced bce for probability loss.
+ If False, dice loss will be used instead.
+ """
+
+ def __init__(self,
+ alpha=1,
+ beta=1,
+ reduction='mean',
+ negative_ratio=3.0,
+ eps=1e-6,
+ bbce_loss=False):
+ super().__init__()
+ assert reduction in ['mean',
+ 'sum'], " reduction must in ['mean','sum']"
+ self.alpha = alpha
+ self.beta = beta
+ self.reduction = reduction
+ self.negative_ratio = negative_ratio
+ self.eps = eps
+ self.bbce_loss = bbce_loss
+ self.dice_loss = DiceLoss(eps=eps)
+
+ def bitmasks2tensor(self, bitmasks, target_sz):
+ """Convert Bitmasks to tensor.
+
+ Args:
+ bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is
+ for one img.
+ target_sz (tuple(int, int)): The target tensor of size
+ :math:`(H, W)`.
+
+ Returns:
+ list[Tensor]: The list of kernel tensors. Each element stands for
+ one kernel level.
+ """
+ assert isinstance(bitmasks, list)
+ assert isinstance(target_sz, tuple)
+
+ batch_size = len(bitmasks)
+ num_levels = len(bitmasks[0])
+
+ result_tensors = []
+
+ for level_inx in range(num_levels):
+ kernel = []
+ for batch_inx in range(batch_size):
+ mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx])
+ mask_sz = mask.shape
+ pad = [
+ 0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]
+ ]
+ mask = F.pad(mask, pad, mode='constant', value=0)
+ kernel.append(mask)
+ kernel = torch.stack(kernel)
+ result_tensors.append(kernel)
+
+ return result_tensors
+
+ def balance_bce_loss(self, pred, gt, mask):
+
+ positive = (gt * mask)
+ negative = ((1 - gt) * mask)
+ positive_count = int(positive.float().sum())
+ negative_count = min(
+ int(negative.float().sum()),
+ int(positive_count * self.negative_ratio))
+
+ assert gt.max() <= 1 and gt.min() >= 0
+ assert pred.max() <= 1 and pred.min() >= 0
+
+ loss = F.binary_cross_entropy_with_logits(pred, gt, reduction='none')
+ positive_loss = loss * positive.float()
+ negative_loss = loss * negative.float()
+
+ negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
+
+ balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
+ positive_count + negative_count + self.eps)
+
+ return balance_loss
+
+ def l1_thr_loss(self, pred, gt, mask):
+ thr_loss = torch.abs((pred - gt) * mask).sum() / (
+ mask.sum() + self.eps)
+ return thr_loss
+
+ def forward(self, preds, downsample_ratio, gt_shrink, gt_shrink_mask,
+ gt_thr, gt_thr_mask):
+ """Compute DBNet loss.
+
+ Args:
+ preds (Tensor): The output tensor with size :math:`(N, 3, H, W)`.
+ downsample_ratio (float): The downsample ratio for the
+ ground truths.
+ gt_shrink (list[BitmapMasks]): The mask list with each element
+ being the shrunk text mask for one img.
+ gt_shrink_mask (list[BitmapMasks]): The effective mask list with
+ each element being the shrunk effective mask for one img.
+ gt_thr (list[BitmapMasks]): The mask list with each element
+ being the threshold text mask for one img.
+ gt_thr_mask (list[BitmapMasks]): The effective mask list with
+ each element being the threshold effective mask for one img.
+
+ Returns:
+ dict: The dict for dbnet losses with "loss_prob", "loss_db" and
+ "loss_thresh".
+ """
+ assert isinstance(downsample_ratio, float)
+
+ assert isinstance(gt_shrink, list)
+ assert isinstance(gt_shrink_mask, list)
+ assert isinstance(gt_thr, list)
+ assert isinstance(gt_thr_mask, list)
+
+ pred_prob = preds[:, 0, :, :]
+ pred_thr = preds[:, 1, :, :]
+ pred_db = preds[:, 2, :, :]
+ feature_sz = preds.size()
+
+ keys = ['gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask']
+ gt = {}
+ for k in keys:
+ gt[k] = eval(k)
+ gt[k] = [item.rescale(downsample_ratio) for item in gt[k]]
+ gt[k] = self.bitmasks2tensor(gt[k], feature_sz[2:])
+ gt[k] = [item.to(preds.device) for item in gt[k]]
+ gt['gt_shrink'][0] = (gt['gt_shrink'][0] > 0).float()
+ if self.bbce_loss:
+ loss_prob = self.balance_bce_loss(pred_prob, gt['gt_shrink'][0],
+ gt['gt_shrink_mask'][0])
+ else:
+ loss_prob = self.dice_loss(pred_prob, gt['gt_shrink'][0],
+ gt['gt_shrink_mask'][0])
+
+ loss_db = self.dice_loss(pred_db, gt['gt_shrink'][0],
+ gt['gt_shrink_mask'][0])
+ loss_thr = self.l1_thr_loss(pred_thr, gt['gt_thr'][0],
+ gt['gt_thr_mask'][0])
+
+ results = dict(
+ loss_prob=self.alpha * loss_prob,
+ loss_db=loss_db,
+ loss_thr=self.beta * loss_thr)
+
+ return results
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/necks/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/necks/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..f0be483e7d6dfc4a6de309a40e8cf58cdb629a35
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/necks/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# from .fpem_ffm import FPEM_FFM
+from .fpn_cat import FPNC
+# from .fpn_unet import FPN_UNet
+# from .fpnf import FPNF
+
+# __all__ = ['FPEM_FFM', 'FPNF', 'FPNC', 'FPN_UNet']
+__all__ = [ 'FPNC']
\ No newline at end of file
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/necks/fpn_cat.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/necks/fpn_cat.py
new file mode 100755
index 0000000000000000000000000000000000000000..93285e1c732e574a1ffdc6f3e3701025c973f04e
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/necks/fpn_cat.py
@@ -0,0 +1,271 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from dbnet_cv.cnn import ConvModule
+from dbnet_cv.runner import BaseModule, ModuleList, Sequential, auto_fp16
+
+from dbnet.models.builder import NECKS
+
+
+@NECKS.register_module()
+class FPNC(BaseModule):
+ """FPN-like fusion module in Real-time Scene Text Detection with
+ Differentiable Binarization.
+
+ This was partially adapted from https://github.com/MhLiao/DB and
+ https://github.com/WenmuZhou/DBNet.pytorch.
+
+ Args:
+ in_channels (list[int]): A list of numbers of input channels.
+ lateral_channels (int): Number of channels for lateral layers.
+ out_channels (int): Number of output channels.
+ bias_on_lateral (bool): Whether to use bias on lateral convolutional
+ layers.
+ bn_re_on_lateral (bool): Whether to use BatchNorm and ReLU
+ on lateral convolutional layers.
+ bias_on_smooth (bool): Whether to use bias on smoothing layer.
+ bn_re_on_smooth (bool): Whether to use BatchNorm and ReLU on smoothing
+ layer.
+ asf_cfg (dict): Adaptive Scale Fusion module configs. The
+ attention_type can be 'ScaleChannelSpatial'.
+ conv_after_concat (bool): Whether to add a convolution layer after
+ the concatenation of predictions.
+ init_cfg (dict or list[dict], optional): Initialization configs.
+ """
+
+ def __init__(self,
+ in_channels,
+ lateral_channels=256,
+ out_channels=64,
+ bias_on_lateral=False,
+ bn_re_on_lateral=False,
+ bias_on_smooth=False,
+ bn_re_on_smooth=False,
+ asf_cfg=None,
+ conv_after_concat=False,
+ init_cfg=[
+ dict(type='Kaiming', layer='Conv'),
+ dict(
+ type='Constant', layer='BatchNorm', val=1., bias=1e-4)
+ ]):
+ super().__init__(init_cfg=init_cfg)
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.lateral_channels = lateral_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.bn_re_on_lateral = bn_re_on_lateral
+ self.bn_re_on_smooth = bn_re_on_smooth
+ self.asf_cfg = asf_cfg
+ self.conv_after_concat = conv_after_concat
+ self.lateral_convs = ModuleList()
+ self.smooth_convs = ModuleList()
+ self.num_outs = self.num_ins
+
+ for i in range(self.num_ins):
+ norm_cfg = None
+ act_cfg = None
+ if self.bn_re_on_lateral:
+ norm_cfg = dict(type='BN')
+ act_cfg = dict(type='ReLU')
+ l_conv = ConvModule(
+ in_channels[i],
+ lateral_channels,
+ 1,
+ bias=bias_on_lateral,
+ conv_cfg=None,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+ norm_cfg = None
+ act_cfg = None
+ if self.bn_re_on_smooth:
+ norm_cfg = dict(type='BN')
+ act_cfg = dict(type='ReLU')
+
+ smooth_conv = ConvModule(
+ lateral_channels,
+ out_channels,
+ 3,
+ bias=bias_on_smooth,
+ padding=1,
+ conv_cfg=None,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+
+ self.lateral_convs.append(l_conv)
+ self.smooth_convs.append(smooth_conv)
+
+ if self.asf_cfg is not None:
+ self.asf_conv = ConvModule(
+ out_channels * self.num_outs,
+ out_channels * self.num_outs,
+ 3,
+ padding=1,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=None,
+ inplace=False)
+ if self.asf_cfg['attention_type'] == 'ScaleChannelSpatial':
+ self.asf_attn = ScaleChannelSpatialAttention(
+ self.out_channels * self.num_outs,
+ (self.out_channels * self.num_outs) // 4, self.num_outs)
+ else:
+ raise NotImplementedError
+
+ if self.conv_after_concat:
+ norm_cfg = dict(type='BN')
+ act_cfg = dict(type='ReLU')
+ self.out_conv = ConvModule(
+ out_channels * self.num_outs,
+ out_channels * self.num_outs,
+ 3,
+ padding=1,
+ conv_cfg=None,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+
+ @auto_fp16()
+ def forward(self, inputs):
+ """
+ Args:
+ inputs (list[Tensor]): Each tensor has the shape of
+ :math:`(N, C_i, H_i, W_i)`. It usually expects 4 tensors
+ (C2-C5 features) from ResNet.
+
+ Returns:
+ Tensor: A tensor of shape :math:`(N, C_{out}, H_0, W_0)` where
+ :math:`C_{out}` is ``out_channels``.
+ """
+ assert len(inputs) == len(self.in_channels)
+ # build laterals
+ laterals = [
+ lateral_conv(inputs[i])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+ used_backbone_levels = len(laterals)
+ # build top-down path
+ for i in range(used_backbone_levels - 1, 0, -1):
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] = laterals[i - 1] + F.interpolate(
+ laterals[i], size=prev_shape, mode='nearest')
+ # build outputs
+ # part 1: from original levels
+ outs = [
+ self.smooth_convs[i](laterals[i])
+ for i in range(used_backbone_levels)
+ ]
+
+ for i, out in enumerate(outs):
+ outs[i] = F.interpolate(
+ outs[i], size=outs[0].shape[2:], mode='nearest')
+
+ out = torch.cat(outs, dim=1)
+ if self.asf_cfg is not None:
+ asf_feature = self.asf_conv(out)
+ attention = self.asf_attn(asf_feature)
+ enhanced_feature = []
+ for i, out in enumerate(outs):
+ enhanced_feature.append(attention[:, i:i + 1] * outs[i])
+ out = torch.cat(enhanced_feature, dim=1)
+
+ if self.conv_after_concat:
+ out = self.out_conv(out)
+
+ return out
+
+
+class ScaleChannelSpatialAttention(BaseModule):
+ """Spatial Attention module in Real-Time Scene Text Detection with
+ Differentiable Binarization and Adaptive Scale Fusion.
+
+ This was partially adapted from https://github.com/MhLiao/DB
+
+ Args:
+ in_channels (int): A numbers of input channels.
+ c_wise_channels (int): Number of channel-wise attention channels.
+ out_channels (int): Number of output channels.
+ init_cfg (dict or list[dict], optional): Initialization configs.
+ """
+
+ def __init__(self,
+ in_channels,
+ c_wise_channels,
+ out_channels,
+ init_cfg=[dict(type='Kaiming', layer='Conv', bias=0)]):
+ super().__init__(init_cfg=init_cfg)
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ # Channel Wise
+ self.channel_wise = Sequential(
+ ConvModule(
+ in_channels,
+ c_wise_channels,
+ 1,
+ bias=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ inplace=False),
+ ConvModule(
+ c_wise_channels,
+ in_channels,
+ 1,
+ bias=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='Sigmoid'),
+ inplace=False))
+ # Spatial Wise
+ self.spatial_wise = Sequential(
+ ConvModule(
+ 1,
+ 1,
+ 3,
+ padding=1,
+ bias=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ inplace=False),
+ ConvModule(
+ 1,
+ 1,
+ 1,
+ bias=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='Sigmoid'),
+ inplace=False))
+ # Attention Wise
+ self.attention_wise = ConvModule(
+ in_channels,
+ out_channels,
+ 1,
+ bias=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='Sigmoid'),
+ inplace=False)
+
+ @auto_fp16()
+ def forward(self, inputs):
+ """
+ Args:
+ inputs (Tensor): A concat FPN feature tensor that has the shape of
+ :math:`(N, C, H, W)`.
+
+ Returns:
+ Tensor: An attention map of shape :math:`(N, C_{out}, H, W)`
+ where :math:`C_{out}` is ``out_channels``.
+ """
+ out = self.avg_pool(inputs)
+ out = self.channel_wise(out)
+ out = out + inputs
+ inputs = torch.mean(out, dim=1, keepdim=True)
+ out = self.spatial_wise(inputs) + out
+ out = self.attention_wise(out)
+
+ return out
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..dcfc2e69990b841ad2c56f1436a36b07b8fa7c64
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base_postprocessor import BasePostprocessor
+from .db_postprocessor import DBPostprocessor
+# from .drrg_postprocessor import DRRGPostprocessor
+# from .fce_postprocessor import FCEPostprocessor
+# from .pan_postprocessor import PANPostprocessor
+# from .pse_postprocessor import PSEPostprocessor
+# from .textsnake_postprocessor import TextSnakePostprocessor
+
+
+# __all__ = [
+# 'BasePostprocessor', 'PSEPostprocessor', 'PANPostprocessor',
+# 'DBPostprocessor', 'DRRGPostprocessor', 'FCEPostprocessor',
+# 'TextSnakePostprocessor'
+# ]
+__all__ = [
+ 'BasePostprocessor',
+ 'DBPostprocessor' ]
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/base_postprocessor.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/base_postprocessor.py
new file mode 100755
index 0000000000000000000000000000000000000000..734f87b6d1783fbe7cb8f12a74a6d12d734a30ad
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/base_postprocessor.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+
+class BasePostprocessor:
+
+ def __init__(self, text_repr_type='poly'):
+ assert text_repr_type in ['poly', 'quad'
+ ], f'Invalid text repr type {text_repr_type}'
+
+ self.text_repr_type = text_repr_type
+
+ def is_valid_instance(self, area, confidence, area_thresh,
+ confidence_thresh):
+
+ return bool(area >= area_thresh and confidence > confidence_thresh)
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/db_postprocessor.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/db_postprocessor.py
new file mode 100755
index 0000000000000000000000000000000000000000..3996facf0277dd13f3493d642d631db3caec3b26
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/db_postprocessor.py
@@ -0,0 +1,94 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+
+from dbnet.core import points2boundary
+from dbnet.models.builder import POSTPROCESSOR
+from .base_postprocessor import BasePostprocessor
+from .utils import box_score_fast, unclip
+
+
+@POSTPROCESSOR.register_module()
+class DBPostprocessor(BasePostprocessor):
+ """Decoding predictions of DbNet to instances. This is partially adapted
+ from https://github.com/MhLiao/DB.
+
+ Args:
+ text_repr_type (str): The boundary encoding type 'poly' or 'quad'.
+ mask_thr (float): The mask threshold value for binarization.
+ min_text_score (float): The threshold value for converting binary map
+ to shrink text regions.
+ min_text_width (int): The minimum width of boundary polygon/box
+ predicted.
+ unclip_ratio (float): The unclip ratio for text regions dilation.
+ epsilon_ratio (float): The epsilon ratio for approximation accuracy.
+ max_candidates (int): The maximum candidate number.
+ """
+
+ def __init__(self,
+ text_repr_type='poly',
+ mask_thr=0.3,
+ min_text_score=0.3,
+ min_text_width=5,
+ unclip_ratio=1.5,
+ epsilon_ratio=0.01,
+ max_candidates=3000,
+ **kwargs):
+ super().__init__(text_repr_type)
+ self.mask_thr = mask_thr
+ self.min_text_score = min_text_score
+ self.min_text_width = min_text_width
+ self.unclip_ratio = unclip_ratio
+ self.epsilon_ratio = epsilon_ratio
+ self.max_candidates = max_candidates
+
+ def __call__(self, preds):
+ """
+ Args:
+ preds (Tensor): Prediction map with shape :math:`(C, H, W)`.
+
+ Returns:
+ list[list[float]]: The predicted text boundaries.
+ """
+ assert preds.dim() == 3
+
+ prob_map = preds[0, :, :]
+ text_mask = prob_map > self.mask_thr
+
+ score_map = prob_map.data.cpu().numpy().astype(np.float32)
+ text_mask = text_mask.data.cpu().numpy().astype(np.uint8) # to numpy
+
+ contours, _ = cv2.findContours((text_mask * 255).astype(np.uint8),
+ cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
+
+ boundaries = []
+ for i, poly in enumerate(contours):
+ if i > self.max_candidates:
+ break
+ epsilon = self.epsilon_ratio * cv2.arcLength(poly, True)
+ approx = cv2.approxPolyDP(poly, epsilon, True)
+ points = approx.reshape((-1, 2))
+ if points.shape[0] < 4:
+ continue
+ score = box_score_fast(score_map, points)
+ if score < self.min_text_score:
+ continue
+ poly = unclip(points, unclip_ratio=self.unclip_ratio)
+ if len(poly) == 0 or isinstance(poly[0], list):
+ continue
+ poly = poly.reshape(-1, 2)
+
+ if self.text_repr_type == 'quad':
+ poly = points2boundary(poly, self.text_repr_type, score,
+ self.min_text_width)
+ elif self.text_repr_type == 'poly':
+ poly = poly.flatten().tolist()
+ if score is not None:
+ poly = poly + [score]
+ if len(poly) < 8:
+ poly = None
+
+ if poly is not None:
+ boundaries.append(poly)
+
+ return boundaries
diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/utils.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..50fbaec1383f0d925fb4a2e198e76570592a9683
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/utils.py
@@ -0,0 +1,482 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+import operator
+
+import cv2
+import numpy as np
+import pyclipper
+from numpy.fft import ifft
+from numpy.linalg import norm
+from shapely.geometry import Polygon
+
+from dbnet.core.evaluation.utils import boundary_iou
+
+
+def filter_instance(area, confidence, min_area, min_confidence):
+ return bool(area < min_area or confidence < min_confidence)
+
+
+def box_score_fast(bitmap, _box):
+ h, w = bitmap.shape[:2]
+ box = _box.copy()
+ xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1)
+ xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int32), 0, w - 1)
+ ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int32), 0, h - 1)
+ ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int32), 0, h - 1)
+
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
+ box[:, 0] = box[:, 0] - xmin
+ box[:, 1] = box[:, 1] - ymin
+ cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
+ return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
+
+
+def unclip(box, unclip_ratio=1.5):
+ poly = Polygon(box)
+ distance = poly.area * unclip_ratio / poly.length
+ offset = pyclipper.PyclipperOffset()
+ offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+ expanded = np.array(offset.Execute(distance))
+ return expanded
+
+
+def fill_hole(input_mask):
+ h, w = input_mask.shape
+ canvas = np.zeros((h + 2, w + 2), np.uint8)
+ canvas[1:h + 1, 1:w + 1] = input_mask.copy()
+
+ mask = np.zeros((h + 4, w + 4), np.uint8)
+
+ cv2.floodFill(canvas, mask, (0, 0), 1)
+ canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool)
+
+ return ~canvas | input_mask
+
+
+def centralize(points_yx,
+ normal_sin,
+ normal_cos,
+ radius,
+ contour_mask,
+ step_ratio=0.03):
+
+ h, w = contour_mask.shape
+ top_yx = bot_yx = points_yx
+ step_flags = np.ones((len(points_yx), 1), dtype=np.bool)
+ step = step_ratio * radius * np.hstack([normal_sin, normal_cos])
+ while np.any(step_flags):
+ next_yx = np.array(top_yx + step, dtype=np.int32)
+ next_y, next_x = next_yx[:, 0], next_yx[:, 1]
+ step_flags = (next_y >= 0) & (next_y < h) & (next_x > 0) & (
+ next_x < w) & contour_mask[np.clip(next_y, 0, h - 1),
+ np.clip(next_x, 0, w - 1)]
+ top_yx = top_yx + step_flags.reshape((-1, 1)) * step
+ step_flags = np.ones((len(points_yx), 1), dtype=np.bool)
+ while np.any(step_flags):
+ next_yx = np.array(bot_yx - step, dtype=np.int32)
+ next_y, next_x = next_yx[:, 0], next_yx[:, 1]
+ step_flags = (next_y >= 0) & (next_y < h) & (next_x > 0) & (
+ next_x < w) & contour_mask[np.clip(next_y, 0, h - 1),
+ np.clip(next_x, 0, w - 1)]
+ bot_yx = bot_yx - step_flags.reshape((-1, 1)) * step
+ centers = np.array((top_yx + bot_yx) * 0.5, dtype=np.int32)
+ return centers
+
+
+def merge_disks(disks, disk_overlap_thr):
+ xy = disks[:, 0:2]
+ radius = disks[:, 2]
+ scores = disks[:, 3]
+ order = scores.argsort()[::-1]
+
+ merged_disks = []
+ while order.size > 0:
+ if order.size == 1:
+ merged_disks.append(disks[order])
+ break
+ i = order[0]
+ d = norm(xy[i] - xy[order[1:]], axis=1)
+ ri = radius[i]
+ r = radius[order[1:]]
+ d_thr = (ri + r) * disk_overlap_thr
+
+ merge_inds = np.where(d <= d_thr)[0] + 1
+ if merge_inds.size > 0:
+ merge_order = np.hstack([i, order[merge_inds]])
+ merged_disks.append(np.mean(disks[merge_order], axis=0))
+ else:
+ merged_disks.append(disks[i])
+
+ inds = np.where(d > d_thr)[0] + 1
+ order = order[inds]
+ merged_disks = np.vstack(merged_disks)
+
+ return merged_disks
+
+
+def poly_nms(polygons, threshold):
+ assert isinstance(polygons, list)
+
+ polygons = np.array(sorted(polygons, key=lambda x: x[-1]))
+
+ keep_poly = []
+ index = [i for i in range(polygons.shape[0])]
+
+ while len(index) > 0:
+ keep_poly.append(polygons[index[-1]].tolist())
+ A = polygons[index[-1]][:-1]
+ index = np.delete(index, -1)
+
+ iou_list = np.zeros((len(index), ))
+ for i in range(len(index)):
+ B = polygons[index[i]][:-1]
+
+ iou_list[i] = boundary_iou(A, B, 1)
+ remove_index = np.where(iou_list > threshold)
+ index = np.delete(index, remove_index)
+
+ return keep_poly
+
+
+def fourier2poly(fourier_coeff, num_reconstr_points=50):
+ """ Inverse Fourier transform
+ Args:
+ fourier_coeff (ndarray): Fourier coefficients shaped (n, 2k+1),
+ with n and k being candidates number and Fourier degree
+ respectively.
+ num_reconstr_points (int): Number of reconstructed polygon points.
+ Returns:
+ Polygons (ndarray): The reconstructed polygons shaped (n, n')
+ """
+
+ a = np.zeros((len(fourier_coeff), num_reconstr_points), dtype='complex')
+ k = (len(fourier_coeff[0]) - 1) // 2
+
+ a[:, 0:k + 1] = fourier_coeff[:, k:]
+ a[:, -k:] = fourier_coeff[:, :k]
+
+ poly_complex = ifft(a) * num_reconstr_points
+ polygon = np.zeros((len(fourier_coeff), num_reconstr_points, 2))
+ polygon[:, :, 0] = poly_complex.real
+ polygon[:, :, 1] = poly_complex.imag
+ return polygon.astype('int32').reshape((len(fourier_coeff), -1))
+
+
+class Node:
+
+ def __init__(self, ind):
+ self.__ind = ind
+ self.__links = set()
+
+ @property
+ def ind(self):
+ return self.__ind
+
+ @property
+ def links(self):
+ return set(self.__links)
+
+ def add_link(self, link_node):
+ self.__links.add(link_node)
+ link_node.__links.add(self)
+
+
+def graph_propagation(edges, scores, text_comps, edge_len_thr=50.):
+ """Propagate edge score information and construct graph. This code was
+ partially adapted from https://github.com/GXYM/DRRG licensed under the MIT
+ license.
+
+ Args:
+ edges (ndarray): The edge array of shape N * 2, each row is a node
+ index pair that makes up an edge in graph.
+ scores (ndarray): The edge score array.
+ text_comps (ndarray): The text components.
+ edge_len_thr (float): The edge length threshold.
+
+ Returns:
+ vertices (list[Node]): The Nodes in graph.
+ score_dict (dict): The edge score dict.
+ """
+ assert edges.ndim == 2
+ assert edges.shape[1] == 2
+ assert edges.shape[0] == scores.shape[0]
+ assert text_comps.ndim == 2
+ assert isinstance(edge_len_thr, float)
+
+ edges = np.sort(edges, axis=1)
+ score_dict = {}
+ for i, edge in enumerate(edges):
+ if text_comps is not None:
+ box1 = text_comps[edge[0], :8].reshape(4, 2)
+ box2 = text_comps[edge[1], :8].reshape(4, 2)
+ center1 = np.mean(box1, axis=0)
+ center2 = np.mean(box2, axis=0)
+ distance = norm(center1 - center2)
+ if distance > edge_len_thr:
+ scores[i] = 0
+ if (edge[0], edge[1]) in score_dict:
+ score_dict[edge[0], edge[1]] = 0.5 * (
+ score_dict[edge[0], edge[1]] + scores[i])
+ else:
+ score_dict[edge[0], edge[1]] = scores[i]
+
+ nodes = np.sort(np.unique(edges.flatten()))
+ mapping = -1 * np.ones((np.max(nodes) + 1), dtype=np.int)
+ mapping[nodes] = np.arange(nodes.shape[0])
+ order_inds = mapping[edges]
+ vertices = [Node(node) for node in nodes]
+ for ind in order_inds:
+ vertices[ind[0]].add_link(vertices[ind[1]])
+
+ return vertices, score_dict
+
+
+def connected_components(nodes, score_dict, link_thr):
+ """Conventional connected components searching. This code was partially
+ adapted from https://github.com/GXYM/DRRG licensed under the MIT license.
+
+ Args:
+ nodes (list[Node]): The list of Node objects.
+ score_dict (dict): The edge score dict.
+ link_thr (float): The link threshold.
+
+ Returns:
+ clusters (List[list[Node]]): The clustered Node objects.
+ """
+ assert isinstance(nodes, list)
+ assert all([isinstance(node, Node) for node in nodes])
+ assert isinstance(score_dict, dict)
+ assert isinstance(link_thr, float)
+
+ clusters = []
+ nodes = set(nodes)
+ while nodes:
+ node = nodes.pop()
+ cluster = {node}
+ node_queue = [node]
+ while node_queue:
+ node = node_queue.pop(0)
+ neighbors = set([
+ neighbor for neighbor in node.links if
+ score_dict[tuple(sorted([node.ind, neighbor.ind]))] >= link_thr
+ ])
+ neighbors.difference_update(cluster)
+ nodes.difference_update(neighbors)
+ cluster.update(neighbors)
+ node_queue.extend(neighbors)
+ clusters.append(list(cluster))
+ return clusters
+
+
+def clusters2labels(clusters, num_nodes):
+ """Convert clusters of Node to text component labels. This code was
+ partially adapted from https://github.com/GXYM/DRRG licensed under the MIT
+ license.
+
+ Args:
+ clusters (List[list[Node]]): The clusters of Node objects.
+ num_nodes (int): The total node number of graphs in an image.
+
+ Returns:
+ node_labels (ndarray): The node label array.
+ """
+ assert isinstance(clusters, list)
+ assert all([isinstance(cluster, list) for cluster in clusters])
+ assert all(
+ [isinstance(node, Node) for cluster in clusters for node in cluster])
+ assert isinstance(num_nodes, int)
+
+ node_labels = np.zeros(num_nodes)
+ for cluster_ind, cluster in enumerate(clusters):
+ for node in cluster:
+ node_labels[node.ind] = cluster_ind
+ return node_labels
+
+
+def remove_single(text_comps, comp_pred_labels):
+ """Remove isolated text components. This code was partially adapted from
+ https://github.com/GXYM/DRRG licensed under the MIT license.
+
+ Args:
+ text_comps (ndarray): The text components.
+ comp_pred_labels (ndarray): The clustering labels of text components.
+
+ Returns:
+ filtered_text_comps (ndarray): The text components with isolated ones
+ removed.
+ comp_pred_labels (ndarray): The clustering labels with labels of
+ isolated text components removed.
+ """
+ assert text_comps.ndim == 2
+ assert text_comps.shape[0] == comp_pred_labels.shape[0]
+
+ single_flags = np.zeros_like(comp_pred_labels)
+ pred_labels = np.unique(comp_pred_labels)
+ for label in pred_labels:
+ current_label_flag = (comp_pred_labels == label)
+ if np.sum(current_label_flag) == 1:
+ single_flags[np.where(current_label_flag)[0][0]] = 1
+ keep_ind = [i for i in range(len(comp_pred_labels)) if not single_flags[i]]
+ filtered_text_comps = text_comps[keep_ind, :]
+ filtered_labels = comp_pred_labels[keep_ind]
+
+ return filtered_text_comps, filtered_labels
+
+
+def norm2(point1, point2):
+ return ((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)**0.5
+
+
+def min_connect_path(points):
+ """Find the shortest path to traverse all points. This code was partially
+ adapted from https://github.com/GXYM/DRRG licensed under the MIT license.
+
+ Args:
+ points(List[list[int]]): The point sequence [[x0, y0], [x1, y1], ...].
+
+ Returns:
+ shortest_path(List[list[int]]): The shortest index path.
+ """
+ assert isinstance(points, list)
+ assert all([isinstance(point, list) for point in points])
+ assert all([isinstance(coord, int) for point in points for coord in point])
+
+ points_queue = points.copy()
+ shortest_path = []
+ current_edge = [[], []]
+
+ edge_dict0 = {}
+ edge_dict1 = {}
+ current_edge[0] = points_queue[0]
+ current_edge[1] = points_queue[0]
+ points_queue.remove(points_queue[0])
+ while points_queue:
+ for point in points_queue:
+ length0 = norm2(point, current_edge[0])
+ edge_dict0[length0] = [point, current_edge[0]]
+ length1 = norm2(current_edge[1], point)
+ edge_dict1[length1] = [current_edge[1], point]
+ key0 = min(edge_dict0.keys())
+ key1 = min(edge_dict1.keys())
+
+ if key0 <= key1:
+ start = edge_dict0[key0][0]
+ end = edge_dict0[key0][1]
+ shortest_path.insert(0, [points.index(start), points.index(end)])
+ points_queue.remove(start)
+ current_edge[0] = start
+ else:
+ start = edge_dict1[key1][0]
+ end = edge_dict1[key1][1]
+ shortest_path.append([points.index(start), points.index(end)])
+ points_queue.remove(end)
+ current_edge[1] = end
+
+ edge_dict0 = {}
+ edge_dict1 = {}
+
+ shortest_path = functools.reduce(operator.concat, shortest_path)
+ shortest_path = sorted(set(shortest_path), key=shortest_path.index)
+
+ return shortest_path
+
+
+def in_contour(cont, point):
+ x, y = point
+ is_inner = cv2.pointPolygonTest(cont, (int(x), int(y)), False) > 0.5
+ return is_inner
+
+
+def fix_corner(top_line, bot_line, start_box, end_box):
+ """Add corner points to predicted side lines. This code was partially
+ adapted from https://github.com/GXYM/DRRG licensed under the MIT license.
+
+ Args:
+ top_line (List[list[int]]): The predicted top sidelines of text
+ instance.
+ bot_line (List[list[int]]): The predicted bottom sidelines of text
+ instance.
+ start_box (ndarray): The first text component box.
+ end_box (ndarray): The last text component box.
+
+ Returns:
+ top_line (List[list[int]]): The top sidelines with corner point added.
+ bot_line (List[list[int]]): The bottom sidelines with corner point
+ added.
+ """
+ assert isinstance(top_line, list)
+ assert all(isinstance(point, list) for point in top_line)
+ assert isinstance(bot_line, list)
+ assert all(isinstance(point, list) for point in bot_line)
+ assert start_box.shape == end_box.shape == (4, 2)
+
+ contour = np.array(top_line + bot_line[::-1])
+ start_left_mid = (start_box[0] + start_box[3]) / 2
+ start_right_mid = (start_box[1] + start_box[2]) / 2
+ end_left_mid = (end_box[0] + end_box[3]) / 2
+ end_right_mid = (end_box[1] + end_box[2]) / 2
+ if not in_contour(contour, start_left_mid):
+ top_line.insert(0, start_box[0].tolist())
+ bot_line.insert(0, start_box[3].tolist())
+ elif not in_contour(contour, start_right_mid):
+ top_line.insert(0, start_box[1].tolist())
+ bot_line.insert(0, start_box[2].tolist())
+ if not in_contour(contour, end_left_mid):
+ top_line.append(end_box[0].tolist())
+ bot_line.append(end_box[3].tolist())
+ elif not in_contour(contour, end_right_mid):
+ top_line.append(end_box[1].tolist())
+ bot_line.append(end_box[2].tolist())
+ return top_line, bot_line
+
+
+def comps2boundaries(text_comps, comp_pred_labels):
+ """Construct text instance boundaries from clustered text components. This
+ code was partially adapted from https://github.com/GXYM/DRRG licensed under
+ the MIT license.
+
+ Args:
+ text_comps (ndarray): The text components.
+ comp_pred_labels (ndarray): The clustering labels of text components.
+
+ Returns:
+ boundaries (List[list[float]]): The predicted boundaries of text
+ instances.
+ """
+ assert text_comps.ndim == 2
+ assert len(text_comps) == len(comp_pred_labels)
+ boundaries = []
+ if len(text_comps) < 1:
+ return boundaries
+ for cluster_ind in range(0, int(np.max(comp_pred_labels)) + 1):
+ cluster_comp_inds = np.where(comp_pred_labels == cluster_ind)
+ text_comp_boxes = text_comps[cluster_comp_inds, :8].reshape(
+ (-1, 4, 2)).astype(np.int32)
+ score = np.mean(text_comps[cluster_comp_inds, -1])
+
+ if text_comp_boxes.shape[0] < 1:
+ continue
+
+ elif text_comp_boxes.shape[0] > 1:
+ centers = np.mean(
+ text_comp_boxes, axis=1).astype(np.int32).tolist()
+ shortest_path = min_connect_path(centers)
+ text_comp_boxes = text_comp_boxes[shortest_path]
+ top_line = np.mean(
+ text_comp_boxes[:, 0:2, :], axis=1).astype(np.int32).tolist()
+ bot_line = np.mean(
+ text_comp_boxes[:, 2:4, :], axis=1).astype(np.int32).tolist()
+ top_line, bot_line = fix_corner(top_line, bot_line,
+ text_comp_boxes[0],
+ text_comp_boxes[-1])
+ boundary_points = top_line + bot_line[::-1]
+
+ else:
+ top_line = text_comp_boxes[0, 0:2, :].astype(np.int32).tolist()
+ bot_line = text_comp_boxes[0, 2:4:-1, :].astype(np.int32).tolist()
+ boundary_points = top_line + bot_line
+
+ boundary = [p for coord in boundary_points for p in coord] + [score]
+ boundaries.append(boundary)
+
+ return boundaries
diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/utils/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..66dbd294f30e40a09ba9df16d472c9c8bfda4c47
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/utils/__init__.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from dbnet_cv.utils import Registry, build_from_cfg
+
+from .box_util import (bezier_to_polygon, is_on_same_line, sort_points,
+ stitch_boxes_into_lines)
+from .check_argument import (equal_len, is_2dlist, is_3dlist, is_none_or_type,
+ is_type_list, valid_boundary)
+from .collect_env import collect_env
+# from .data_convert_util import convert_annotations
+from .fileio import list_from_file, list_to_file
+# from .img_util import drop_orientation, is_not_png
+from .lmdb_util import recog2lmdb
+from .logger import get_root_logger
+from .model import revert_sync_batchnorm
+from .setup_env import setup_multi_processes
+from .string_util import StringStrip
+
+__all__ = [
+ 'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env',
+ 'is_3dlist', 'is_type_list', 'is_none_or_type', 'equal_len', 'is_2dlist',
+ 'valid_boundary', 'lmdb_converter', 'list_to_file', 'list_from_file',
+ 'is_on_same_line', 'stitch_boxes_into_lines', 'StringStrip',
+ 'revert_sync_batchnorm', 'bezier_to_polygon', 'sort_points',
+ 'setup_multi_processes', 'recog2lmdb'
+]
diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/box_util.py b/cv/ocr/dbnet/pytorch/dbnet/utils/box_util.py
new file mode 100755
index 0000000000000000000000000000000000000000..c0ef44e5df5b85efe69521fe28e3a34be5e58de0
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/utils/box_util.py
@@ -0,0 +1,199 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+
+import numpy as np
+
+from dbnet.utils.check_argument import is_2dlist, is_type_list
+
+
+def is_on_same_line(box_a, box_b, min_y_overlap_ratio=0.8):
+ """Check if two boxes are on the same line by their y-axis coordinates.
+
+ Two boxes are on the same line if they overlap vertically, and the length
+ of the overlapping line segment is greater than min_y_overlap_ratio * the
+ height of either of the boxes.
+
+ Args:
+ box_a (list), box_b (list): Two bounding boxes to be checked
+ min_y_overlap_ratio (float): The minimum vertical overlapping ratio
+ allowed for boxes in the same line
+
+ Returns:
+ The bool flag indicating if they are on the same line
+ """
+ a_y_min = np.min(box_a[1::2])
+ b_y_min = np.min(box_b[1::2])
+ a_y_max = np.max(box_a[1::2])
+ b_y_max = np.max(box_b[1::2])
+
+ # Make sure that box a is always the box above another
+ if a_y_min > b_y_min:
+ a_y_min, b_y_min = b_y_min, a_y_min
+ a_y_max, b_y_max = b_y_max, a_y_max
+
+ if b_y_min <= a_y_max:
+ if min_y_overlap_ratio is not None:
+ sorted_y = sorted([b_y_min, b_y_max, a_y_max])
+ overlap = sorted_y[1] - sorted_y[0]
+ min_a_overlap = (a_y_max - a_y_min) * min_y_overlap_ratio
+ min_b_overlap = (b_y_max - b_y_min) * min_y_overlap_ratio
+ return overlap >= min_a_overlap or \
+ overlap >= min_b_overlap
+ else:
+ return True
+ return False
+
+
+def stitch_boxes_into_lines(boxes, max_x_dist=10, min_y_overlap_ratio=0.8):
+ """Stitch fragmented boxes of words into lines.
+
+ Note: part of its logic is inspired by @Johndirr
+ (https://github.com/faustomorales/keras-ocr/issues/22)
+
+ Args:
+ boxes (list): List of ocr results to be stitched
+ max_x_dist (int): The maximum horizontal distance between the closest
+ edges of neighboring boxes in the same line
+ min_y_overlap_ratio (float): The minimum vertical overlapping ratio
+ allowed for any pairs of neighboring boxes in the same line
+
+ Returns:
+ merged_boxes(list[dict]): List of merged boxes and texts
+ """
+
+ if len(boxes) <= 1:
+ return boxes
+
+ merged_boxes = []
+
+ # sort groups based on the x_min coordinate of boxes
+ x_sorted_boxes = sorted(boxes, key=lambda x: np.min(x['box'][::2]))
+ # store indexes of boxes which are already parts of other lines
+ skip_idxs = set()
+
+ i = 0
+ # locate lines of boxes starting from the leftmost one
+ for i in range(len(x_sorted_boxes)):
+ if i in skip_idxs:
+ continue
+ # the rightmost box in the current line
+ rightmost_box_idx = i
+ line = [rightmost_box_idx]
+ for j in range(i + 1, len(x_sorted_boxes)):
+ if j in skip_idxs:
+ continue
+ if is_on_same_line(x_sorted_boxes[rightmost_box_idx]['box'],
+ x_sorted_boxes[j]['box'], min_y_overlap_ratio):
+ line.append(j)
+ skip_idxs.add(j)
+ rightmost_box_idx = j
+
+ # split line into lines if the distance between two neighboring
+ # sub-lines' is greater than max_x_dist
+ lines = []
+ line_idx = 0
+ lines.append([line[0]])
+ for k in range(1, len(line)):
+ curr_box = x_sorted_boxes[line[k]]
+ prev_box = x_sorted_boxes[line[k - 1]]
+ dist = np.min(curr_box['box'][::2]) - np.max(prev_box['box'][::2])
+ if dist > max_x_dist:
+ line_idx += 1
+ lines.append([])
+ lines[line_idx].append(line[k])
+
+ # Get merged boxes
+ for box_group in lines:
+ merged_box = {}
+ merged_box['text'] = ' '.join(
+ [x_sorted_boxes[idx]['text'] for idx in box_group])
+ x_min, y_min = float('inf'), float('inf')
+ x_max, y_max = float('-inf'), float('-inf')
+ for idx in box_group:
+ x_max = max(np.max(x_sorted_boxes[idx]['box'][::2]), x_max)
+ x_min = min(np.min(x_sorted_boxes[idx]['box'][::2]), x_min)
+ y_max = max(np.max(x_sorted_boxes[idx]['box'][1::2]), y_max)
+ y_min = min(np.min(x_sorted_boxes[idx]['box'][1::2]), y_min)
+ merged_box['box'] = [
+ x_min, y_min, x_max, y_min, x_max, y_max, x_min, y_max
+ ]
+ merged_boxes.append(merged_box)
+
+ return merged_boxes
+
+
+def bezier_to_polygon(bezier_points, num_sample=20):
+ """Sample points from the boundary of a polygon enclosed by two Bezier
+ curves, which are controlled by ``bezier_points``.
+
+ Args:
+ bezier_points (ndarray): A :math:`(2, 4, 2)` array of 8 Bezeir points
+ or its equalivance. The first 4 points control the curve at one
+ side and the last four control the other side.
+ num_sample (int): The number of sample points at each Bezeir curve.
+
+ Returns:
+ list[ndarray]: A list of 2*num_sample points representing the polygon
+ extracted from Bezier curves.
+
+ Warning:
+ The points are not guaranteed to be ordered. Please use
+ :func:`mmocr.utils.sort_points` to sort points if necessary.
+ """
+ assert num_sample > 0
+
+ bezier_points = np.asarray(bezier_points)
+ assert np.prod(
+ bezier_points.shape) == 16, 'Need 8 Bezier control points to continue!'
+
+ bezier = bezier_points.reshape(2, 4, 2).transpose(0, 2, 1).reshape(4, 4)
+ u = np.linspace(0, 1, num_sample)
+
+ points = np.outer((1 - u) ** 3, bezier[:, 0]) \
+ + np.outer(3 * u * ((1 - u) ** 2), bezier[:, 1]) \
+ + np.outer(3 * (u ** 2) * (1 - u), bezier[:, 2]) \
+ + np.outer(u ** 3, bezier[:, 3])
+
+ # Convert points to polygon
+ points = np.concatenate((points[:, :2], points[:, 2:]), axis=0)
+ return points.tolist()
+
+
+def sort_points(points):
+ """Sort arbitory points in clockwise order. Reference:
+ https://stackoverflow.com/a/6989383.
+
+ Args:
+ points (list[ndarray] or ndarray or list[list]): A list of unsorted
+ boundary points.
+
+ Returns:
+ list[ndarray]: A list of points sorted in clockwise order.
+ """
+
+ assert is_type_list(points, np.ndarray) or isinstance(points, np.ndarray) \
+ or is_2dlist(points)
+
+ points = np.array(points)
+ center = np.mean(points, axis=0)
+
+ def cmp(a, b):
+ oa = a - center
+ ob = b - center
+
+ # Some corner cases
+ if oa[0] >= 0 and ob[0] < 0:
+ return 1
+ if oa[0] < 0 and ob[0] >= 0:
+ return -1
+
+ prod = np.cross(oa, ob)
+ if prod > 0:
+ return 1
+ if prod < 0:
+ return -1
+
+ # a, b are on the same line from the center
+ return 1 if (oa**2).sum() < (ob**2).sum() else -1
+
+ return sorted(points, key=functools.cmp_to_key(cmp))
diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/check_argument.py b/cv/ocr/dbnet/pytorch/dbnet/utils/check_argument.py
new file mode 100755
index 0000000000000000000000000000000000000000..34cbe8dc2658d725c328eb5cd98652633a22aa24
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/utils/check_argument.py
@@ -0,0 +1,72 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+
+def is_3dlist(x):
+ """check x is 3d-list([[[1], []]]) or 2d empty list([[], []]) or 1d empty
+ list([]).
+
+ Notice:
+ The reason that it contains 1d or 2d empty list is because
+ some arguments from gt annotation file or model prediction
+ may be empty, but usually, it should be 3d-list.
+ """
+ if not isinstance(x, list):
+ return False
+ if len(x) == 0:
+ return True
+ for sub_x in x:
+ if not is_2dlist(sub_x):
+ return False
+
+ return True
+
+
+def is_2dlist(x):
+ """check x is 2d-list([[1], []]) or 1d empty list([]).
+
+ Notice:
+ The reason that it contains 1d empty list is because
+ some arguments from gt annotation file or model prediction
+ may be empty, but usually, it should be 2d-list.
+ """
+ if not isinstance(x, list):
+ return False
+ if len(x) == 0:
+ return True
+
+ return all(isinstance(item, list) for item in x)
+
+
+def is_type_list(x, type):
+
+ if not isinstance(x, list):
+ return False
+
+ return all(isinstance(item, type) for item in x)
+
+
+def is_none_or_type(x, type):
+
+ return isinstance(x, type) or x is None
+
+
+def equal_len(*argv):
+ assert len(argv) > 0
+
+ num_arg = len(argv[0])
+ for arg in argv:
+ if len(arg) != num_arg:
+ return False
+ return True
+
+
+def valid_boundary(x, with_score=True):
+ num = len(x)
+ if num < 8:
+ return False
+ if num % 2 == 0 and (not with_score):
+ return True
+ if num % 2 == 1 and with_score:
+ return True
+
+ return False
diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/collect_env.py b/cv/ocr/dbnet/pytorch/dbnet/utils/collect_env.py
new file mode 100755
index 0000000000000000000000000000000000000000..ed6bc0cca661b3ca39df27496ac487b7e57e8f28
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/utils/collect_env.py
@@ -0,0 +1,17 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from dbnet_cv.utils import collect_env as collect_base_env
+from dbnet_cv.utils import get_git_hash
+
+import dbnet
+
+
+def collect_env():
+ """Collect the information of the running environments."""
+ env_info = collect_base_env()
+ # env_info['MMOCR'] = mmocr.__version__ + '+' + get_git_hash()[:7]
+ return env_info
+
+
+if __name__ == '__main__':
+ for name, val in collect_env().items():
+ print(f'{name}: {val}')
diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/fileio.py b/cv/ocr/dbnet/pytorch/dbnet/utils/fileio.py
new file mode 100755
index 0000000000000000000000000000000000000000..cf07dd0553cc16221c87656c69d46f74d57da001
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/utils/fileio.py
@@ -0,0 +1,38 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+
+import dbnet_cv
+
+
+def list_to_file(filename, lines):
+ """Write a list of strings to a text file.
+
+ Args:
+ filename (str): The output filename. It will be created/overwritten.
+ lines (list(str)): Data to be written.
+ """
+ dbnet_cv.mkdir_or_exist(os.path.dirname(filename))
+ with open(filename, 'w', encoding='utf-8') as fw:
+ for line in lines:
+ fw.write(f'{line}\n')
+
+
+def list_from_file(filename, encoding='utf-8'):
+ """Load a text file and parse the content as a list of strings. The
+ trailing "\\r" and "\\n" of each line will be removed.
+
+ Note:
+ This will be replaced by dbnet_cv's version after it supports encoding.
+
+ Args:
+ filename (str): Filename.
+ encoding (str): Encoding used to open the file. Default utf-8.
+
+ Returns:
+ list[str]: A list of strings.
+ """
+ item_list = []
+ with open(filename, 'r', encoding=encoding) as f:
+ for line in f:
+ item_list.append(line.rstrip('\n\r'))
+ return item_list
diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/lmdb_util.py b/cv/ocr/dbnet/pytorch/dbnet/utils/lmdb_util.py
new file mode 100755
index 0000000000000000000000000000000000000000..618c836fe8e611d373fbb0dcd1ee5ffe4eeaf957
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/utils/lmdb_util.py
@@ -0,0 +1,128 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import os
+import os.path as osp
+
+import cv2
+import lmdb
+import numpy as np
+
+from dbnet.utils import list_from_file
+
+
+def check_image_is_valid(imageBin):
+ if imageBin is None:
+ return False
+ imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
+ img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
+ imgH, imgW = img.shape[0], img.shape[1]
+ if imgH * imgW == 0:
+ return False
+ return True
+
+
+def parse_line(line, format):
+ if format == 'txt':
+ img_name, text = line.split(' ')
+ else:
+ line = json.loads(line)
+ img_name = line['filename']
+ text = line['text']
+ return img_name, text
+
+
+def write_cache(env, cache):
+ with env.begin(write=True) as txn:
+ cursor = txn.cursor()
+ cursor.putmulti(cache, dupdata=False, overwrite=True)
+
+
+def recog2lmdb(img_root,
+ label_path,
+ output,
+ label_format='txt',
+ label_only=False,
+ batch_size=1000,
+ encoding='utf-8',
+ lmdb_map_size=1099511627776,
+ verify=True):
+ """Create text recognition dataset to LMDB format.
+
+ Args:
+ img_root (str): Path to images.
+ label_path (str): Path to label file.
+ output (str): LMDB output path.
+ label_format (str): Format of the label file, either txt or jsonl.
+ label_only (bool): Only convert label to lmdb format.
+ batch_size (int): Number of files written to the cache each time.
+ encoding (str): Label encoding method.
+ lmdb_map_size (int): Maximum size database may grow to.
+ verify (bool): If true, check the validity of
+ every image.Defaults to True.
+
+ E.g.
+ This function supports MMOCR's recognition data format and the label file
+ can be txt or jsonl, as follows:
+
+ ├──img_root
+ | |—— img1.jpg
+ | |—— img2.jpg
+ | |—— ...
+ |——label.txt (or label.jsonl)
+
+ label.txt: img1.jpg HELLO
+ img2.jpg WORLD
+ ...
+
+ label.jsonl: {'filename':'img1.jpg', 'text':'HELLO'}
+ {'filename':'img2.jpg', 'text':'WORLD'}
+ ...
+ """
+ # check label format
+ assert osp.basename(label_path).split('.')[-1] == label_format
+ # create lmdb env
+ os.makedirs(output, exist_ok=True)
+ env = lmdb.open(output, map_size=lmdb_map_size)
+ # load label file
+ anno_list = list_from_file(label_path, encoding=encoding)
+ cache = []
+ # index start from 1
+ cnt = 1
+ n_samples = len(anno_list)
+ for anno in anno_list:
+ label_key = 'label-%09d'.encode(encoding) % cnt
+ img_name, text = parse_line(anno, label_format)
+ if label_only:
+ # convert only labels to lmdb
+ line = json.dumps(
+ dict(filename=img_name, text=text), ensure_ascii=False)
+ cache.append((label_key, line.encode(encoding)))
+ else:
+ # convert both images and labels to lmdb
+ img_path = osp.join(img_root, img_name)
+ if not osp.exists(img_path):
+ print('%s does not exist' % img_path)
+ continue
+ with open(img_path, 'rb') as f:
+ image_bin = f.read()
+ if verify:
+ try:
+ if not check_image_is_valid(image_bin):
+ print('%s is not a valid image' % img_path)
+ continue
+ except Exception:
+ print('error occurred at ', img_name)
+ image_key = 'image-%09d'.encode(encoding) % cnt
+ cache.append((image_key, image_bin))
+ cache.append((label_key, text.encode(encoding)))
+
+ if cnt % batch_size == 0:
+ write_cache(env, cache)
+ cache = []
+ print('Written %d / %d' % (cnt, n_samples))
+ cnt += 1
+ n_samples = cnt - 1
+ cache.append(
+ ('num-samples'.encode(encoding), str(n_samples).encode(encoding)))
+ write_cache(env, cache)
+ print('Created lmdb dataset with %d samples' % n_samples)
diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/logger.py b/cv/ocr/dbnet/pytorch/dbnet/utils/logger.py
new file mode 100755
index 0000000000000000000000000000000000000000..779ae99b01096e64e89052a906cae27d30b5c859
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/utils/logger.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+
+from dbnet_cv.utils import get_logger
+
+
+def get_root_logger(log_file=None, log_level=logging.INFO):
+ """Use `get_logger` method in dbnet_cv to get the root logger.
+
+ The logger will be initialized if it has not been initialized. By default a
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
+ also be added. The name of the root logger is the top-level package name,
+ e.g., "mmpose".
+
+ Args:
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the root logger.
+ log_level (int): The root logger level. Note that only the process of
+ rank 0 is affected, while other processes will set the level to
+ "Error" and be silent most of the time.
+
+ Returns:
+ logging.Logger: The root logger.
+ """
+ return get_logger(__name__.split('.')[0], log_file, log_level)
diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/model.py b/cv/ocr/dbnet/pytorch/dbnet/utils/model.py
new file mode 100755
index 0000000000000000000000000000000000000000..4a126006b69c70d7780a310de46c0c2e0a0495ba
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/utils/model.py
@@ -0,0 +1,51 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+
+class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
+ """A general BatchNorm layer without input dimension check.
+
+ Reproduced from @kapily's work:
+ (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
+ The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
+ is `_check_input_dim` that is designed for tensor sanity checks.
+ The check has been bypassed in this class for the convenience of converting
+ SyncBatchNorm.
+ """
+
+ def _check_input_dim(self, input):
+ return
+
+
+def revert_sync_batchnorm(module):
+ """Helper function to convert all `SyncBatchNorm` layers in the model to
+ `BatchNormXd` layers.
+
+ Adapted from @kapily's work:
+ (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
+
+ Args:
+ module (nn.Module): The module containing `SyncBatchNorm` layers.
+
+ Returns:
+ module_output: The converted module with `BatchNormXd` layers.
+ """
+ module_output = module
+ if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm):
+ module_output = _BatchNormXd(module.num_features, module.eps,
+ module.momentum, module.affine,
+ module.track_running_stats)
+ if module.affine:
+ with torch.no_grad():
+ module_output.weight = module.weight
+ module_output.bias = module.bias
+ module_output.running_mean = module.running_mean
+ module_output.running_var = module.running_var
+ module_output.num_batches_tracked = module.num_batches_tracked
+ module_output.training = module.training
+ if hasattr(module, 'qconfig'):
+ module_output.qconfig = module.qconfig
+ for name, child in module.named_children():
+ module_output.add_module(name, revert_sync_batchnorm(child))
+ del module
+ return module_output
diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/ocr.py b/cv/ocr/dbnet/pytorch/dbnet/utils/ocr.py
new file mode 100755
index 0000000000000000000000000000000000000000..a09bb15de6d0ff52a1c6d4bb23a40ac3506504e3
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/utils/ocr.py
@@ -0,0 +1,878 @@
+#!/usr/bin/env python
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import os
+import warnings
+from argparse import ArgumentParser, Namespace
+from pathlib import Path
+
+import dbnet_cv
+import numpy as np
+import torch
+from dbnet_cv.image.misc import tensor2imgs
+from dbnet_cv.runner import load_checkpoint
+from dbnet_cv.utils.config import Config
+from PIL import Image
+
+try:
+ import tesserocr
+except ImportError:
+ tesserocr = None
+
+from dbnet.apis import init_detector
+from dbnet.apis.inference import model_inference
+from dbnet.core.visualize import det_recog_show_result
+from dbnet.datasets.kie_dataset import KIEDataset
+from dbnet.datasets.pipelines.crop import crop_img
+from dbnet.models import build_detector
+from dbnet.models.textdet.detectors import TextDetectorMixin
+from dbnet.models.textrecog.recognizer import BaseRecognizer
+from dbnet.utils import is_type_list
+from dbnet.utils.box_util import stitch_boxes_into_lines
+from dbnet.utils.fileio import list_from_file
+from dbnet.utils.model import revert_sync_batchnorm
+
+
+# Parse CLI arguments
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument(
+ 'img', type=str, help='Input image file or folder path.')
+ parser.add_argument(
+ '--output',
+ type=str,
+ default='',
+ help='Output file/folder name for visualization')
+ parser.add_argument(
+ '--det',
+ type=str,
+ default='PANet_IC15',
+ help='Pretrained text detection algorithm')
+ parser.add_argument(
+ '--det-config',
+ type=str,
+ default='',
+ help='Path to the custom config file of the selected det model. It '
+ 'overrides the settings in det')
+ parser.add_argument(
+ '--det-ckpt',
+ type=str,
+ default='',
+ help='Path to the custom checkpoint file of the selected det model. '
+ 'It overrides the settings in det')
+ parser.add_argument(
+ '--recog',
+ type=str,
+ default='SEG',
+ help='Pretrained text recognition algorithm')
+ parser.add_argument(
+ '--recog-config',
+ type=str,
+ default='',
+ help='Path to the custom config file of the selected recog model. It'
+ 'overrides the settings in recog')
+ parser.add_argument(
+ '--recog-ckpt',
+ type=str,
+ default='',
+ help='Path to the custom checkpoint file of the selected recog model. '
+ 'It overrides the settings in recog')
+ parser.add_argument(
+ '--kie',
+ type=str,
+ default='',
+ help='Pretrained key information extraction algorithm')
+ parser.add_argument(
+ '--kie-config',
+ type=str,
+ default='',
+ help='Path to the custom config file of the selected kie model. It'
+ 'overrides the settings in kie')
+ parser.add_argument(
+ '--kie-ckpt',
+ type=str,
+ default='',
+ help='Path to the custom checkpoint file of the selected kie model. '
+ 'It overrides the settings in kie')
+ parser.add_argument(
+ '--config-dir',
+ type=str,
+ default=os.path.join(str(Path.cwd()), 'configs/'),
+ help='Path to the config directory where all the config files '
+ 'are located. Defaults to "configs/"')
+ parser.add_argument(
+ '--batch-mode',
+ action='store_true',
+ help='Whether use batch mode for inference')
+ parser.add_argument(
+ '--recog-batch-size',
+ type=int,
+ default=0,
+ help='Batch size for text recognition')
+ parser.add_argument(
+ '--det-batch-size',
+ type=int,
+ default=0,
+ help='Batch size for text detection')
+ parser.add_argument(
+ '--single-batch-size',
+ type=int,
+ default=0,
+ help='Batch size for separate det/recog inference')
+ parser.add_argument(
+ '--device', default=None, help='Device used for inference.')
+ parser.add_argument(
+ '--export',
+ type=str,
+ default='',
+ help='Folder where the results of each image are exported')
+ parser.add_argument(
+ '--export-format',
+ type=str,
+ default='json',
+ help='Format of the exported result file(s)')
+ parser.add_argument(
+ '--details',
+ action='store_true',
+ help='Whether include the text boxes coordinates and confidence values'
+ )
+ parser.add_argument(
+ '--imshow',
+ action='store_true',
+ help='Whether show image with OpenCV.')
+ parser.add_argument(
+ '--print-result',
+ action='store_true',
+ help='Prints the recognised text')
+ parser.add_argument(
+ '--merge', action='store_true', help='Merge neighboring boxes')
+ parser.add_argument(
+ '--merge-xdist',
+ type=float,
+ default=20,
+ help='The maximum x-axis distance to merge boxes')
+ args = parser.parse_args()
+ if args.det == 'None':
+ args.det = None
+ if args.recog == 'None':
+ args.recog = None
+ # Warnings
+ if args.merge and not (args.det and args.recog):
+ warnings.warn(
+ 'Box merging will not work if the script is not'
+ ' running in detection + recognition mode.', UserWarning)
+ if not os.path.samefile(args.config_dir, os.path.join(str(
+ Path.cwd()))) and (args.det_config != ''
+ or args.recog_config != ''):
+ warnings.warn(
+ 'config_dir will be overridden by det-config or recog-config.',
+ UserWarning)
+ return args
+
+
+class MMOCR:
+
+ def __init__(self,
+ det='PANet_IC15',
+ det_config='',
+ det_ckpt='',
+ recog='SEG',
+ recog_config='',
+ recog_ckpt='',
+ kie='',
+ kie_config='',
+ kie_ckpt='',
+ config_dir=os.path.join(str(Path.cwd()), 'configs/'),
+ device=None,
+ **kwargs):
+
+ textdet_models = {
+ 'DB_r18': {
+ 'config':
+ 'dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
+ 'ckpt':
+ 'dbnet/'
+ 'dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth'
+ },
+ 'DB_r50': {
+ 'config':
+ 'dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py',
+ 'ckpt':
+ 'dbnet/'
+ 'dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20211025-9fe3b590.pth'
+ },
+ 'DBPP_r50': {
+ 'config':
+ 'dbnetpp/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015.py',
+ 'ckpt':
+ 'dbnet/'
+ 'dbnetpp_r50dcnv2_fpnc_1200e_icdar2015-20220502-d7a76fff.pth'
+ },
+ 'DRRG': {
+ 'config':
+ 'drrg/drrg_r50_fpn_unet_1200e_ctw1500.py',
+ 'ckpt':
+ 'drrg/drrg_r50_fpn_unet_1200e_ctw1500_20211022-fb30b001.pth'
+ },
+ 'FCE_IC15': {
+ 'config':
+ 'fcenet/fcenet_r50_fpn_1500e_icdar2015.py',
+ 'ckpt':
+ 'fcenet/fcenet_r50_fpn_1500e_icdar2015_20211022-daefb6ed.pth'
+ },
+ 'FCE_CTW_DCNv2': {
+ 'config':
+ 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py',
+ 'ckpt':
+ 'fcenet/' +
+ 'fcenet_r50dcnv2_fpn_1500e_ctw1500_20211022-e326d7ec.pth'
+ },
+ 'MaskRCNN_CTW': {
+ 'config':
+ 'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py',
+ 'ckpt':
+ 'maskrcnn/'
+ 'mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.pth'
+ },
+ 'MaskRCNN_IC15': {
+ 'config':
+ 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py',
+ 'ckpt':
+ 'maskrcnn/'
+ 'mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.pth'
+ },
+ 'MaskRCNN_IC17': {
+ 'config':
+ 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py',
+ 'ckpt':
+ 'maskrcnn/'
+ 'mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.pth'
+ },
+ 'PANet_CTW': {
+ 'config':
+ 'panet/panet_r18_fpem_ffm_600e_ctw1500.py',
+ 'ckpt':
+ 'panet/'
+ 'panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.pth'
+ },
+ 'PANet_IC15': {
+ 'config':
+ 'panet/panet_r18_fpem_ffm_600e_icdar2015.py',
+ 'ckpt':
+ 'panet/'
+ 'panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth'
+ },
+ 'PS_CTW': {
+ 'config': 'psenet/psenet_r50_fpnf_600e_ctw1500.py',
+ 'ckpt':
+ 'psenet/psenet_r50_fpnf_600e_ctw1500_20210401-216fed50.pth'
+ },
+ 'PS_IC15': {
+ 'config':
+ 'psenet/psenet_r50_fpnf_600e_icdar2015.py',
+ 'ckpt':
+ 'psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth'
+ },
+ 'TextSnake': {
+ 'config':
+ 'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py',
+ 'ckpt':
+ 'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth'
+ },
+ 'Tesseract': {}
+ }
+
+ textrecog_models = {
+ 'CRNN': {
+ 'config': 'crnn/crnn_academic_dataset.py',
+ 'ckpt': 'crnn/crnn_academic-a723a1c5.pth'
+ },
+ 'SAR': {
+ 'config': 'sar/sar_r31_parallel_decoder_academic.py',
+ 'ckpt': 'sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth'
+ },
+ 'SAR_CN': {
+ 'config':
+ 'sar/sar_r31_parallel_decoder_chinese.py',
+ 'ckpt':
+ 'sar/sar_r31_parallel_decoder_chineseocr_20210507-b4be8214.pth'
+ },
+ 'NRTR_1/16-1/8': {
+ 'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py',
+ 'ckpt':
+ 'nrtr/nrtr_r31_1by16_1by8_academic_20211124-f60cebf4.pth'
+ },
+ 'NRTR_1/8-1/4': {
+ 'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py',
+ 'ckpt':
+ 'nrtr/nrtr_r31_1by8_1by4_academic_20211123-e1fdb322.pth'
+ },
+ 'RobustScanner': {
+ 'config': 'robust_scanner/robustscanner_r31_academic.py',
+ 'ckpt': 'robustscanner/robustscanner_r31_academic-5f05874f.pth'
+ },
+ 'SATRN': {
+ 'config': 'satrn/satrn_academic.py',
+ 'ckpt': 'satrn/satrn_academic_20211009-cb8b1580.pth'
+ },
+ 'SATRN_sm': {
+ 'config': 'satrn/satrn_small.py',
+ 'ckpt': 'satrn/satrn_small_20211009-2cf13355.pth'
+ },
+ 'ABINet': {
+ 'config': 'abinet/abinet_academic.py',
+ 'ckpt': 'abinet/abinet_academic-f718abf6.pth'
+ },
+ 'ABINet_Vision': {
+ 'config': 'abinet/abinet_vision_only_academic.py',
+ 'ckpt': 'abinet/abinet_vision_only_academic-e6b9ea89.pth'
+ },
+ 'SEG': {
+ 'config': 'seg/seg_r31_1by16_fpnocr_academic.py',
+ 'ckpt': 'seg/seg_r31_1by16_fpnocr_academic-72235b11.pth'
+ },
+ 'CRNN_TPS': {
+ 'config': 'tps/crnn_tps_academic_dataset.py',
+ 'ckpt': 'tps/crnn_tps_academic_dataset_20210510-d221a905.pth'
+ },
+ 'Tesseract': {},
+ 'MASTER': {
+ 'config': 'master/master_r31_12e_ST_MJ_SA.py',
+ 'ckpt': 'master/master_r31_12e_ST_MJ_SA-787edd36.pth'
+ }
+ }
+
+ kie_models = {
+ 'SDMGR': {
+ 'config': 'sdmgr/sdmgr_unet16_60e_wildreceipt.py',
+ 'ckpt':
+ 'sdmgr/sdmgr_unet16_60e_wildreceipt_20220706-57c220a6.pth'
+ }
+ }
+
+ self.td = det
+ self.tr = recog
+ self.kie = kie
+ self.device = device
+ if self.device is None:
+ self.device = torch.device(
+ 'cuda' if torch.cuda.is_available() else 'cpu')
+
+ # Check if the det/recog model choice is valid
+ if self.td and self.td not in textdet_models:
+ raise ValueError(self.td,
+ 'is not a supported text detection algorthm')
+ elif self.tr and self.tr not in textrecog_models:
+ raise ValueError(self.tr,
+ 'is not a supported text recognition algorithm')
+ elif self.kie:
+ if self.kie not in kie_models:
+ raise ValueError(
+ self.kie, 'is not a supported key information extraction'
+ ' algorithm')
+ elif not (self.td and self.tr):
+ raise NotImplementedError(
+ self.kie, 'has to run together'
+ ' with text detection and recognition algorithms.')
+
+ self.detect_model = None
+ if self.td and self.td == 'Tesseract':
+ if tesserocr is None:
+ raise ImportError('Please install tesserocr first. '
+ 'Check out the installation guide at '
+ 'https://github.com/sirfz/tesserocr')
+ self.detect_model = 'Tesseract_det'
+ elif self.td:
+ # Build detection model
+ if not det_config:
+ det_config = os.path.join(config_dir, 'textdet/',
+ textdet_models[self.td]['config'])
+ if not det_ckpt:
+ det_ckpt = 'https://download.openmmlab.com/mmocr/textdet/' + \
+ textdet_models[self.td]['ckpt']
+
+ self.detect_model = init_detector(
+ det_config, det_ckpt, device=self.device)
+ self.detect_model = revert_sync_batchnorm(self.detect_model)
+
+ self.recog_model = None
+ if self.tr and self.tr == 'Tesseract':
+ if tesserocr is None:
+ raise ImportError('Please install tesserocr first. '
+ 'Check out the installation guide at '
+ 'https://github.com/sirfz/tesserocr')
+ self.recog_model = 'Tesseract_recog'
+ elif self.tr:
+ # Build recognition model
+ if not recog_config:
+ recog_config = os.path.join(
+ config_dir, 'textrecog/',
+ textrecog_models[self.tr]['config'])
+ if not recog_ckpt:
+ recog_ckpt = 'https://download.openmmlab.com/mmocr/' + \
+ 'textrecog/' + textrecog_models[self.tr]['ckpt']
+
+ self.recog_model = init_detector(
+ recog_config, recog_ckpt, device=self.device)
+ self.recog_model = revert_sync_batchnorm(self.recog_model)
+
+ self.kie_model = None
+ if self.kie:
+ # Build key information extraction model
+ if not kie_config:
+ kie_config = os.path.join(config_dir, 'kie/',
+ kie_models[self.kie]['config'])
+ if not kie_ckpt:
+ kie_ckpt = 'https://download.openmmlab.com/mmocr/' + \
+ 'kie/' + kie_models[self.kie]['ckpt']
+
+ kie_cfg = Config.fromfile(kie_config)
+ self.kie_model = build_detector(
+ kie_cfg.model, test_cfg=kie_cfg.get('test_cfg'))
+ self.kie_model = revert_sync_batchnorm(self.kie_model)
+ self.kie_model.cfg = kie_cfg
+ load_checkpoint(self.kie_model, kie_ckpt, map_location=self.device)
+
+ # Attribute check
+ for model in list(filter(None, [self.recog_model, self.detect_model])):
+ if hasattr(model, 'module'):
+ model = model.module
+
+ @staticmethod
+ def get_tesserocr_api():
+ """Get tesserocr api depending on different platform."""
+ import subprocess
+ import sys
+
+ if sys.platform == 'linux':
+ api = tesserocr.PyTessBaseAPI()
+ elif sys.platform == 'win32':
+ try:
+ p = subprocess.Popen(
+ 'where tesseract', stdout=subprocess.PIPE, shell=True)
+ s = p.communicate()[0].decode('utf-8').split('\\')
+ path = s[:-1] + ['tessdata']
+ tessdata_path = '/'.join(path)
+ api = tesserocr.PyTessBaseAPI(path=tessdata_path)
+ except RuntimeError:
+ raise RuntimeError(
+ 'Please install tesseract first.\n Check out the'
+ ' installation guide at'
+ ' https://github.com/UB-Mannheim/tesseract/wiki')
+ else:
+ raise NotImplementedError
+ return api
+
+ def tesseract_det_inference(self, imgs, **kwargs):
+ """Inference image(s) with the tesseract detector.
+
+ Args:
+ imgs (ndarray or list[ndarray]): image(s) to inference.
+
+ Returns:
+ result (dict): Predicted results.
+ """
+ is_batch = True
+ if isinstance(imgs, np.ndarray):
+ is_batch = False
+ imgs = [imgs]
+ assert is_type_list(imgs, np.ndarray)
+ api = self.get_tesserocr_api()
+
+ # Get detection result using tesseract
+ results = []
+ for img in imgs:
+ image = Image.fromarray(img)
+ api.SetImage(image)
+ boxes = api.GetComponentImages(tesserocr.RIL.TEXTLINE, True)
+ boundaries = []
+ for _, box, _, _ in boxes:
+ min_x = box['x']
+ min_y = box['y']
+ max_x = box['x'] + box['w']
+ max_y = box['y'] + box['h']
+ boundary = [
+ min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y, 1.0
+ ]
+ boundaries.append(boundary)
+ results.append({'boundary_result': boundaries})
+
+ # close tesserocr api
+ api.End()
+
+ if not is_batch:
+ return results[0]
+ else:
+ return results
+
+ def tesseract_recog_inference(self, imgs, **kwargs):
+ """Inference image(s) with the tesseract recognizer.
+
+ Args:
+ imgs (ndarray or list[ndarray]): image(s) to inference.
+
+ Returns:
+ result (dict): Predicted results.
+ """
+ is_batch = True
+ if isinstance(imgs, np.ndarray):
+ is_batch = False
+ imgs = [imgs]
+ assert is_type_list(imgs, np.ndarray)
+ api = self.get_tesserocr_api()
+
+ results = []
+ for img in imgs:
+ image = Image.fromarray(img)
+ api.SetImage(image)
+ api.SetRectangle(0, 0, img.shape[1], img.shape[0])
+ # Remove beginning and trailing spaces from Tesseract
+ text = api.GetUTF8Text().strip()
+ conf = api.MeanTextConf() / 100
+ results.append({'text': text, 'score': conf})
+
+ # close tesserocr api
+ api.End()
+
+ if not is_batch:
+ return results[0]
+ else:
+ return results
+
+ def readtext(self,
+ img,
+ output=None,
+ details=False,
+ export=None,
+ export_format='json',
+ batch_mode=False,
+ recog_batch_size=0,
+ det_batch_size=0,
+ single_batch_size=0,
+ imshow=False,
+ print_result=False,
+ merge=False,
+ merge_xdist=20,
+ **kwargs):
+ args = locals().copy()
+ [args.pop(x, None) for x in ['kwargs', 'self']]
+ args = Namespace(**args)
+
+ # Input and output arguments processing
+ self._args_processing(args)
+ self.args = args
+
+ pp_result = None
+
+ # Send args and models to the MMOCR model inference API
+ # and call post-processing functions for the output
+ if self.detect_model and self.recog_model:
+ det_recog_result = self.det_recog_kie_inference(
+ self.detect_model, self.recog_model, kie_model=self.kie_model)
+ pp_result = self.det_recog_pp(det_recog_result)
+ else:
+ for model in list(
+ filter(None, [self.recog_model, self.detect_model])):
+ result = self.single_inference(model, args.arrays,
+ args.batch_mode,
+ args.single_batch_size)
+ pp_result = self.single_pp(result, model)
+
+ return pp_result
+
+ # Post processing function for end2end ocr
+ def det_recog_pp(self, result):
+ final_results = []
+ args = self.args
+ for arr, output, export, det_recog_result in zip(
+ args.arrays, args.output, args.export, result):
+ if output or args.imshow:
+ if self.kie_model:
+ res_img = det_recog_show_result(arr, det_recog_result)
+ else:
+ res_img = det_recog_show_result(
+ arr, det_recog_result, out_file=output)
+ if args.imshow and not self.kie_model:
+ dbnet_cv.imshow(res_img, 'inference results')
+ if not args.details:
+ simple_res = {}
+ simple_res['filename'] = det_recog_result['filename']
+ simple_res['text'] = [
+ x['text'] for x in det_recog_result['result']
+ ]
+ final_result = simple_res
+ else:
+ final_result = det_recog_result
+ if export:
+ dbnet_cv.dump(final_result, export, indent=4)
+ if args.print_result:
+ print(final_result, end='\n\n')
+ final_results.append(final_result)
+ return final_results
+
+ # Post processing function for separate det/recog inference
+ def single_pp(self, result, model):
+ for arr, output, export, res in zip(self.args.arrays, self.args.output,
+ self.args.export, result):
+ if export:
+ dbnet_cv.dump(res, export, indent=4)
+ if output or self.args.imshow:
+ if model == 'Tesseract_det':
+ res_img = TextDetectorMixin(show_score=False).show_result(
+ arr, res, out_file=output)
+ elif model == 'Tesseract_recog':
+ res_img = BaseRecognizer.show_result(
+ arr, res, out_file=output)
+ else:
+ res_img = model.show_result(arr, res, out_file=output)
+ if self.args.imshow:
+ dbnet_cv.imshow(res_img, 'inference results')
+ if self.args.print_result:
+ print(res, end='\n\n')
+ return result
+
+ def generate_kie_labels(self, result, boxes, class_list):
+ idx_to_cls = {}
+ if class_list is not None:
+ for line in list_from_file(class_list):
+ class_idx, class_label = line.strip().split()
+ idx_to_cls[class_idx] = class_label
+
+ max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1)
+ node_pred_label = max_idx.numpy().tolist()
+ node_pred_score = max_value.numpy().tolist()
+ labels = []
+ for i in range(len(boxes)):
+ pred_label = str(node_pred_label[i])
+ if pred_label in idx_to_cls:
+ pred_label = idx_to_cls[pred_label]
+ pred_score = node_pred_score[i]
+ labels.append((pred_label, pred_score))
+ return labels
+
+ def visualize_kie_output(self,
+ model,
+ data,
+ result,
+ out_file=None,
+ show=False):
+ """Visualizes KIE output."""
+ img_tensor = data['img'].data
+ img_meta = data['img_metas'].data
+ gt_bboxes = data['gt_bboxes'].data.numpy().tolist()
+ if img_tensor.dtype == torch.uint8:
+ # The img tensor is the raw input not being normalized
+ # (For SDMGR non-visual)
+ img = img_tensor.cpu().numpy().transpose(1, 2, 0)
+ else:
+ img = tensor2imgs(
+ img_tensor.unsqueeze(0), **img_meta.get('img_norm_cfg', {}))[0]
+ h, w, _ = img_meta.get('img_shape', img.shape)
+ img_show = img[:h, :w, :]
+ model.show_result(
+ img_show, result, gt_bboxes, show=show, out_file=out_file)
+
+ # End2end ocr inference pipeline
+ def det_recog_kie_inference(self, det_model, recog_model, kie_model=None):
+ end2end_res = []
+ # Find bounding boxes in the images (text detection)
+ det_result = self.single_inference(det_model, self.args.arrays,
+ self.args.batch_mode,
+ self.args.det_batch_size)
+ bboxes_list = [res['boundary_result'] for res in det_result]
+
+ if kie_model:
+ kie_dataset = KIEDataset(
+ dict_file=kie_model.cfg.data.test.dict_file)
+
+ # For each bounding box, the image is cropped and
+ # sent to the recognition model either one by one
+ # or all together depending on the batch_mode
+ for filename, arr, bboxes, out_file in zip(self.args.filenames,
+ self.args.arrays,
+ bboxes_list,
+ self.args.output):
+ img_e2e_res = {}
+ img_e2e_res['filename'] = filename
+ img_e2e_res['result'] = []
+ box_imgs = []
+ for bbox in bboxes:
+ box_res = {}
+ box_res['box'] = [round(x) for x in bbox[:-1]]
+ box_res['box_score'] = float(bbox[-1])
+ box = bbox[:8]
+ if len(bbox) > 9:
+ min_x = min(bbox[0:-1:2])
+ min_y = min(bbox[1:-1:2])
+ max_x = max(bbox[0:-1:2])
+ max_y = max(bbox[1:-1:2])
+ box = [
+ min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y
+ ]
+ box_img = crop_img(arr, box)
+ if self.args.batch_mode:
+ box_imgs.append(box_img)
+ else:
+ if recog_model == 'Tesseract_recog':
+ recog_result = self.single_inference(
+ recog_model, box_img, batch_mode=True)
+ else:
+ recog_result = model_inference(recog_model, box_img)
+ text = recog_result['text']
+ text_score = recog_result['score']
+ if isinstance(text_score, list):
+ text_score = sum(text_score) / max(1, len(text))
+ box_res['text'] = text
+ box_res['text_score'] = text_score
+ img_e2e_res['result'].append(box_res)
+
+ if self.args.batch_mode:
+ recog_results = self.single_inference(
+ recog_model, box_imgs, True, self.args.recog_batch_size)
+ for i, recog_result in enumerate(recog_results):
+ text = recog_result['text']
+ text_score = recog_result['score']
+ if isinstance(text_score, (list, tuple)):
+ text_score = sum(text_score) / max(1, len(text))
+ img_e2e_res['result'][i]['text'] = text
+ img_e2e_res['result'][i]['text_score'] = text_score
+
+ if self.args.merge:
+ img_e2e_res['result'] = stitch_boxes_into_lines(
+ img_e2e_res['result'], self.args.merge_xdist, 0.5)
+
+ if kie_model:
+ annotations = copy.deepcopy(img_e2e_res['result'])
+ # Customized for kie_dataset, which
+ # assumes that boxes are represented by only 4 points
+ for i, ann in enumerate(annotations):
+ min_x = min(ann['box'][::2])
+ min_y = min(ann['box'][1::2])
+ max_x = max(ann['box'][::2])
+ max_y = max(ann['box'][1::2])
+ annotations[i]['box'] = [
+ min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y
+ ]
+ ann_info = kie_dataset._parse_anno_info(annotations)
+ ann_info['ori_bboxes'] = ann_info.get('ori_bboxes',
+ ann_info['bboxes'])
+ ann_info['gt_bboxes'] = ann_info.get('gt_bboxes',
+ ann_info['bboxes'])
+ kie_result, data = model_inference(
+ kie_model,
+ arr,
+ ann=ann_info,
+ return_data=True,
+ batch_mode=self.args.batch_mode)
+ # visualize KIE results
+ self.visualize_kie_output(
+ kie_model,
+ data,
+ kie_result,
+ out_file=out_file,
+ show=self.args.imshow)
+ gt_bboxes = data['gt_bboxes'].data.numpy().tolist()
+ labels = self.generate_kie_labels(kie_result, gt_bboxes,
+ kie_model.class_list)
+ for i in range(len(gt_bboxes)):
+ img_e2e_res['result'][i]['label'] = labels[i][0]
+ img_e2e_res['result'][i]['label_score'] = labels[i][1]
+
+ end2end_res.append(img_e2e_res)
+ return end2end_res
+
+ # Separate det/recog inference pipeline
+ def single_inference(self, model, arrays, batch_mode, batch_size=0):
+
+ def inference(m, a, **kwargs):
+ if model == 'Tesseract_det':
+ return self.tesseract_det_inference(a)
+ elif model == 'Tesseract_recog':
+ return self.tesseract_recog_inference(a)
+ else:
+ return model_inference(m, a, **kwargs)
+
+ result = []
+ if batch_mode:
+ if batch_size == 0:
+ result = inference(model, arrays, batch_mode=True)
+ else:
+ n = batch_size
+ arr_chunks = [
+ arrays[i:i + n] for i in range(0, len(arrays), n)
+ ]
+ for chunk in arr_chunks:
+ result.extend(inference(model, chunk, batch_mode=True))
+ else:
+ for arr in arrays:
+ result.append(inference(model, arr, batch_mode=False))
+ return result
+
+ # Arguments pre-processing function
+ def _args_processing(self, args):
+ # Check if the input is a list/tuple that
+ # contains only np arrays or strings
+ if isinstance(args.img, (list, tuple)):
+ img_list = args.img
+ if not all([isinstance(x, (np.ndarray, str)) for x in args.img]):
+ raise AssertionError('Images must be strings or numpy arrays')
+
+ # Create a list of the images
+ if isinstance(args.img, str):
+ img_path = Path(args.img)
+ if img_path.is_dir():
+ img_list = [str(x) for x in img_path.glob('*')]
+ else:
+ img_list = [str(img_path)]
+ elif isinstance(args.img, np.ndarray):
+ img_list = [args.img]
+
+ # Read all image(s) in advance to reduce wasted time
+ # re-reading the images for visualization output
+ args.arrays = [dbnet_cv.imread(x) for x in img_list]
+
+ # Create a list of filenames (used for output images and result files)
+ if isinstance(img_list[0], str):
+ args.filenames = [str(Path(x).stem) for x in img_list]
+ else:
+ args.filenames = [str(x) for x in range(len(img_list))]
+
+ # If given an output argument, create a list of output image filenames
+ num_res = len(img_list)
+ if args.output:
+ output_path = Path(args.output)
+ if output_path.is_dir():
+ args.output = [
+ str(output_path / f'out_{x}.png') for x in args.filenames
+ ]
+ else:
+ args.output = [str(args.output)]
+ if args.batch_mode:
+ raise AssertionError('Output of multiple images inference'
+ ' must be a directory')
+ else:
+ args.output = [None] * num_res
+
+ # If given an export argument, create a list of
+ # result filenames for each image
+ if args.export:
+ export_path = Path(args.export)
+ args.export = [
+ str(export_path / f'out_{x}.{args.export_format}')
+ for x in args.filenames
+ ]
+ else:
+ args.export = [None] * num_res
+
+ return args
+
+
+# Create an inference pipeline with parsed arguments
+def main():
+ args = parse_args()
+ ocr = MMOCR(**vars(args))
+ ocr.readtext(**vars(args))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/setup_env.py b/cv/ocr/dbnet/pytorch/dbnet/utils/setup_env.py
new file mode 100755
index 0000000000000000000000000000000000000000..21def2f0809153a5f755af2431f7e702db625e5c
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/utils/setup_env.py
@@ -0,0 +1,47 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import platform
+import warnings
+
+import cv2
+import torch.multiprocessing as mp
+
+
+def setup_multi_processes(cfg):
+ """Setup multi-processing environment variables."""
+ # set multi-process start method as `fork` to speed up the training
+ if platform.system() != 'Windows':
+ mp_start_method = cfg.get('mp_start_method', 'fork')
+ current_method = mp.get_start_method(allow_none=True)
+ if current_method is not None and current_method != mp_start_method:
+ warnings.warn(
+ f'Multi-processing start method `{mp_start_method}` is '
+ f'different from the previous setting `{current_method}`.'
+ f'It will be force set to `{mp_start_method}`. You can change '
+ f'this behavior by changing `mp_start_method` in your config.')
+ mp.set_start_method(mp_start_method, force=True)
+
+ # disable opencv multithreading to avoid system being overloaded
+ opencv_num_threads = cfg.get('opencv_num_threads', 0)
+ cv2.setNumThreads(opencv_num_threads)
+
+ # setup OMP threads
+ # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
+ if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
+ omp_num_threads = 1
+ warnings.warn(
+ f'Setting OMP_NUM_THREADS environment variable for each process '
+ f'to be {omp_num_threads} in default, to avoid your system being '
+ f'overloaded, please further tune the variable for optimal '
+ f'performance in your application as needed.')
+ os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
+
+ # setup MKL threads
+ if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
+ mkl_num_threads = 1
+ warnings.warn(
+ f'Setting MKL_NUM_THREADS environment variable for each process '
+ f'to be {mkl_num_threads} in default, to avoid your system being '
+ f'overloaded, please further tune the variable for optimal '
+ f'performance in your application as needed.')
+ os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/string_util.py b/cv/ocr/dbnet/pytorch/dbnet/utils/string_util.py
new file mode 100755
index 0000000000000000000000000000000000000000..5a8946ee6969074ebad50747758ec919d611e933
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/utils/string_util.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+class StringStrip:
+ """Removing the leading and/or the trailing characters based on the string
+ argument passed.
+
+ Args:
+ strip (bool): Whether remove characters from both left and right of
+ the string. Default: True.
+ strip_pos (str): Which position for removing, can be one of
+ ('both', 'left', 'right'), Default: 'both'.
+ strip_str (str|None): A string specifying the set of characters
+ to be removed from the left and right part of the string.
+ If None, all leading and trailing whitespaces
+ are removed from the string. Default: None.
+ """
+
+ def __init__(self, strip=True, strip_pos='both', strip_str=None):
+ assert isinstance(strip, bool)
+ assert strip_pos in ('both', 'left', 'right')
+ assert strip_str is None or isinstance(strip_str, str)
+
+ self.strip = strip
+ self.strip_pos = strip_pos
+ self.strip_str = strip_str
+
+ def __call__(self, in_str):
+
+ if not self.strip:
+ return in_str
+
+ if self.strip_pos == 'left':
+ return in_str.lstrip(self.strip_str)
+ elif self.strip_pos == 'right':
+ return in_str.rstrip(self.strip_str)
+ else:
+ return in_str.strip(self.strip_str)
diff --git a/cv/ocr/dbnet/pytorch/dbnet/version.py b/cv/ocr/dbnet/pytorch/dbnet/version.py
new file mode 100755
index 0000000000000000000000000000000000000000..3626dcd19fc26ea9d2a6f7b4c78aae3c51a0a856
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet/version.py
@@ -0,0 +1,4 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+
+__version__ = '0.6.0'
+short_version = __version__
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..65b6016ceb237276f13baab4fba2ab2a8c266426
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# flake8: noqa
+# from .arraymisc import *
+from .fileio import *
+from .image import *
+from .utils import *
+from .version import *
+# from .video import *
+# from .visualization import *
+
+# The following modules are not imported to this level, so dbnet_cv may be used
+# without PyTorch.
+# - runner
+# - parallel
+# - op
+# - device
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..d2e8f0990337f09f53c7454aff0e4b72b76d7b97
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/__init__.py
@@ -0,0 +1,55 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# from .alexnet import AlexNet
+# yapf: disable
+# from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
+# PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
+# ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule,
+# ConvTranspose2d, ConvTranspose3d, ConvWS2d,
+# DepthwiseSeparableConvModule, GeneralizedAttention,
+# HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
+# NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
+# build_activation_layer, build_conv_layer,
+# build_norm_layer, build_padding_layer, build_plugin_layer,
+# build_upsample_layer, conv_ws_2d, is_norm)
+from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
+ PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
+ Conv2d, Conv3d, ConvModule,
+ ConvTranspose2d, ConvTranspose3d, Linear, MaxPool2d, MaxPool3d,
+
+ build_activation_layer, build_conv_layer,
+ build_norm_layer, build_padding_layer, build_plugin_layer,
+ build_upsample_layer, is_norm)
+from .builder import MODELS, build_model_from_cfg
+# yapf: enable
+from .resnet import ResNet, make_res_layer
+from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
+ NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
+ XavierInit, bias_init_with_prob, caffe2_xavier_init,
+ constant_init, fuse_conv_bn, get_model_complexity_info,
+ initialize, kaiming_init, normal_init, trunc_normal_init,
+ uniform_init, xavier_init)
+# from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
+# NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
+# XavierInit, bias_init_with_prob, caffe2_xavier_init,
+# constant_init, fuse_conv_bn, get_model_complexity_info,
+# initialize, kaiming_init, normal_init, trunc_normal_init,
+# uniform_init, xavier_init)
+# from .vgg import VGG, make_vgg_layer
+
+# __all__ = [
+# 'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
+# 'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
+# 'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
+# 'bias_init_with_prob', 'ConvModule', 'build_activation_layer',
+# 'build_conv_layer', 'build_norm_layer', 'build_padding_layer',
+# 'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d',
+# 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish',
+# 'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
+# 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale',
+# 'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
+# 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d',
+# 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d',
+# 'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
+# 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
+# 'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
+# ]
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..d1e17d38c9c54a084b995ce69ab27ed993788c49
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .activation import build_activation_layer
+
+from .conv import build_conv_layer
+from .conv_module import ConvModule
+from .conv2d_adaptive_padding import Conv2dAdaptivePadding
+from .norm import build_norm_layer, is_norm
+from .padding import build_padding_layer
+from .plugin import build_plugin_layer
+from .drop import Dropout, DropPath
+from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
+ PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS)
+from .hsigmoid import HSigmoid
+from .hswish import HSwish
+from .upsample import build_upsample_layer
+from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
+ Linear, MaxPool2d, MaxPool3d)
+
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/activation.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/activation.py
new file mode 100755
index 0000000000000000000000000000000000000000..d791a04c0fdfff6aabb1ddddf3b3ec1786960402
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/activation.py
@@ -0,0 +1,93 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from dbnet_cv.utils import TORCH_VERSION, build_from_cfg, digit_version
+from .registry import ACTIVATION_LAYERS
+
+for module in [
+ nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU,
+ nn.Sigmoid, nn.Tanh
+]:
+ ACTIVATION_LAYERS.register_module(module=module)
+
+
+@ACTIVATION_LAYERS.register_module(name='Clip')
+@ACTIVATION_LAYERS.register_module()
+class Clamp(nn.Module):
+ """Clamp activation layer.
+
+ This activation function is to clamp the feature map value within
+ :math:`[min, max]`. More details can be found in ``torch.clamp()``.
+
+ Args:
+ min (Number | optional): Lower-bound of the range to be clamped to.
+ Default to -1.
+ max (Number | optional): Upper-bound of the range to be clamped to.
+ Default to 1.
+ """
+
+ def __init__(self, min=-1., max=1.):
+ super().__init__()
+ self.min = min
+ self.max = max
+
+ def forward(self, x):
+ """Forward function.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: Clamped tensor.
+ """
+ return torch.clamp(x, min=self.min, max=self.max)
+
+
+class GELU(nn.Module):
+ r"""Applies the Gaussian Error Linear Units function:
+
+ .. math::
+ \text{GELU}(x) = x * \Phi(x)
+ where :math:`\Phi(x)` is the Cumulative Distribution Function for
+ Gaussian Distribution.
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/GELU.png
+
+ Examples::
+
+ >>> m = nn.GELU()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+
+ def forward(self, input):
+ return F.gelu(input)
+
+
+if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.4')):
+ ACTIVATION_LAYERS.register_module(module=GELU)
+else:
+ ACTIVATION_LAYERS.register_module(module=nn.GELU)
+
+
+def build_activation_layer(cfg):
+ """Build activation layer.
+
+ Args:
+ cfg (dict): The activation layer config, which should contain:
+
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate an activation layer.
+
+ Returns:
+ nn.Module: Created activation layer.
+ """
+ return build_from_cfg(cfg, ACTIVATION_LAYERS)
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv.py
new file mode 100755
index 0000000000000000000000000000000000000000..f6c35fd70b3c6ae0812432e363adbc5711e71b18
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch import nn
+
+from .registry import CONV_LAYERS
+
+CONV_LAYERS.register_module('Conv1d', module=nn.Conv1d)
+CONV_LAYERS.register_module('Conv2d', module=nn.Conv2d)
+CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d)
+CONV_LAYERS.register_module('Conv', module=nn.Conv2d)
+
+
+def build_conv_layer(cfg, *args, **kwargs):
+ """Build convolution layer.
+
+ Args:
+ cfg (None or dict): The conv layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate an conv layer.
+ args (argument list): Arguments passed to the `__init__`
+ method of the corresponding conv layer.
+ kwargs (keyword arguments): Keyword arguments passed to the `__init__`
+ method of the corresponding conv layer.
+
+ Returns:
+ nn.Module: Created conv layer.
+ """
+ if cfg is None:
+ cfg_ = dict(type='Conv2d')
+ else:
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in CONV_LAYERS:
+ raise KeyError(f'Unrecognized layer type {layer_type}')
+ else:
+ conv_layer = CONV_LAYERS.get(layer_type)
+
+ layer = conv_layer(*args, **kwargs, **cfg_)
+
+ return layer
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv2d_adaptive_padding.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv2d_adaptive_padding.py
new file mode 100755
index 0000000000000000000000000000000000000000..b45e758ac6cf8dfb0382d072fe09125bc7e9b888
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv2d_adaptive_padding.py
@@ -0,0 +1,62 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+from torch import nn
+from torch.nn import functional as F
+
+from .registry import CONV_LAYERS
+
+
+@CONV_LAYERS.register_module()
+class Conv2dAdaptivePadding(nn.Conv2d):
+ """Implementation of 2D convolution in tensorflow with `padding` as "same",
+ which applies padding to input (if needed) so that input image gets fully
+ covered by filter and stride you specified. For stride 1, this will ensure
+ that output image size is same as input. For stride of 2, output dimensions
+ will be half, for example.
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of
+ the input. Default: 0
+ dilation (int or tuple, optional): Spacing between kernel elements.
+ Default: 1
+ groups (int, optional): Number of blocked connections from input
+ channels to output channels. Default: 1
+ bias (bool, optional): If ``True``, adds a learnable bias to the
+ output. Default: ``True``
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True):
+ super().__init__(in_channels, out_channels, kernel_size, stride, 0,
+ dilation, groups, bias)
+
+ def forward(self, x):
+ img_h, img_w = x.size()[-2:]
+ kernel_h, kernel_w = self.weight.size()[-2:]
+ stride_h, stride_w = self.stride
+ output_h = math.ceil(img_h / stride_h)
+ output_w = math.ceil(img_w / stride_w)
+ pad_h = (
+ max((output_h - 1) * self.stride[0] +
+ (kernel_h - 1) * self.dilation[0] + 1 - img_h, 0))
+ pad_w = (
+ max((output_w - 1) * self.stride[1] +
+ (kernel_w - 1) * self.dilation[1] + 1 - img_w, 0))
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, [
+ pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
+ ])
+ return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups)
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv_module.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv_module.py
new file mode 100755
index 0000000000000000000000000000000000000000..5c27f6e39bb8cce8abc49d7283fc5941d24ec204
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv_module.py
@@ -0,0 +1,206 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch.nn as nn
+
+from dbnet_cv.utils import _BatchNorm, _InstanceNorm
+from ..utils import constant_init, kaiming_init
+from .activation import build_activation_layer
+from .conv import build_conv_layer
+from .norm import build_norm_layer
+from .padding import build_padding_layer
+from .registry import PLUGIN_LAYERS
+
+
+@PLUGIN_LAYERS.register_module()
+class ConvModule(nn.Module):
+ """A conv block that bundles conv/norm/activation layers.
+
+ This block simplifies the usage of convolution layers, which are commonly
+ used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
+ It is based upon three build methods: `build_conv_layer()`,
+ `build_norm_layer()` and `build_activation_layer()`.
+
+ Besides, we add some additional features in this module.
+ 1. Automatically set `bias` of the conv layer.
+ 2. Spectral norm is supported.
+ 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
+ supports zero and circular padding, and we add "reflect" padding mode.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ Same as that in ``nn._ConvNd``.
+ out_channels (int): Number of channels produced by the convolution.
+ Same as that in ``nn._ConvNd``.
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
+ Same as that in ``nn._ConvNd``.
+ stride (int | tuple[int]): Stride of the convolution.
+ Same as that in ``nn._ConvNd``.
+ padding (int | tuple[int]): Zero-padding added to both sides of
+ the input. Same as that in ``nn._ConvNd``.
+ dilation (int | tuple[int]): Spacing between kernel elements.
+ Same as that in ``nn._ConvNd``.
+ groups (int): Number of blocked connections from input channels to
+ output channels. Same as that in ``nn._ConvNd``.
+ bias (bool | str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
+ False. Default: "auto".
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ inplace (bool): Whether to use inplace mode for activation.
+ Default: True.
+ with_spectral_norm (bool): Whether use spectral norm in conv module.
+ Default: False.
+ padding_mode (str): If the `padding_mode` has not been supported by
+ current `Conv2d` in PyTorch, we will use our own padding layer
+ instead. Currently, we support ['zeros', 'circular'] with official
+ implementation and ['reflect'] with our own implementation.
+ Default: 'zeros'.
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
+ sequence of "conv", "norm" and "act". Common examples are
+ ("conv", "norm", "act") and ("act", "conv", "norm").
+ Default: ('conv', 'norm', 'act').
+ """
+
+ _abbr_ = 'conv_block'
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias='auto',
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ inplace=True,
+ with_spectral_norm=False,
+ padding_mode='zeros',
+ order=('conv', 'norm', 'act')):
+ super().__init__()
+ assert conv_cfg is None or isinstance(conv_cfg, dict)
+ assert norm_cfg is None or isinstance(norm_cfg, dict)
+ assert act_cfg is None or isinstance(act_cfg, dict)
+ official_padding_mode = ['zeros', 'circular']
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.inplace = inplace
+ self.with_spectral_norm = with_spectral_norm
+ self.with_explicit_padding = padding_mode not in official_padding_mode
+ self.order = order
+ assert isinstance(self.order, tuple) and len(self.order) == 3
+ assert set(order) == {'conv', 'norm', 'act'}
+
+ self.with_norm = norm_cfg is not None
+ self.with_activation = act_cfg is not None
+ # if the conv layer is before a norm layer, bias is unnecessary.
+ if bias == 'auto':
+ bias = not self.with_norm
+ self.with_bias = bias
+
+ if self.with_explicit_padding:
+ pad_cfg = dict(type=padding_mode)
+ self.padding_layer = build_padding_layer(pad_cfg, padding)
+
+ # reset padding to 0 for conv module
+ conv_padding = 0 if self.with_explicit_padding else padding
+ # build convolution layer
+ self.conv = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=conv_padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ # export the attributes of self.conv to a higher level for convenience
+ self.in_channels = self.conv.in_channels
+ self.out_channels = self.conv.out_channels
+ self.kernel_size = self.conv.kernel_size
+ self.stride = self.conv.stride
+ self.padding = padding
+ self.dilation = self.conv.dilation
+ self.transposed = self.conv.transposed
+ self.output_padding = self.conv.output_padding
+ self.groups = self.conv.groups
+
+ if self.with_spectral_norm:
+ self.conv = nn.utils.spectral_norm(self.conv)
+
+ # build normalization layers
+ if self.with_norm:
+ # norm layer is after conv layer
+ if order.index('norm') > order.index('conv'):
+ norm_channels = out_channels
+ else:
+ norm_channels = in_channels
+ self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
+ self.add_module(self.norm_name, norm)
+ if self.with_bias:
+ if isinstance(norm, (_BatchNorm, _InstanceNorm)):
+ warnings.warn(
+ 'Unnecessary conv bias before batch/instance norm')
+ else:
+ self.norm_name = None
+
+ # build activation layer
+ if self.with_activation:
+ act_cfg_ = act_cfg.copy()
+ # nn.Tanh has no 'inplace' argument
+ if act_cfg_['type'] not in [
+ 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish', 'GELU'
+ ]:
+ act_cfg_.setdefault('inplace', inplace)
+ self.activate = build_activation_layer(act_cfg_)
+
+ # Use msra init by default
+ self.init_weights()
+
+ @property
+ def norm(self):
+ if self.norm_name:
+ return getattr(self, self.norm_name)
+ else:
+ return None
+
+ def init_weights(self):
+ # 1. It is mainly for customized conv layers with their own
+ # initialization manners by calling their own ``init_weights()``,
+ # and we do not want ConvModule to override the initialization.
+ # 2. For customized conv layers without their own initialization
+ # manners (that is, they don't have their own ``init_weights()``)
+ # and PyTorch's conv layers, they will be initialized by
+ # this method with default ``kaiming_init``.
+ # Note: For PyTorch's conv layers, they will be overwritten by our
+ # initialization implementation using default ``kaiming_init``.
+ if not hasattr(self.conv, 'init_weights'):
+ if self.with_activation and self.act_cfg['type'] == 'LeakyReLU':
+ nonlinearity = 'leaky_relu'
+ a = self.act_cfg.get('negative_slope', 0.01)
+ else:
+ nonlinearity = 'relu'
+ a = 0
+ kaiming_init(self.conv, a=a, nonlinearity=nonlinearity)
+ if self.with_norm:
+ constant_init(self.norm, 1, bias=0)
+
+ def forward(self, x, activate=True, norm=True):
+ for layer in self.order:
+ if layer == 'conv':
+ if self.with_explicit_padding:
+ x = self.padding_layer(x)
+ x = self.conv(x)
+ elif layer == 'norm' and norm and self.with_norm:
+ x = self.norm(x)
+ elif layer == 'act' and activate and self.with_activation:
+ x = self.activate(x)
+ return x
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/drop.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/drop.py
new file mode 100755
index 0000000000000000000000000000000000000000..ff29bac8d1c64346843efafc93da34d1eba8273a
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/drop.py
@@ -0,0 +1,65 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Any, Dict, Optional
+
+import torch
+import torch.nn as nn
+
+from dbnet_cv import build_from_cfg
+from .registry import DROPOUT_LAYERS
+
+
+def drop_path(x: torch.Tensor,
+ drop_prob: float = 0.,
+ training: bool = False) -> torch.Tensor:
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks).
+ We follow the implementation
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ # handle tensors with different dimensions, not just 4D tensors.
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
+ random_tensor = keep_prob + torch.rand(
+ shape, dtype=x.dtype, device=x.device)
+ output = x.div(keep_prob) * random_tensor.floor()
+ return output
+
+
+@DROPOUT_LAYERS.register_module()
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks).
+ We follow the implementation
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
+ Args:
+ drop_prob (float): Probability of the path to be zeroed. Default: 0.1
+ """
+
+ def __init__(self, drop_prob: float = 0.1):
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return drop_path(x, self.drop_prob, self.training)
+
+
+@DROPOUT_LAYERS.register_module()
+class Dropout(nn.Dropout):
+ """A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of
+ ``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with
+ ``DropPath``
+ Args:
+ drop_prob (float): Probability of the elements to be
+ zeroed. Default: 0.5.
+ inplace (bool): Do the operation inplace or not. Default: False.
+ """
+
+ def __init__(self, drop_prob: float = 0.5, inplace: bool = False):
+ super().__init__(p=drop_prob, inplace=inplace)
+
+
+def build_dropout(cfg: Dict, default_args: Optional[Dict] = None) -> Any:
+ """Builder for drop out layers."""
+ return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
\ No newline at end of file
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/hsigmoid.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/hsigmoid.py
new file mode 100755
index 0000000000000000000000000000000000000000..1854517c5342d97b64254a646745cf4671b71551
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/hsigmoid.py
@@ -0,0 +1,48 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+import torch.nn as nn
+
+from .registry import ACTIVATION_LAYERS
+
+
+@ACTIVATION_LAYERS.register_module()
+class HSigmoid(nn.Module):
+ """Hard Sigmoid Module. Apply the hard sigmoid function:
+ Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value)
+ Default: Hsigmoid(x) = min(max((x + 3) / 6, 0), 1)
+ Note:
+ In MMCV v1.4.4, we modified the default value of args to align with
+ PyTorch official.
+ Args:
+ bias (float): Bias of the input feature map. Default: 3.0.
+ divisor (float): Divisor of the input feature map. Default: 6.0.
+ min_value (float): Lower bound value. Default: 0.0.
+ max_value (float): Upper bound value. Default: 1.0.
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self,
+ bias: float = 3.0,
+ divisor: float = 6.0,
+ min_value: float = 0.0,
+ max_value: float = 1.0):
+ super().__init__()
+ warnings.warn(
+ 'In MMCV v1.4.4, we modified the default value of args to align '
+ 'with PyTorch official. Previous Implementation: '
+ 'Hsigmoid(x) = min(max((x + 1) / 2, 0), 1). '
+ 'Current Implementation: '
+ 'Hsigmoid(x) = min(max((x + 3) / 6, 0), 1).')
+ self.bias = bias
+ self.divisor = divisor
+ assert self.divisor != 0
+ self.min_value = min_value
+ self.max_value = max_value
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = (x + self.bias) / self.divisor
+
+ return x.clamp_(self.min_value, self.max_value)
\ No newline at end of file
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/hswish.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/hswish.py
new file mode 100755
index 0000000000000000000000000000000000000000..1d3ee70e2e9fd6cecdaa7fc59328197660f012f8
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/hswish.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+from dbnet_cv.utils import TORCH_VERSION, digit_version
+from .registry import ACTIVATION_LAYERS
+
+
+class HSwish(nn.Module):
+ """Hard Swish Module.
+ This module applies the hard swish function:
+ .. math::
+ Hswish(x) = x * ReLU6(x + 3) / 6
+ Args:
+ inplace (bool): can optionally do the operation in-place.
+ Default: False.
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self, inplace: bool = False):
+ super().__init__()
+ self.act = nn.ReLU6(inplace)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return x * self.act(x + 3) / 6
+
+
+if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.7')):
+ # Hardswish is not supported when PyTorch version < 1.6.
+ # And Hardswish in PyTorch 1.6 does not support inplace.
+ ACTIVATION_LAYERS.register_module(module=HSwish)
+else:
+ ACTIVATION_LAYERS.register_module(module=nn.Hardswish, name='HSwish')
\ No newline at end of file
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/norm.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/norm.py
new file mode 100755
index 0000000000000000000000000000000000000000..9516d812c8a015b3f859453e2e30dde2ec1a47e9
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/norm.py
@@ -0,0 +1,144 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+
+import torch.nn as nn
+
+from dbnet_cv.utils import is_tuple_of
+from dbnet_cv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm, _InstanceNorm
+from .registry import NORM_LAYERS
+
+NORM_LAYERS.register_module('BN', module=nn.BatchNorm2d)
+NORM_LAYERS.register_module('BN1d', module=nn.BatchNorm1d)
+NORM_LAYERS.register_module('BN2d', module=nn.BatchNorm2d)
+NORM_LAYERS.register_module('BN3d', module=nn.BatchNorm3d)
+NORM_LAYERS.register_module('SyncBN', module=SyncBatchNorm)
+NORM_LAYERS.register_module('GN', module=nn.GroupNorm)
+NORM_LAYERS.register_module('LN', module=nn.LayerNorm)
+NORM_LAYERS.register_module('IN', module=nn.InstanceNorm2d)
+NORM_LAYERS.register_module('IN1d', module=nn.InstanceNorm1d)
+NORM_LAYERS.register_module('IN2d', module=nn.InstanceNorm2d)
+NORM_LAYERS.register_module('IN3d', module=nn.InstanceNorm3d)
+
+
+def infer_abbr(class_type):
+ """Infer abbreviation from the class name.
+
+ When we build a norm layer with `build_norm_layer()`, we want to preserve
+ the norm type in variable names, e.g, self.bn1, self.gn. This method will
+ infer the abbreviation to map class types to abbreviations.
+
+ Rule 1: If the class has the property "_abbr_", return the property.
+ Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
+ InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
+ "in" respectively.
+ Rule 3: If the class name contains "batch", "group", "layer" or "instance",
+ the abbreviation of this layer will be "bn", "gn", "ln" and "in"
+ respectively.
+ Rule 4: Otherwise, the abbreviation falls back to "norm".
+
+ Args:
+ class_type (type): The norm layer type.
+
+ Returns:
+ str: The inferred abbreviation.
+ """
+ if not inspect.isclass(class_type):
+ raise TypeError(
+ f'class_type must be a type, but got {type(class_type)}')
+ if hasattr(class_type, '_abbr_'):
+ return class_type._abbr_
+ if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN
+ return 'in'
+ elif issubclass(class_type, _BatchNorm):
+ return 'bn'
+ elif issubclass(class_type, nn.GroupNorm):
+ return 'gn'
+ elif issubclass(class_type, nn.LayerNorm):
+ return 'ln'
+ else:
+ class_name = class_type.__name__.lower()
+ if 'batch' in class_name:
+ return 'bn'
+ elif 'group' in class_name:
+ return 'gn'
+ elif 'layer' in class_name:
+ return 'ln'
+ elif 'instance' in class_name:
+ return 'in'
+ else:
+ return 'norm_layer'
+
+
+def build_norm_layer(cfg, num_features, postfix=''):
+ """Build normalization layer.
+
+ Args:
+ cfg (dict): The norm layer config, which should contain:
+
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate a norm layer.
+ - requires_grad (bool, optional): Whether stop gradient updates.
+ num_features (int): Number of input channels.
+ postfix (int | str): The postfix to be appended into norm abbreviation
+ to create named layer.
+
+ Returns:
+ tuple[str, nn.Module]: The first element is the layer name consisting
+ of abbreviation and postfix, e.g., bn1, gn. The second element is the
+ created norm layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in NORM_LAYERS:
+ raise KeyError(f'Unrecognized norm type {layer_type}')
+
+ norm_layer = NORM_LAYERS.get(layer_type)
+ abbr = infer_abbr(norm_layer)
+
+ assert isinstance(postfix, (int, str))
+ name = abbr + str(postfix)
+
+ requires_grad = cfg_.pop('requires_grad', True)
+ cfg_.setdefault('eps', 1e-5)
+ if layer_type != 'GN':
+ layer = norm_layer(num_features, **cfg_)
+ if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
+ layer._specify_ddp_gpu_num(1)
+ else:
+ assert 'num_groups' in cfg_
+ layer = norm_layer(num_channels=num_features, **cfg_)
+
+ for param in layer.parameters():
+ param.requires_grad = requires_grad
+
+ return name, layer
+
+
+def is_norm(layer, exclude=None):
+ """Check if a layer is a normalization layer.
+
+ Args:
+ layer (nn.Module): The layer to be checked.
+ exclude (type | tuple[type]): Types to be excluded.
+
+ Returns:
+ bool: Whether the layer is a norm layer.
+ """
+ if exclude is not None:
+ if not isinstance(exclude, tuple):
+ exclude = (exclude, )
+ if not is_tuple_of(exclude, type):
+ raise TypeError(
+ f'"exclude" must be either None or type or a tuple of types, '
+ f'but got {type(exclude)}: {exclude}')
+
+ if exclude and isinstance(layer, exclude):
+ return False
+
+ all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
+ return isinstance(layer, all_norm_bases)
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/padding.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/padding.py
new file mode 100755
index 0000000000000000000000000000000000000000..e4ac6b28a1789bd551c613a7d3e7b622433ac7ec
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/padding.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from .registry import PADDING_LAYERS
+
+PADDING_LAYERS.register_module('zero', module=nn.ZeroPad2d)
+PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
+PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
+
+
+def build_padding_layer(cfg, *args, **kwargs):
+ """Build padding layer.
+
+ Args:
+ cfg (None or dict): The padding layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate a padding layer.
+
+ Returns:
+ nn.Module: Created padding layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+
+ cfg_ = cfg.copy()
+ padding_type = cfg_.pop('type')
+ if padding_type not in PADDING_LAYERS:
+ raise KeyError(f'Unrecognized padding type {padding_type}.')
+ else:
+ padding_layer = PADDING_LAYERS.get(padding_type)
+
+ layer = padding_layer(*args, **kwargs, **cfg_)
+
+ return layer
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/plugin.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/plugin.py
new file mode 100755
index 0000000000000000000000000000000000000000..6aa13f439ba6d55db6cf45c0e94b7b895c82d0ec
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/plugin.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import platform
+
+from .registry import PLUGIN_LAYERS
+
+if platform.system() == 'Windows':
+ import regex as re # type: ignore
+else:
+ import re # type: ignore
+
+
+def infer_abbr(class_type):
+ """Infer abbreviation from the class name.
+
+ This method will infer the abbreviation to map class types to
+ abbreviations.
+
+ Rule 1: If the class has the property "abbr", return the property.
+ Rule 2: Otherwise, the abbreviation falls back to snake case of class
+ name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``.
+
+ Args:
+ class_type (type): The norm layer type.
+
+ Returns:
+ str: The inferred abbreviation.
+ """
+
+ def camel2snack(word):
+ """Convert camel case word into snack case.
+
+ Modified from `inflection lib
+ `_.
+
+ Example::
+
+ >>> camel2snack("FancyBlock")
+ 'fancy_block'
+ """
+
+ word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word)
+ word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word)
+ word = word.replace('-', '_')
+ return word.lower()
+
+ if not inspect.isclass(class_type):
+ raise TypeError(
+ f'class_type must be a type, but got {type(class_type)}')
+ if hasattr(class_type, '_abbr_'):
+ return class_type._abbr_
+ else:
+ return camel2snack(class_type.__name__)
+
+
+def build_plugin_layer(cfg, postfix='', **kwargs):
+ """Build plugin layer.
+
+ Args:
+ cfg (None or dict): cfg should contain:
+
+ - type (str): identify plugin layer type.
+ - layer args: args needed to instantiate a plugin layer.
+ postfix (int, str): appended into norm abbreviation to
+ create named layer. Default: ''.
+
+ Returns:
+ tuple[str, nn.Module]: The first one is the concatenation of
+ abbreviation and postfix. The second is the created plugin layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in PLUGIN_LAYERS:
+ raise KeyError(f'Unrecognized plugin type {layer_type}')
+
+ plugin_layer = PLUGIN_LAYERS.get(layer_type)
+ abbr = infer_abbr(plugin_layer)
+
+ assert isinstance(postfix, (int, str))
+ name = abbr + str(postfix)
+
+ layer = plugin_layer(**kwargs, **cfg_)
+
+ return name, layer
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/registry.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/registry.py
new file mode 100755
index 0000000000000000000000000000000000000000..b56fbe1816fefb3bf6d70e480419e36d05883850
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/registry.py
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from dbnet_cv.utils import Registry
+
+CONV_LAYERS = Registry('conv layer')
+NORM_LAYERS = Registry('norm layer')
+ACTIVATION_LAYERS = Registry('activation layer')
+PADDING_LAYERS = Registry('padding layer')
+UPSAMPLE_LAYERS = Registry('upsample layer')
+PLUGIN_LAYERS = Registry('plugin layer')
+
+DROPOUT_LAYERS = Registry('drop out layers')
+POSITIONAL_ENCODING = Registry('position encoding')
+ATTENTION = Registry('attention')
+FEEDFORWARD_NETWORK = Registry('feed-forward Network')
+TRANSFORMER_LAYER = Registry('transformerLayer')
+TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/upsample.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/upsample.py
new file mode 100755
index 0000000000000000000000000000000000000000..15e4febde52fde324c5576ffae36a044e45ae9de
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/upsample.py
@@ -0,0 +1,84 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..utils import xavier_init
+from .registry import UPSAMPLE_LAYERS
+
+UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample)
+UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample)
+
+
+@UPSAMPLE_LAYERS.register_module(name='pixel_shuffle')
+class PixelShufflePack(nn.Module):
+ """Pixel Shuffle upsample layer.
+
+ This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to
+ achieve a simple upsampling with pixel shuffle.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ scale_factor (int): Upsample ratio.
+ upsample_kernel (int): Kernel size of the conv layer to expand the
+ channels.
+ """
+
+ def __init__(self, in_channels, out_channels, scale_factor,
+ upsample_kernel):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.scale_factor = scale_factor
+ self.upsample_kernel = upsample_kernel
+ self.upsample_conv = nn.Conv2d(
+ self.in_channels,
+ self.out_channels * scale_factor * scale_factor,
+ self.upsample_kernel,
+ padding=(self.upsample_kernel - 1) // 2)
+ self.init_weights()
+
+ def init_weights(self):
+ xavier_init(self.upsample_conv, distribution='uniform')
+
+ def forward(self, x):
+ x = self.upsample_conv(x)
+ x = F.pixel_shuffle(x, self.scale_factor)
+ return x
+
+
+def build_upsample_layer(cfg, *args, **kwargs):
+ """Build upsample layer.
+
+ Args:
+ cfg (dict): The upsample layer config, which should contain:
+
+ - type (str): Layer type.
+ - scale_factor (int): Upsample ratio, which is not applicable to
+ deconv.
+ - layer args: Args needed to instantiate a upsample layer.
+ args (argument list): Arguments passed to the ``__init__``
+ method of the corresponding conv layer.
+ kwargs (keyword arguments): Keyword arguments passed to the
+ ``__init__`` method of the corresponding conv layer.
+
+ Returns:
+ nn.Module: Created upsample layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
+ if 'type' not in cfg:
+ raise KeyError(
+ f'the cfg dict must contain the key "type", but got {cfg}')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in UPSAMPLE_LAYERS:
+ raise KeyError(f'Unrecognized upsample type {layer_type}')
+ else:
+ upsample = UPSAMPLE_LAYERS.get(layer_type)
+
+ if upsample is nn.Upsample:
+ cfg_['mode'] = layer_type
+ layer = upsample(*args, **kwargs, **cfg_)
+ return layer
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/wrappers.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/wrappers.py
new file mode 100755
index 0000000000000000000000000000000000000000..8aebf67bf52355a513f21756ee74fe510902d075
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/wrappers.py
@@ -0,0 +1,180 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+r"""Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/wrappers.py # noqa: E501
+
+Wrap some nn modules to support empty tensor input. Currently, these wrappers
+are mainly used in mask heads like fcn_mask_head and maskiou_heads since mask
+heads are trained on only positive RoIs.
+"""
+import math
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.utils import _pair, _triple
+
+from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
+
+if torch.__version__ == 'parrots':
+ TORCH_VERSION = torch.__version__
+else:
+ # torch.__version__ could be 1.3.1+cu92, we only need the first two
+ # for comparison
+ TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
+
+
+def obsolete_torch_version(torch_version, version_threshold):
+ return torch_version == 'parrots' or torch_version <= version_threshold
+
+
+class NewEmptyTensorOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, new_shape):
+ ctx.shape = x.shape
+ return x.new_empty(new_shape)
+
+ @staticmethod
+ def backward(ctx, grad):
+ shape = ctx.shape
+ return NewEmptyTensorOp.apply(grad, shape), None
+
+
+@CONV_LAYERS.register_module('Conv', force=True)
+class Conv2d(nn.Conv2d):
+
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
+ self.padding, self.stride, self.dilation):
+ o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+@CONV_LAYERS.register_module('Conv3d', force=True)
+class Conv3d(nn.Conv3d):
+
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
+ self.padding, self.stride, self.dilation):
+ o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+@CONV_LAYERS.register_module()
+@CONV_LAYERS.register_module('deconv')
+@UPSAMPLE_LAYERS.register_module('deconv', force=True)
+class ConvTranspose2d(nn.ConvTranspose2d):
+
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
+ self.padding, self.stride,
+ self.dilation, self.output_padding):
+ out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+@CONV_LAYERS.register_module()
+@CONV_LAYERS.register_module('deconv3d')
+@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
+class ConvTranspose3d(nn.ConvTranspose3d):
+
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
+ self.padding, self.stride,
+ self.dilation, self.output_padding):
+ out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+class MaxPool2d(nn.MaxPool2d):
+
+ def forward(self, x):
+ # PyTorch 1.9 does not support empty tensor inference yet
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
+ out_shape = list(x.shape[:2])
+ for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
+ _pair(self.padding), _pair(self.stride),
+ _pair(self.dilation)):
+ o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
+ o = math.ceil(o) if self.ceil_mode else math.floor(o)
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ return empty
+
+ return super().forward(x)
+
+
+class MaxPool3d(nn.MaxPool3d):
+
+ def forward(self, x):
+ # PyTorch 1.9 does not support empty tensor inference yet
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
+ out_shape = list(x.shape[:2])
+ for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
+ _triple(self.padding),
+ _triple(self.stride),
+ _triple(self.dilation)):
+ o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
+ o = math.ceil(o) if self.ceil_mode else math.floor(o)
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ return empty
+
+ return super().forward(x)
+
+
+class Linear(torch.nn.Linear):
+
+ def forward(self, x):
+ # empty tensor forward of Linear layer is supported in Pytorch 1.6
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
+ out_shape = [x.shape[0], self.out_features]
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/builder.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/builder.py
new file mode 100755
index 0000000000000000000000000000000000000000..7567316c566bd3aca6d8f65a84b00e9e890948a7
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/builder.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..runner import Sequential
+from ..utils import Registry, build_from_cfg
+
+
+def build_model_from_cfg(cfg, registry, default_args=None):
+ """Build a PyTorch model from config dict(s). Different from
+ ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
+
+ Args:
+ cfg (dict, list[dict]): The config of modules, is is either a config
+ dict or a list of config dicts. If cfg is a list, a
+ the built modules will be wrapped with ``nn.Sequential``.
+ registry (:obj:`Registry`): A registry the module belongs to.
+ default_args (dict, optional): Default arguments to build the module.
+ Defaults to None.
+
+ Returns:
+ nn.Module: A built nn module.
+ """
+ if isinstance(cfg, list):
+ modules = [
+ build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
+ ]
+ return Sequential(*modules)
+ else:
+ return build_from_cfg(cfg, registry, default_args)
+
+
+MODELS = Registry('model', build_func=build_model_from_cfg)
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/resnet.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/resnet.py
new file mode 100755
index 0000000000000000000000000000000000000000..fb29e6256280b671acfbf73fd9a01f079749b260
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/resnet.py
@@ -0,0 +1,322 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+from typing import Optional, Sequence, Tuple, Union
+
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from torch import Tensor
+
+from .utils import constant_init, kaiming_init
+
+
+def conv3x3(in_planes: int,
+ out_planes: int,
+ stride: int = 1,
+ dilation: int = 1):
+ """3x3 convolution with padding."""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ dilation: int = 1,
+ downsample: Optional[nn.Module] = None,
+ style: str = 'pytorch',
+ with_cp: bool = False):
+ super().__init__()
+ assert style in ['pytorch', 'caffe']
+ self.conv1 = conv3x3(inplanes, planes, stride, dilation)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ assert not with_cp
+
+ def forward(self, x: Tensor) -> Tensor:
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ dilation: int = 1,
+ downsample: Optional[nn.Module] = None,
+ style: str = 'pytorch',
+ with_cp: bool = False):
+ """Bottleneck block.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
+ it is "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ super().__init__()
+ assert style in ['pytorch', 'caffe']
+ if style == 'pytorch':
+ conv1_stride = 1
+ conv2_stride = stride
+ else:
+ conv1_stride = stride
+ conv2_stride = 1
+ self.conv1 = nn.Conv2d(
+ inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False)
+ self.conv2 = nn.Conv2d(
+ planes,
+ planes,
+ kernel_size=3,
+ stride=conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(
+ planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.with_cp = with_cp
+
+ def forward(self, x: Tensor) -> Tensor:
+
+ def _inner_forward(x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+def make_res_layer(block: nn.Module,
+ inplanes: int,
+ planes: int,
+ blocks: int,
+ stride: int = 1,
+ dilation: int = 1,
+ style: str = 'pytorch',
+ with_cp: bool = False) -> nn.Module:
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(
+ block(
+ inplanes,
+ planes,
+ stride,
+ dilation,
+ downsample,
+ style=style,
+ with_cp=with_cp))
+ inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(inplanes, planes, 1, dilation, style=style, with_cp=with_cp))
+
+ return nn.Sequential(*layers)
+
+
+class ResNet(nn.Module):
+ """ResNet backbone.
+
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ num_stages (int): Resnet stages, normally 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
+ running stats (mean and var).
+ bn_frozen (bool): Whether to freeze weight and bias of BN layers.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ """
+
+ arch_settings = {
+ 18: (BasicBlock, (2, 2, 2, 2)),
+ 34: (BasicBlock, (3, 4, 6, 3)),
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self,
+ depth: int,
+ num_stages: int = 4,
+ strides: Sequence[int] = (1, 2, 2, 2),
+ dilations: Sequence[int] = (1, 1, 1, 1),
+ out_indices: Sequence[int] = (0, 1, 2, 3),
+ style: str = 'pytorch',
+ frozen_stages: int = -1,
+ bn_eval: bool = True,
+ bn_frozen: bool = False,
+ with_cp: bool = False):
+ super().__init__()
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for resnet')
+ assert num_stages >= 1 and num_stages <= 4
+ block, stage_blocks = self.arch_settings[depth]
+ stage_blocks = stage_blocks[:num_stages] # type: ignore
+ assert len(strides) == len(dilations) == num_stages
+ assert max(out_indices) < num_stages
+
+ self.out_indices = out_indices
+ self.style = style
+ self.frozen_stages = frozen_stages
+ self.bn_eval = bn_eval
+ self.bn_frozen = bn_frozen
+ self.with_cp = with_cp
+
+ self.inplanes: int = 64
+ self.conv1 = nn.Conv2d(
+ 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.res_layers = []
+ for i, num_blocks in enumerate(stage_blocks):
+ stride = strides[i]
+ dilation = dilations[i]
+ planes = 64 * 2**i
+ res_layer = make_res_layer(
+ block,
+ self.inplanes,
+ planes,
+ num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ with_cp=with_cp)
+ self.inplanes = planes * block.expansion # type: ignore
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self.feat_dim = block.expansion * 64 * 2**( # type: ignore
+ len(stage_blocks) - 1)
+
+ def init_weights(self, pretrained: Optional[str] = None) -> None:
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ from ..runner import load_checkpoint
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor]]:
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ if len(outs) == 1:
+ return outs[0]
+ else:
+ return tuple(outs)
+
+ def train(self, mode: bool = True) -> None:
+ super().train(mode)
+ if self.bn_eval:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ if self.bn_frozen:
+ for params in m.parameters():
+ params.requires_grad = False
+ if mode and self.frozen_stages >= 0:
+ for param in self.conv1.parameters():
+ param.requires_grad = False
+ for param in self.bn1.parameters():
+ param.requires_grad = False
+ self.bn1.eval()
+ self.bn1.weight.requires_grad = False
+ self.bn1.bias.requires_grad = False
+ for i in range(1, self.frozen_stages + 1):
+ mod = getattr(self, f'layer{i}')
+ mod.eval()
+ for param in mod.parameters():
+ param.requires_grad = False
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..a263e31c1e3977712827ca229bbc04910b4e928e
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .flops_counter import get_model_complexity_info
+from .fuse_conv_bn import fuse_conv_bn
+from .sync_bn import revert_sync_batchnorm
+from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
+ KaimingInit, NormalInit, PretrainedInit,
+ TruncNormalInit, UniformInit, XavierInit,
+ bias_init_with_prob, caffe2_xavier_init,
+ constant_init, initialize, kaiming_init, normal_init,
+ trunc_normal_init, uniform_init, xavier_init)
+
+__all__ = [
+ 'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
+ 'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init',
+ 'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
+ 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
+ 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
+ 'Caffe2XavierInit', 'revert_sync_batchnorm'
+]
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/flops_counter.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/flops_counter.py
new file mode 100755
index 0000000000000000000000000000000000000000..b8024ad3d14a3c9c52ae0fff88099fb0473b2d26
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/flops_counter.py
@@ -0,0 +1,603 @@
+# Modified from flops-counter.pytorch by Vladislav Sovrasov
+# original repo: https://github.com/sovrasov/flops-counter.pytorch
+
+# MIT License
+
+# Copyright (c) 2018 Vladislav Sovrasov
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import sys
+import warnings
+from functools import partial
+from typing import Any, Callable, Dict, Optional, TextIO, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+import dbnet_cv
+
+
+def get_model_complexity_info(model: nn.Module,
+ input_shape: tuple,
+ print_per_layer_stat: bool = True,
+ as_strings: bool = True,
+ input_constructor: Optional[Callable] = None,
+ flush: bool = False,
+ ost: TextIO = sys.stdout) -> tuple:
+ """Get complexity information of a model.
+
+ This method can calculate FLOPs and parameter counts of a model with
+ corresponding input shape. It can also print complexity information for
+ each layer in a model.
+
+ Supported layers are listed as below:
+ - Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``.
+ - Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``,
+ ``nn.LeakyReLU``, ``nn.ReLU6``.
+ - Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``,
+ ``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``,
+ ``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``,
+ ``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``,
+ ``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``.
+ - BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``,
+ ``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``,
+ ``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``.
+ - Linear: ``nn.Linear``.
+ - Deconvolution: ``nn.ConvTranspose2d``.
+ - Upsample: ``nn.Upsample``.
+
+ Args:
+ model (nn.Module): The model for complexity calculation.
+ input_shape (tuple): Input shape used for calculation.
+ print_per_layer_stat (bool): Whether to print complexity information
+ for each layer in a model. Default: True.
+ as_strings (bool): Output FLOPs and params counts in a string form.
+ Default: True.
+ input_constructor (None | callable): If specified, it takes a callable
+ method that generates input. otherwise, it will generate a random
+ tensor with input shape to calculate FLOPs. Default: None.
+ flush (bool): same as that in :func:`print`. Default: False.
+ ost (stream): same as ``file`` param in :func:`print`.
+ Default: sys.stdout.
+
+ Returns:
+ tuple[float | str]: If ``as_strings`` is set to True, it will return
+ FLOPs and parameter counts in a string format. otherwise, it will
+ return those in a float number format.
+ """
+ assert type(input_shape) is tuple
+ assert len(input_shape) >= 1
+ assert isinstance(model, nn.Module)
+ flops_model = add_flops_counting_methods(model)
+ flops_model.eval()
+ flops_model.start_flops_count()
+ if input_constructor:
+ input = input_constructor(input_shape)
+ _ = flops_model(**input)
+ else:
+ try:
+ batch = torch.ones(()).new_empty(
+ (1, *input_shape),
+ dtype=next(flops_model.parameters()).dtype,
+ device=next(flops_model.parameters()).device)
+ except StopIteration:
+ # Avoid StopIteration for models which have no parameters,
+ # like `nn.Relu()`, `nn.AvgPool2d`, etc.
+ batch = torch.ones(()).new_empty((1, *input_shape))
+
+ _ = flops_model(batch)
+
+ flops_count, params_count = flops_model.compute_average_flops_cost()
+ if print_per_layer_stat:
+ print_model_with_flops(
+ flops_model, flops_count, params_count, ost=ost, flush=flush)
+ flops_model.stop_flops_count()
+
+ if as_strings:
+ return flops_to_string(flops_count), params_to_string(params_count)
+
+ return flops_count, params_count
+
+
+def flops_to_string(flops: float,
+ units: Optional[str] = 'GFLOPs',
+ precision: int = 2) -> str:
+ """Convert FLOPs number into a string.
+
+ Note that Here we take a multiply-add counts as one FLOP.
+
+ Args:
+ flops (float): FLOPs number to be converted.
+ units (str | None): Converted FLOPs units. Options are None, 'GFLOPs',
+ 'MFLOPs', 'KFLOPs', 'FLOPs'. If set to None, it will automatically
+ choose the most suitable unit for FLOPs. Default: 'GFLOPs'.
+ precision (int): Digit number after the decimal point. Default: 2.
+
+ Returns:
+ str: The converted FLOPs number with units.
+
+ Examples:
+ >>> flops_to_string(1e9)
+ '1.0 GFLOPs'
+ >>> flops_to_string(2e5, 'MFLOPs')
+ '0.2 MFLOPs'
+ >>> flops_to_string(3e-9, None)
+ '3e-09 FLOPs'
+ """
+ if units is None:
+ if flops // 10**9 > 0:
+ return str(round(flops / 10.**9, precision)) + ' GFLOPs'
+ elif flops // 10**6 > 0:
+ return str(round(flops / 10.**6, precision)) + ' MFLOPs'
+ elif flops // 10**3 > 0:
+ return str(round(flops / 10.**3, precision)) + ' KFLOPs'
+ else:
+ return str(flops) + ' FLOPs'
+ else:
+ if units == 'GFLOPs':
+ return str(round(flops / 10.**9, precision)) + ' ' + units
+ elif units == 'MFLOPs':
+ return str(round(flops / 10.**6, precision)) + ' ' + units
+ elif units == 'KFLOPs':
+ return str(round(flops / 10.**3, precision)) + ' ' + units
+ else:
+ return str(flops) + ' FLOPs'
+
+
+def params_to_string(num_params: float,
+ units: Optional[str] = None,
+ precision: int = 2) -> str:
+ """Convert parameter number into a string.
+
+ Args:
+ num_params (float): Parameter number to be converted.
+ units (str | None): Converted FLOPs units. Options are None, 'M',
+ 'K' and ''. If set to None, it will automatically choose the most
+ suitable unit for Parameter number. Default: None.
+ precision (int): Digit number after the decimal point. Default: 2.
+
+ Returns:
+ str: The converted parameter number with units.
+
+ Examples:
+ >>> params_to_string(1e9)
+ '1000.0 M'
+ >>> params_to_string(2e5)
+ '200.0 k'
+ >>> params_to_string(3e-9)
+ '3e-09'
+ """
+ if units is None:
+ if num_params // 10**6 > 0:
+ return str(round(num_params / 10**6, precision)) + ' M'
+ elif num_params // 10**3:
+ return str(round(num_params / 10**3, precision)) + ' k'
+ else:
+ return str(num_params)
+ else:
+ if units == 'M':
+ return str(round(num_params / 10.**6, precision)) + ' ' + units
+ elif units == 'K':
+ return str(round(num_params / 10.**3, precision)) + ' ' + units
+ else:
+ return str(num_params)
+
+
+def print_model_with_flops(model: nn.Module,
+ total_flops: float,
+ total_params: float,
+ units: Optional[str] = 'GFLOPs',
+ precision: int = 3,
+ ost: TextIO = sys.stdout,
+ flush: bool = False) -> None:
+ """Print a model with FLOPs for each layer.
+
+ Args:
+ model (nn.Module): The model to be printed.
+ total_flops (float): Total FLOPs of the model.
+ total_params (float): Total parameter counts of the model.
+ units (str | None): Converted FLOPs units. Default: 'GFLOPs'.
+ precision (int): Digit number after the decimal point. Default: 3.
+ ost (stream): same as `file` param in :func:`print`.
+ Default: sys.stdout.
+ flush (bool): same as that in :func:`print`. Default: False.
+
+ Example:
+ >>> class ExampleModel(nn.Module):
+
+ >>> def __init__(self):
+ >>> super().__init__()
+ >>> self.conv1 = nn.Conv2d(3, 8, 3)
+ >>> self.conv2 = nn.Conv2d(8, 256, 3)
+ >>> self.conv3 = nn.Conv2d(256, 8, 3)
+ >>> self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
+ >>> self.flatten = nn.Flatten()
+ >>> self.fc = nn.Linear(8, 1)
+
+ >>> def forward(self, x):
+ >>> x = self.conv1(x)
+ >>> x = self.conv2(x)
+ >>> x = self.conv3(x)
+ >>> x = self.avg_pool(x)
+ >>> x = self.flatten(x)
+ >>> x = self.fc(x)
+ >>> return x
+
+ >>> model = ExampleModel()
+ >>> x = (3, 16, 16)
+ to print the complexity information state for each layer, you can use
+ >>> get_model_complexity_info(model, x)
+ or directly use
+ >>> print_model_with_flops(model, 4579784.0, 37361)
+ ExampleModel(
+ 0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs,
+ (conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1)) # noqa: E501
+ (conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1))
+ (conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1))
+ (avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1))
+ (flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, )
+ (fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True)
+ )
+ """
+
+ def accumulate_params(self):
+ if is_supported_instance(self):
+ return self.__params__
+ else:
+ sum = 0
+ for m in self.children():
+ sum += m.accumulate_params()
+ return sum
+
+ def accumulate_flops(self):
+ if is_supported_instance(self):
+ return self.__flops__ / model.__batch_counter__
+ else:
+ sum = 0
+ for m in self.children():
+ sum += m.accumulate_flops()
+ return sum
+
+ def flops_repr(self):
+ accumulated_num_params = self.accumulate_params()
+ accumulated_flops_cost = self.accumulate_flops()
+ return ', '.join([
+ params_to_string(
+ accumulated_num_params, units='M', precision=precision),
+ f'{accumulated_num_params / total_params:.3%} Params',
+ flops_to_string(
+ accumulated_flops_cost, units=units, precision=precision),
+ f'{accumulated_flops_cost / total_flops:.3%} FLOPs',
+ self.original_extra_repr()
+ ])
+
+ def add_extra_repr(m):
+ m.accumulate_flops = accumulate_flops.__get__(m)
+ m.accumulate_params = accumulate_params.__get__(m)
+ flops_extra_repr = flops_repr.__get__(m)
+ if m.extra_repr != flops_extra_repr:
+ m.original_extra_repr = m.extra_repr
+ m.extra_repr = flops_extra_repr
+ assert m.extra_repr != m.original_extra_repr
+
+ def del_extra_repr(m):
+ if hasattr(m, 'original_extra_repr'):
+ m.extra_repr = m.original_extra_repr
+ del m.original_extra_repr
+ if hasattr(m, 'accumulate_flops'):
+ del m.accumulate_flops
+
+ model.apply(add_extra_repr)
+ print(model, file=ost, flush=flush)
+ model.apply(del_extra_repr)
+
+
+def get_model_parameters_number(model: nn.Module) -> float:
+ """Calculate parameter number of a model.
+
+ Args:
+ model (nn.module): The model for parameter number calculation.
+
+ Returns:
+ float: Parameter number of the model.
+ """
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ return num_params
+
+
+def add_flops_counting_methods(net_main_module: nn.Module) -> nn.Module:
+ # adding additional methods to the existing module object,
+ # this is done this way so that each function has access to self object
+ net_main_module.start_flops_count = start_flops_count.__get__( # type: ignore # noqa E501
+ net_main_module)
+ net_main_module.stop_flops_count = stop_flops_count.__get__( # type: ignore # noqa E501
+ net_main_module)
+ net_main_module.reset_flops_count = reset_flops_count.__get__( # type: ignore # noqa E501
+ net_main_module)
+ net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # type: ignore # noqa E501
+ net_main_module)
+
+ net_main_module.reset_flops_count()
+
+ return net_main_module
+
+
+def compute_average_flops_cost(self) -> Tuple[float, float]:
+ """Compute average FLOPs cost.
+
+ A method to compute average FLOPs cost, which will be available after
+ `add_flops_counting_methods()` is called on a desired net object.
+
+ Returns:
+ float: Current mean flops consumption per image.
+ """
+ batches_count = self.__batch_counter__
+ flops_sum = 0
+ for module in self.modules():
+ if is_supported_instance(module):
+ flops_sum += module.__flops__
+ params_sum = get_model_parameters_number(self)
+ return flops_sum / batches_count, params_sum
+
+
+def start_flops_count(self) -> None:
+ """Activate the computation of mean flops consumption per image.
+
+ A method to activate the computation of mean flops consumption per image.
+ which will be available after ``add_flops_counting_methods()`` is called on
+ a desired net object. It should be called before running the network.
+ """
+ add_batch_counter_hook_function(self)
+
+ def add_flops_counter_hook_function(module: nn.Module) -> None:
+ if is_supported_instance(module):
+ if hasattr(module, '__flops_handle__'):
+ return
+
+ else:
+ handle = module.register_forward_hook(
+ get_modules_mapping()[type(module)])
+
+ module.__flops_handle__ = handle
+
+ self.apply(partial(add_flops_counter_hook_function))
+
+
+def stop_flops_count(self) -> None:
+ """Stop computing the mean flops consumption per image.
+
+ A method to stop computing the mean flops consumption per image, which will
+ be available after ``add_flops_counting_methods()`` is called on a desired
+ net object. It can be called to pause the computation whenever.
+ """
+ remove_batch_counter_hook_function(self)
+ self.apply(remove_flops_counter_hook_function)
+
+
+def reset_flops_count(self) -> None:
+ """Reset statistics computed so far.
+
+ A method to Reset computed statistics, which will be available after
+ `add_flops_counting_methods()` is called on a desired net object.
+ """
+ add_batch_counter_variables_or_reset(self)
+ self.apply(add_flops_counter_variable_or_reset)
+
+
+# ---- Internal functions
+def empty_flops_counter_hook(module: nn.Module, input: tuple,
+ output: Any) -> None:
+ module.__flops__ += 0
+
+
+def upsample_flops_counter_hook(module: nn.Module, input: tuple,
+ output: torch.Tensor) -> None:
+ output_size = output[0]
+ batch_size = output_size.shape[0]
+ output_elements_count = batch_size
+ for val in output_size.shape[1:]:
+ output_elements_count *= val
+ module.__flops__ += int(output_elements_count)
+
+
+def relu_flops_counter_hook(module: nn.Module, input: tuple,
+ output: torch.Tensor) -> None:
+ active_elements_count = output.numel()
+ module.__flops__ += int(active_elements_count)
+
+
+def linear_flops_counter_hook(module: nn.Module, input: tuple,
+ output: torch.Tensor) -> None:
+ output_last_dim = output.shape[
+ -1] # pytorch checks dimensions, so here we don't care much
+ module.__flops__ += int(np.prod(input[0].shape) * output_last_dim)
+
+
+def pool_flops_counter_hook(module: nn.Module, input: tuple,
+ output: torch.Tensor) -> None:
+ module.__flops__ += int(np.prod(input[0].shape))
+
+
+def norm_flops_counter_hook(module: nn.Module, input: tuple,
+ output: torch.Tensor) -> None:
+ batch_flops = np.prod(input[0].shape)
+ if (getattr(module, 'affine', False)
+ or getattr(module, 'elementwise_affine', False)):
+ batch_flops *= 2
+ module.__flops__ += int(batch_flops)
+
+
+def deconv_flops_counter_hook(conv_module: nn.Module, input: tuple,
+ output: torch.Tensor) -> None:
+ # Can have multiple inputs, getting the first one
+ batch_size = input[0].shape[0]
+ input_height, input_width = input[0].shape[2:]
+
+ kernel_height, kernel_width = conv_module.kernel_size
+ in_channels = conv_module.in_channels
+ out_channels = conv_module.out_channels
+ groups = conv_module.groups
+
+ filters_per_channel = out_channels // groups
+ conv_per_position_flops = (
+ kernel_height * kernel_width * in_channels * filters_per_channel)
+
+ active_elements_count = batch_size * input_height * input_width
+ overall_conv_flops = conv_per_position_flops * active_elements_count
+ bias_flops = 0
+ if conv_module.bias is not None:
+ output_height, output_width = output.shape[2:]
+ bias_flops = out_channels * batch_size * output_height * output_width
+ overall_flops = overall_conv_flops + bias_flops
+
+ conv_module.__flops__ += int(overall_flops)
+
+
+def conv_flops_counter_hook(conv_module: nn.Module, input: tuple,
+ output: torch.Tensor) -> None:
+ # Can have multiple inputs, getting the first one
+ batch_size = input[0].shape[0]
+ output_dims = list(output.shape[2:])
+
+ kernel_dims = list(conv_module.kernel_size)
+ in_channels = conv_module.in_channels
+ out_channels = conv_module.out_channels
+ groups = conv_module.groups
+
+ filters_per_channel = out_channels // groups
+ conv_per_position_flops = int(
+ np.prod(kernel_dims)) * in_channels * filters_per_channel
+
+ active_elements_count = batch_size * int(np.prod(output_dims))
+
+ overall_conv_flops = conv_per_position_flops * active_elements_count
+
+ bias_flops = 0
+
+ if conv_module.bias is not None:
+
+ bias_flops = out_channels * active_elements_count
+
+ overall_flops = overall_conv_flops + bias_flops
+
+ conv_module.__flops__ += int(overall_flops)
+
+
+def batch_counter_hook(module: nn.Module, input: tuple, output: Any) -> None:
+ batch_size = 1
+ if len(input) > 0:
+ # Can have multiple inputs, getting the first one
+ batch_size = len(input[0])
+ else:
+ warnings.warn('No positional inputs found for a module, '
+ 'assuming batch size is 1.')
+ module.__batch_counter__ += batch_size
+
+
+def add_batch_counter_variables_or_reset(module: nn.Module) -> None:
+
+ module.__batch_counter__ = 0
+
+
+def add_batch_counter_hook_function(module: nn.Module) -> None:
+ if hasattr(module, '__batch_counter_handle__'):
+ return
+
+ handle = module.register_forward_hook(batch_counter_hook)
+ module.__batch_counter_handle__ = handle
+
+
+def remove_batch_counter_hook_function(module: nn.Module) -> None:
+ if hasattr(module, '__batch_counter_handle__'):
+ module.__batch_counter_handle__.remove()
+ del module.__batch_counter_handle__
+
+
+def add_flops_counter_variable_or_reset(module: nn.Module) -> None:
+ if is_supported_instance(module):
+ if hasattr(module, '__flops__') or hasattr(module, '__params__'):
+ warnings.warn('variables __flops__ or __params__ are already '
+ 'defined for the module' + type(module).__name__ +
+ ' ptflops can affect your code!')
+ module.__flops__ = 0
+ module.__params__ = get_model_parameters_number(module)
+
+
+def is_supported_instance(module: nn.Module) -> bool:
+ if type(module) in get_modules_mapping():
+ return True
+ return False
+
+
+def remove_flops_counter_hook_function(module: nn.Module) -> None:
+ if is_supported_instance(module):
+ if hasattr(module, '__flops_handle__'):
+ module.__flops_handle__.remove()
+ del module.__flops_handle__
+
+
+def get_modules_mapping() -> Dict:
+ return {
+ # convolutions
+ nn.Conv1d: conv_flops_counter_hook,
+ nn.Conv2d: conv_flops_counter_hook,
+ dbnet_cv.cnn.bricks.Conv2d: conv_flops_counter_hook,
+ nn.Conv3d: conv_flops_counter_hook,
+ dbnet_cv.cnn.bricks.Conv3d: conv_flops_counter_hook,
+ # activations
+ nn.ReLU: relu_flops_counter_hook,
+ nn.PReLU: relu_flops_counter_hook,
+ nn.ELU: relu_flops_counter_hook,
+ nn.LeakyReLU: relu_flops_counter_hook,
+ nn.ReLU6: relu_flops_counter_hook,
+ # poolings
+ nn.MaxPool1d: pool_flops_counter_hook,
+ nn.AvgPool1d: pool_flops_counter_hook,
+ nn.AvgPool2d: pool_flops_counter_hook,
+ nn.MaxPool2d: pool_flops_counter_hook,
+ dbnet_cv.cnn.bricks.MaxPool2d: pool_flops_counter_hook,
+ nn.MaxPool3d: pool_flops_counter_hook,
+ dbnet_cv.cnn.bricks.MaxPool3d: pool_flops_counter_hook,
+ nn.AvgPool3d: pool_flops_counter_hook,
+ nn.AdaptiveMaxPool1d: pool_flops_counter_hook,
+ nn.AdaptiveAvgPool1d: pool_flops_counter_hook,
+ nn.AdaptiveMaxPool2d: pool_flops_counter_hook,
+ nn.AdaptiveAvgPool2d: pool_flops_counter_hook,
+ nn.AdaptiveMaxPool3d: pool_flops_counter_hook,
+ nn.AdaptiveAvgPool3d: pool_flops_counter_hook,
+ # normalizations
+ nn.BatchNorm1d: norm_flops_counter_hook,
+ nn.BatchNorm2d: norm_flops_counter_hook,
+ nn.BatchNorm3d: norm_flops_counter_hook,
+ nn.GroupNorm: norm_flops_counter_hook,
+ nn.InstanceNorm1d: norm_flops_counter_hook,
+ nn.InstanceNorm2d: norm_flops_counter_hook,
+ nn.InstanceNorm3d: norm_flops_counter_hook,
+ nn.LayerNorm: norm_flops_counter_hook,
+ # FC
+ nn.Linear: linear_flops_counter_hook,
+ dbnet_cv.cnn.bricks.Linear: linear_flops_counter_hook,
+ # Upscale
+ nn.Upsample: upsample_flops_counter_hook,
+ # Deconvolution
+ nn.ConvTranspose2d: deconv_flops_counter_hook,
+ dbnet_cv.cnn.bricks.ConvTranspose2d: deconv_flops_counter_hook,
+ }
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/fuse_conv_bn.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/fuse_conv_bn.py
new file mode 100755
index 0000000000000000000000000000000000000000..6ccaab3bf1eb3ce615bad910d6dc45a467bb1fe4
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/fuse_conv_bn.py
@@ -0,0 +1,59 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+
+def _fuse_conv_bn(conv: nn.Module, bn: nn.Module) -> nn.Module:
+ """Fuse conv and bn into one module.
+
+ Args:
+ conv (nn.Module): Conv to be fused.
+ bn (nn.Module): BN to be fused.
+
+ Returns:
+ nn.Module: Fused module.
+ """
+ conv_w = conv.weight
+ conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
+ bn.running_mean)
+
+ factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
+ conv.weight = nn.Parameter(conv_w *
+ factor.reshape([conv.out_channels, 1, 1, 1]))
+ conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
+ return conv
+
+
+def fuse_conv_bn(module: nn.Module) -> nn.Module:
+ """Recursively fuse conv and bn in a module.
+
+ During inference, the functionary of batch norm layers is turned off
+ but only the mean and var alone channels are used, which exposes the
+ chance to fuse it with the preceding conv layers to save computations and
+ simplify network structures.
+
+ Args:
+ module (nn.Module): Module to be fused.
+
+ Returns:
+ nn.Module: Fused module.
+ """
+ last_conv = None
+ last_conv_name = None
+
+ for name, child in module.named_children():
+ if isinstance(child,
+ (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
+ if last_conv is None: # only fuse BN that is after Conv
+ continue
+ fused_conv = _fuse_conv_bn(last_conv, child)
+ module._modules[last_conv_name] = fused_conv
+ # To reduce changes, set BN as Identity instead of deleting it.
+ module._modules[name] = nn.Identity()
+ last_conv = None
+ elif isinstance(child, nn.Conv2d):
+ last_conv = child
+ last_conv_name = name
+ else:
+ fuse_conv_bn(child)
+ return module
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/sync_bn.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/sync_bn.py
new file mode 100755
index 0000000000000000000000000000000000000000..47e36588a89a10212fefa9304003a98e6967e666
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/sync_bn.py
@@ -0,0 +1,61 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+import dbnet_cv
+
+
+class _BatchNormXd(nn.modules.batchnorm._BatchNorm):
+ """A general BatchNorm layer without input dimension check.
+
+ Reproduced from @kapily's work:
+ (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
+ The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
+ is `_check_input_dim` that is designed for tensor sanity checks.
+ The check has been bypassed in this class for the convenience of converting
+ SyncBatchNorm.
+ """
+
+ def _check_input_dim(self, input: torch.Tensor):
+ return
+
+
+def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
+ """Helper function to convert all `SyncBatchNorm` (SyncBN) and
+ `dbnet_cv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
+ `BatchNormXd` layers.
+
+ Adapted from @kapily's work:
+ (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
+
+ Args:
+ module (nn.Module): The module containing `SyncBatchNorm` layers.
+
+ Returns:
+ module_output: The converted module with `BatchNormXd` layers.
+ """
+ module_output = module
+ module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
+ if hasattr(dbnet_cv, 'ops'):
+ module_checklist.append(dbnet_cv.ops.SyncBatchNorm)
+ if isinstance(module, tuple(module_checklist)):
+ module_output = _BatchNormXd(module.num_features, module.eps,
+ module.momentum, module.affine,
+ module.track_running_stats)
+ if module.affine:
+ # no_grad() may not be needed here but
+ # just to be consistent with `convert_sync_batchnorm()`
+ with torch.no_grad():
+ module_output.weight = module.weight
+ module_output.bias = module.bias
+ module_output.running_mean = module.running_mean
+ module_output.running_var = module.running_var
+ module_output.num_batches_tracked = module.num_batches_tracked
+ module_output.training = module.training
+ # qconfig exists in quantized models
+ if hasattr(module, 'qconfig'):
+ module_output.qconfig = module.qconfig
+ for name, child in module.named_children():
+ module_output.add_module(name, revert_sync_batchnorm(child))
+ del module
+ return module_output
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/weight_init.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/weight_init.py
new file mode 100755
index 0000000000000000000000000000000000000000..8ed99109e73aeef2e618ed14f0dab33bf446faaa
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/weight_init.py
@@ -0,0 +1,708 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import math
+import warnings
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+from dbnet_cv.utils import Registry, build_from_cfg, get_logger, print_log
+
+INITIALIZERS = Registry('initializer')
+
+
+def update_init_info(module: nn.Module, init_info: str) -> None:
+ """Update the `_params_init_info` in the module if the value of parameters
+ are changed.
+
+ Args:
+ module (obj:`nn.Module`): The module of PyTorch with a user-defined
+ attribute `_params_init_info` which records the initialization
+ information.
+ init_info (str): The string that describes the initialization.
+ """
+ assert hasattr(
+ module,
+ '_params_init_info'), f'Can not find `_params_init_info` in {module}'
+ for name, param in module.named_parameters():
+
+ assert param in module._params_init_info, (
+ f'Find a new :obj:`Parameter` '
+ f'named `{name}` during executing the '
+ f'`init_weights` of '
+ f'`{module.__class__.__name__}`. '
+ f'Please do not add or '
+ f'replace parameters during executing '
+ f'the `init_weights`. ')
+
+ # The parameter has been changed during executing the
+ # `init_weights` of module
+ mean_value = param.data.mean()
+ if module._params_init_info[param]['tmp_mean_value'] != mean_value:
+ module._params_init_info[param]['init_info'] = init_info
+ module._params_init_info[param]['tmp_mean_value'] = mean_value
+
+
+def constant_init(module: nn.Module, val: float, bias: float = 0) -> None:
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.constant_(module.weight, val)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def xavier_init(module: nn.Module,
+ gain: float = 1,
+ bias: float = 0,
+ distribution: str = 'normal') -> None:
+ assert distribution in ['uniform', 'normal']
+ if hasattr(module, 'weight') and module.weight is not None:
+ if distribution == 'uniform':
+ nn.init.xavier_uniform_(module.weight, gain=gain)
+ else:
+ nn.init.xavier_normal_(module.weight, gain=gain)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def normal_init(module: nn.Module,
+ mean: float = 0,
+ std: float = 1,
+ bias: float = 0) -> None:
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.normal_(module.weight, mean, std)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def trunc_normal_init(module: nn.Module,
+ mean: float = 0,
+ std: float = 1,
+ a: float = -2,
+ b: float = 2,
+ bias: float = 0) -> None:
+ if hasattr(module, 'weight') and module.weight is not None:
+ trunc_normal_(module.weight, mean, std, a, b) # type: ignore
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias) # type: ignore
+
+
+def uniform_init(module: nn.Module,
+ a: float = 0,
+ b: float = 1,
+ bias: float = 0) -> None:
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.uniform_(module.weight, a, b)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def kaiming_init(module: nn.Module,
+ a: float = 0,
+ mode: str = 'fan_out',
+ nonlinearity: str = 'relu',
+ bias: float = 0,
+ distribution: str = 'normal') -> None:
+ assert distribution in ['uniform', 'normal']
+ if hasattr(module, 'weight') and module.weight is not None:
+ if distribution == 'uniform':
+ nn.init.kaiming_uniform_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+ else:
+ nn.init.kaiming_normal_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def caffe2_xavier_init(module: nn.Module, bias: float = 0) -> None:
+ # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
+ # Acknowledgment to FAIR's internal code
+ kaiming_init(
+ module,
+ a=1,
+ mode='fan_in',
+ nonlinearity='leaky_relu',
+ bias=bias,
+ distribution='uniform')
+
+
+def bias_init_with_prob(prior_prob: float) -> float:
+ """initialize conv/fc bias value according to a given probability value."""
+ bias_init = float(-np.log((1 - prior_prob) / prior_prob))
+ return bias_init
+
+
+def _get_bases_name(m: nn.Module) -> List[str]:
+ return [b.__name__ for b in m.__class__.__bases__]
+
+
+class BaseInit:
+
+ def __init__(self,
+ *,
+ bias: float = 0,
+ bias_prob: Optional[float] = None,
+ layer: Union[str, List, None] = None):
+ self.wholemodule = False
+ if not isinstance(bias, (int, float)):
+ raise TypeError(f'bias must be a number, but got a {type(bias)}')
+
+ if bias_prob is not None:
+ if not isinstance(bias_prob, float):
+ raise TypeError(f'bias_prob type must be float, \
+ but got {type(bias_prob)}')
+
+ if layer is not None:
+ if not isinstance(layer, (str, list)):
+ raise TypeError(f'layer must be a str or a list of str, \
+ but got a {type(layer)}')
+ else:
+ layer = []
+
+ if bias_prob is not None:
+ self.bias = bias_init_with_prob(bias_prob)
+ else:
+ self.bias = bias
+ self.layer = [layer] if isinstance(layer, str) else layer
+
+ def _get_init_info(self) -> str:
+ info = f'{self.__class__.__name__}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Constant')
+class ConstantInit(BaseInit):
+ """Initialize module parameters with constant values.
+
+ Args:
+ val (int | float): the value to fill the weights in the module with
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self, val: Union[int, float], **kwargs):
+ super().__init__(**kwargs)
+ self.val = val
+
+ def __call__(self, module: nn.Module) -> None:
+
+ def init(m):
+ if self.wholemodule:
+ constant_init(m, self.val, self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ constant_init(m, self.val, self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self) -> str:
+ info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Xavier')
+class XavierInit(BaseInit):
+ r"""Initialize module parameters with values according to the method
+ described in `Understanding the difficulty of training deep feedforward
+ neural networks - Glorot, X. & Bengio, Y. (2010).
+ `_
+
+ Args:
+ gain (int | float): an optional scaling factor. Defaults to 1.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ distribution (str): distribution either be ``'normal'``
+ or ``'uniform'``. Defaults to ``'normal'``.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ gain: float = 1,
+ distribution: str = 'normal',
+ **kwargs):
+ super().__init__(**kwargs)
+ self.gain = gain
+ self.distribution = distribution
+
+ def __call__(self, module: nn.Module) -> None:
+
+ def init(m):
+ if self.wholemodule:
+ xavier_init(m, self.gain, self.bias, self.distribution)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ xavier_init(m, self.gain, self.bias, self.distribution)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self) -> str:
+ info = f'{self.__class__.__name__}: gain={self.gain}, ' \
+ f'distribution={self.distribution}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Normal')
+class NormalInit(BaseInit):
+ r"""Initialize module parameters with the values drawn from the normal
+ distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
+
+ Args:
+ mean (int | float):the mean of the normal distribution. Defaults to 0.
+ std (int | float): the standard deviation of the normal distribution.
+ Defaults to 1.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+
+ """
+
+ def __init__(self, mean: float = 0, std: float = 1, **kwargs):
+ super().__init__(**kwargs)
+ self.mean = mean
+ self.std = std
+
+ def __call__(self, module: nn.Module) -> None:
+
+ def init(m):
+ if self.wholemodule:
+ normal_init(m, self.mean, self.std, self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ normal_init(m, self.mean, self.std, self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self) -> str:
+ info = f'{self.__class__.__name__}: mean={self.mean},' \
+ f' std={self.std}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='TruncNormal')
+class TruncNormalInit(BaseInit):
+ r"""Initialize module parameters with the values drawn from the normal
+ distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
+ outside :math:`[a, b]`.
+
+ Args:
+ mean (float): the mean of the normal distribution. Defaults to 0.
+ std (float): the standard deviation of the normal distribution.
+ Defaults to 1.
+ a (float): The minimum cutoff value.
+ b ( float): The maximum cutoff value.
+ bias (float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+
+ """
+
+ def __init__(self,
+ mean: float = 0,
+ std: float = 1,
+ a: float = -2,
+ b: float = 2,
+ **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.mean = mean
+ self.std = std
+ self.a = a
+ self.b = b
+
+ def __call__(self, module: nn.Module) -> None:
+
+ def init(m):
+ if self.wholemodule:
+ trunc_normal_init(m, self.mean, self.std, self.a, self.b,
+ self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ trunc_normal_init(m, self.mean, self.std, self.a, self.b,
+ self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
+ f' mean={self.mean}, std={self.std}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Uniform')
+class UniformInit(BaseInit):
+ r"""Initialize module parameters with values drawn from the uniform
+ distribution :math:`\mathcal{U}(a, b)`.
+
+ Args:
+ a (int | float): the lower bound of the uniform distribution.
+ Defaults to 0.
+ b (int | float): the upper bound of the uniform distribution.
+ Defaults to 1.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self, a: float = 0., b: float = 1., **kwargs):
+ super().__init__(**kwargs)
+ self.a = a
+ self.b = b
+
+ def __call__(self, module: nn.Module) -> None:
+
+ def init(m):
+ if self.wholemodule:
+ uniform_init(m, self.a, self.b, self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ uniform_init(m, self.a, self.b, self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self) -> str:
+ info = f'{self.__class__.__name__}: a={self.a},' \
+ f' b={self.b}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Kaiming')
+class KaimingInit(BaseInit):
+ r"""Initialize module parameters with the values according to the method
+ described in `Delving deep into rectifiers: Surpassing human-level
+ performance on ImageNet classification - He, K. et al. (2015).
+ `_
+
+ Args:
+ a (int | float): the negative slope of the rectifier used after this
+ layer (only used with ``'leaky_relu'``). Defaults to 0.
+ mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing
+ ``'fan_in'`` preserves the magnitude of the variance of the weights
+ in the forward pass. Choosing ``'fan_out'`` preserves the
+ magnitudes in the backwards pass. Defaults to ``'fan_out'``.
+ nonlinearity (str): the non-linear function (`nn.functional` name),
+ recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
+ Defaults to 'relu'.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ distribution (str): distribution either be ``'normal'`` or
+ ``'uniform'``. Defaults to ``'normal'``.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ a: float = 0,
+ mode: str = 'fan_out',
+ nonlinearity: str = 'relu',
+ distribution: str = 'normal',
+ **kwargs):
+ super().__init__(**kwargs)
+ self.a = a
+ self.mode = mode
+ self.nonlinearity = nonlinearity
+ self.distribution = distribution
+
+ def __call__(self, module: nn.Module) -> None:
+
+ def init(m):
+ if self.wholemodule:
+ kaiming_init(m, self.a, self.mode, self.nonlinearity,
+ self.bias, self.distribution)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ kaiming_init(m, self.a, self.mode, self.nonlinearity,
+ self.bias, self.distribution)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self) -> str:
+ info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
+ f'nonlinearity={self.nonlinearity}, ' \
+ f'distribution ={self.distribution}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Caffe2Xavier')
+class Caffe2XavierInit(KaimingInit):
+ # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
+ # Acknowledgment to FAIR's internal code
+ def __init__(self, **kwargs):
+ super().__init__(
+ a=1,
+ mode='fan_in',
+ nonlinearity='leaky_relu',
+ distribution='uniform',
+ **kwargs)
+
+ def __call__(self, module: nn.Module) -> None:
+ super().__call__(module)
+
+
+@INITIALIZERS.register_module(name='Pretrained')
+class PretrainedInit:
+ """Initialize module by loading a pretrained model.
+
+ Args:
+ checkpoint (str): the checkpoint file of the pretrained model should
+ be load.
+ prefix (str, optional): the prefix of a sub-module in the pretrained
+ model. it is for loading a part of the pretrained model to
+ initialize. For example, if we would like to only load the
+ backbone of a detector model, we can set ``prefix='backbone.'``.
+ Defaults to None.
+ map_location (str): map tensors into proper locations.
+ """
+
+ def __init__(self,
+ checkpoint: str,
+ prefix: Optional[str] = None,
+ map_location: Optional[str] = None):
+ self.checkpoint = checkpoint
+ self.prefix = prefix
+ self.map_location = map_location
+
+ def __call__(self, module: nn.Module) -> None:
+ from dbnet_cv.runner import (_load_checkpoint_with_prefix, load_checkpoint,
+ load_state_dict)
+ logger = get_logger('dbnet_cv')
+ if self.prefix is None:
+ print_log(f'load model from: {self.checkpoint}', logger=logger)
+ load_checkpoint(
+ module,
+ self.checkpoint,
+ map_location=self.map_location,
+ strict=False,
+ logger=logger)
+ else:
+ print_log(
+ f'load {self.prefix} in model from: {self.checkpoint}',
+ logger=logger)
+ state_dict = _load_checkpoint_with_prefix(
+ self.prefix, self.checkpoint, map_location=self.map_location)
+ load_state_dict(module, state_dict, strict=False, logger=logger)
+
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self) -> str:
+ info = f'{self.__class__.__name__}: load from {self.checkpoint}'
+ return info
+
+
+def _initialize(module: nn.Module,
+ cfg: Dict,
+ wholemodule: bool = False) -> None:
+ func = build_from_cfg(cfg, INITIALIZERS)
+ # wholemodule flag is for override mode, there is no layer key in override
+ # and initializer will give init values for the whole module with the name
+ # in override.
+ func.wholemodule = wholemodule
+ func(module)
+
+
+def _initialize_override(module: nn.Module, override: Union[Dict, List],
+ cfg: Dict) -> None:
+ if not isinstance(override, (dict, list)):
+ raise TypeError(f'override must be a dict or a list of dict, \
+ but got {type(override)}')
+
+ override = [override] if isinstance(override, dict) else override
+
+ for override_ in override:
+
+ cp_override = copy.deepcopy(override_)
+ name = cp_override.pop('name', None)
+ if name is None:
+ raise ValueError('`override` must contain the key "name",'
+ f'but got {cp_override}')
+ # if override only has name key, it means use args in init_cfg
+ if not cp_override:
+ cp_override.update(cfg)
+ # if override has name key and other args except type key, it will
+ # raise error
+ elif 'type' not in cp_override.keys():
+ raise ValueError(
+ f'`override` need "type" key, but got {cp_override}')
+
+ if hasattr(module, name):
+ _initialize(getattr(module, name), cp_override, wholemodule=True)
+ else:
+ raise RuntimeError(f'module did not have attribute {name}, '
+ f'but init_cfg is {cp_override}.')
+
+
+def initialize(module: nn.Module, init_cfg: Union[Dict, List[dict]]) -> None:
+ r"""Initialize a module.
+
+ Args:
+ module (``torch.nn.Module``): the module will be initialized.
+ init_cfg (dict | list[dict]): initialization configuration dict to
+ define initializer. OpenMMLab has implemented 6 initializers
+ including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
+ ``Kaiming``, and ``Pretrained``.
+
+ Example:
+ >>> module = nn.Linear(2, 3, bias=True)
+ >>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
+ >>> initialize(module, init_cfg)
+
+ >>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
+ >>> # define key ``'layer'`` for initializing layer with different
+ >>> # configuration
+ >>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
+ dict(type='Constant', layer='Linear', val=2)]
+ >>> initialize(module, init_cfg)
+
+ >>> # define key``'override'`` to initialize some specific part in
+ >>> # module
+ >>> class FooNet(nn.Module):
+ >>> def __init__(self):
+ >>> super().__init__()
+ >>> self.feat = nn.Conv2d(3, 16, 3)
+ >>> self.reg = nn.Conv2d(16, 10, 3)
+ >>> self.cls = nn.Conv2d(16, 5, 3)
+ >>> model = FooNet()
+ >>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
+ >>> override=dict(type='Constant', name='reg', val=3, bias=4))
+ >>> initialize(model, init_cfg)
+
+ >>> model = ResNet(depth=50)
+ >>> # Initialize weights with the pretrained model.
+ >>> init_cfg = dict(type='Pretrained',
+ checkpoint='torchvision://resnet50')
+ >>> initialize(model, init_cfg)
+
+ >>> # Initialize weights of a sub-module with the specific part of
+ >>> # a pretrained model by using "prefix".
+ >>> url = 'http://download.openmmlab.com/dbnet_detection/v2.0/retinanet/'\
+ >>> 'retinanet_r50_fpn_1x_coco/'\
+ >>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
+ >>> init_cfg = dict(type='Pretrained',
+ checkpoint=url, prefix='backbone.')
+ """
+ if not isinstance(init_cfg, (dict, list)):
+ raise TypeError(f'init_cfg must be a dict or a list of dict, \
+ but got {type(init_cfg)}')
+
+ if isinstance(init_cfg, dict):
+ init_cfg = [init_cfg]
+
+ for cfg in init_cfg:
+ # should deeply copy the original config because cfg may be used by
+ # other modules, e.g., one init_cfg shared by multiple bottleneck
+ # blocks, the expected cfg will be changed after pop and will change
+ # the initialization behavior of other modules
+ cp_cfg = copy.deepcopy(cfg)
+ override = cp_cfg.pop('override', None)
+ _initialize(module, cp_cfg)
+
+ if override is not None:
+ cp_cfg.pop('layer', None)
+ _initialize_override(module, override, cp_cfg)
+ else:
+ # All attributes in module have same initialization.
+ pass
+
+
+def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
+ b: float) -> Tensor:
+ # Method based on
+ # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ # Modified from
+ # https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ lower = norm_cdf((a - mean) / std)
+ upper = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [lower, upper], then translate
+ # to [2lower-1, 2upper-1].
+ tensor.uniform_(2 * lower - 1, 2 * upper - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor: Tensor,
+ mean: float = 0.,
+ std: float = 1.,
+ a: float = -2.,
+ b: float = 2.) -> Tensor:
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+
+ Modified from
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
+
+ Args:
+ tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
+ mean (float): the mean of the normal distribution.
+ std (float): the standard deviation of the normal distribution.
+ a (float): the minimum cutoff value.
+ b (float): the maximum cutoff value.
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/engine/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/engine/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..3193b7f664e19ce2458d81c836597fa22e4bb082
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/engine/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .test import (collect_results_cpu, collect_results_gpu, multi_gpu_test,
+ single_gpu_test)
+
+__all__ = [
+ 'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test',
+ 'single_gpu_test'
+]
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/engine/test.py b/cv/ocr/dbnet/pytorch/dbnet_cv/engine/test.py
new file mode 100755
index 0000000000000000000000000000000000000000..99b68971d8c9d4f6dfe7b01a337e50f34bc70844
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/engine/test.py
@@ -0,0 +1,213 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import pickle
+import shutil
+import tempfile
+import time
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.utils.data import DataLoader
+
+import dbnet_cv
+from dbnet_cv.runner import get_dist_info
+
+
+def single_gpu_test(model: nn.Module, data_loader: DataLoader) -> list:
+ """Test model with a single gpu.
+
+ This method tests model with a single gpu and displays test progress bar.
+
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (nn.Dataloader): Pytorch data loader.
+
+ Returns:
+ list: The prediction results.
+ """
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ prog_bar = dbnet_cv.ProgressBar(len(dataset))
+ for data in data_loader:
+ with torch.no_grad():
+ result = model(return_loss=False, **data)
+ results.extend(result)
+
+ # Assume result has the same length of batch_size
+ # refer to https://github.com/open-mmlab/dbnet_cv/issues/985
+ batch_size = len(result)
+ for _ in range(batch_size):
+ prog_bar.update()
+ return results
+
+
+def multi_gpu_test(model: nn.Module,
+ data_loader: DataLoader,
+ tmpdir: Optional[str] = None,
+ gpu_collect: bool = False) -> Optional[list]:
+ """Test model with multiple gpus.
+
+ This method tests model with multiple gpus and collects the results
+ under two different modes: gpu and cpu modes. By setting
+ ``gpu_collect=True``, it encodes results to gpu tensors and use gpu
+ communication for results collection. On cpu mode it saves the results on
+ different gpus to ``tmpdir`` and collects them by the rank 0 worker.
+
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (nn.Dataloader): Pytorch data loader.
+ tmpdir (str): Path of directory to save the temporary results from
+ different gpus under cpu mode.
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
+
+ Returns:
+ list: The prediction results.
+ """
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ rank, world_size = get_dist_info()
+ if rank == 0:
+ prog_bar = dbnet_cv.ProgressBar(len(dataset))
+ time.sleep(2) # This line can prevent deadlock problem in some cases.
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, **data)
+ results.extend(result)
+
+ if rank == 0:
+ batch_size = len(result)
+ batch_size_all = batch_size * world_size
+ if batch_size_all + prog_bar.completed > len(dataset):
+ batch_size_all = len(dataset) - prog_bar.completed
+ for _ in range(batch_size_all):
+ prog_bar.update()
+
+ # collect results from all ranks
+ if gpu_collect:
+ result_from_ranks = collect_results_gpu(results, len(dataset))
+ else:
+ result_from_ranks = collect_results_cpu(results, len(dataset), tmpdir)
+ return result_from_ranks
+
+
+def collect_results_cpu(result_part: list,
+ size: int,
+ tmpdir: Optional[str] = None) -> Optional[list]:
+ """Collect results under cpu mode.
+
+ On cpu mode, this function will save the results on different gpus to
+ ``tmpdir`` and collect them by the rank 0 worker.
+
+ Args:
+ result_part (list): Result list containing result parts
+ to be collected.
+ size (int): Size of the results, commonly equal to length of
+ the results.
+ tmpdir (str | None): temporal directory for collected results to
+ store. If set to None, it will create a random temporal directory
+ for it.
+
+ Returns:
+ list: The collected results.
+ """
+ rank, world_size = get_dist_info()
+ # create a tmp dir if it is not specified
+ if tmpdir is None:
+ MAX_LEN = 512
+ # 32 is whitespace
+ dir_tensor = torch.full((MAX_LEN, ),
+ 32,
+ dtype=torch.uint8,
+ device='cuda')
+ if rank == 0:
+ dbnet_cv.mkdir_or_exist('.dist_test')
+ tmpdir = tempfile.mkdtemp(dir='.dist_test')
+ tmpdir = torch.tensor(
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
+ dir_tensor[:len(tmpdir)] = tmpdir
+ dist.broadcast(dir_tensor, 0)
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
+ else:
+ dbnet_cv.mkdir_or_exist(tmpdir)
+ # dump the part result to the dir
+ part_file = osp.join(tmpdir, f'part_{rank}.pkl') # type: ignore
+ dbnet_cv.dump(result_part, part_file)
+ dist.barrier()
+ # collect all parts
+ if rank != 0:
+ return None
+ else:
+ # load results of all parts from tmp dir
+ part_list = []
+ for i in range(world_size):
+ part_file = osp.join(tmpdir, f'part_{i}.pkl') # type: ignore
+ part_result = dbnet_cv.load(part_file)
+ # When data is severely insufficient, an empty part_result
+ # on a certain gpu could makes the overall outputs empty.
+ if part_result:
+ part_list.append(part_result)
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ # remove tmp dir
+ shutil.rmtree(tmpdir) # type: ignore
+ return ordered_results
+
+
+def collect_results_gpu(result_part: list, size: int) -> Optional[list]:
+ """Collect results under gpu mode.
+
+ On gpu mode, this function will encode results to gpu tensors and use gpu
+ communication for results collection.
+
+ Args:
+ result_part (list): Result list containing result parts
+ to be collected.
+ size (int): Size of the results, commonly equal to length of
+ the results.
+
+ Returns:
+ list: The collected results.
+ """
+ rank, world_size = get_dist_info()
+ # dump result part to tensor with pickle
+ part_tensor = torch.tensor(
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
+ # gather all result part tensor shape
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
+ dist.all_gather(shape_list, shape_tensor)
+ # padding result part tensor to max length
+ shape_max = torch.tensor(shape_list).max()
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
+ part_send[:shape_tensor[0]] = part_tensor
+ part_recv_list = [
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
+ ]
+ # gather all result part
+ dist.all_gather(part_recv_list, part_send)
+
+ if rank == 0:
+ part_list = []
+ for recv, shape in zip(part_recv_list, shape_list):
+ part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())
+ # When data is severely insufficient, an empty part_result
+ # on a certain gpu could makes the overall outputs empty.
+ if part_result:
+ part_list.append(part_result)
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ return ordered_results
+ else:
+ return None
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..2051b85f7e59bff7bdbaa131849ce8cd31f059a4
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .file_client import BaseStorageBackend, FileClient
+from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
+from .io import dump, load, register_handler
+from .parse import dict_from_file, list_from_file
+
+__all__ = [
+ 'BaseStorageBackend', 'FileClient', 'load', 'dump', 'register_handler',
+ 'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler',
+ 'list_from_file', 'dict_from_file'
+]
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/file_client.py b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/file_client.py
new file mode 100755
index 0000000000000000000000000000000000000000..1f99eb54c6e233f2cc7084541a90d183d749c212
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/file_client.py
@@ -0,0 +1,1173 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import os
+import os.path as osp
+import re
+import tempfile
+import warnings
+from abc import ABCMeta, abstractmethod
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Any, Generator, Iterator, Optional, Tuple, Union
+from urllib.request import urlopen
+
+import dbnet_cv
+from dbnet_cv.utils.misc import has_method
+from dbnet_cv.utils.path import is_filepath
+
+
+class BaseStorageBackend(metaclass=ABCMeta):
+ """Abstract class of storage backends.
+
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
+ as texts.
+ """
+
+ # a flag to indicate whether the backend can create a symlink for a file
+ _allow_symlink = False
+
+ @property
+ def name(self):
+ return self.__class__.__name__
+
+ @property
+ def allow_symlink(self):
+ return self._allow_symlink
+
+ @abstractmethod
+ def get(self, filepath):
+ pass
+
+ @abstractmethod
+ def get_text(self, filepath):
+ pass
+
+
+class CephBackend(BaseStorageBackend):
+ """Ceph storage backend (for internal use).
+
+ Args:
+ path_mapping (dict|None): path mapping dict from local path to Petrel
+ path. When ``path_mapping={'src': 'dst'}``, ``src`` in ``filepath``
+ will be replaced by ``dst``. Default: None.
+
+ .. warning::
+ :class:`dbnet_cv.fileio.file_client.CephBackend` will be deprecated,
+ please use :class:`dbnet_cv.fileio.file_client.PetrelBackend` instead.
+ """
+
+ def __init__(self, path_mapping=None):
+ try:
+ import ceph
+ except ImportError:
+ raise ImportError('Please install ceph to enable CephBackend.')
+
+ warnings.warn(
+ 'CephBackend will be deprecated, please use PetrelBackend instead',
+ DeprecationWarning)
+ self._client = ceph.S3Client()
+ assert isinstance(path_mapping, dict) or path_mapping is None
+ self.path_mapping = path_mapping
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ if self.path_mapping is not None:
+ for k, v in self.path_mapping.items():
+ filepath = filepath.replace(k, v)
+ value = self._client.Get(filepath)
+ value_buf = memoryview(value)
+ return value_buf
+
+ def get_text(self, filepath, encoding=None):
+ raise NotImplementedError
+
+
+class PetrelBackend(BaseStorageBackend):
+ """Petrel storage backend (for internal use).
+
+ PetrelBackend supports reading and writing data to multiple clusters.
+ If the file path contains the cluster name, PetrelBackend will read data
+ from specified cluster or write data to it. Otherwise, PetrelBackend will
+ access the default cluster.
+
+ Args:
+ path_mapping (dict, optional): Path mapping dict from local path to
+ Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in
+ ``filepath`` will be replaced by ``dst``. Default: None.
+ enable_mc (bool, optional): Whether to enable memcached support.
+ Default: True.
+
+ Examples:
+ >>> filepath1 = 's3://path/of/file'
+ >>> filepath2 = 'cluster-name:s3://path/of/file'
+ >>> client = PetrelBackend()
+ >>> client.get(filepath1) # get data from default cluster
+ >>> client.get(filepath2) # get data from 'cluster-name' cluster
+ """
+
+ def __init__(self,
+ path_mapping: Optional[dict] = None,
+ enable_mc: bool = True):
+ try:
+ from petrel_client import client
+ except ImportError:
+ raise ImportError('Please install petrel_client to enable '
+ 'PetrelBackend.')
+
+ self._client = client.Client(enable_mc=enable_mc)
+ assert isinstance(path_mapping, dict) or path_mapping is None
+ self.path_mapping = path_mapping
+
+ def _map_path(self, filepath: Union[str, Path]) -> str:
+ """Map ``filepath`` to a string path whose prefix will be replaced by
+ :attr:`self.path_mapping`.
+
+ Args:
+ filepath (str): Path to be mapped.
+ """
+ filepath = str(filepath)
+ if self.path_mapping is not None:
+ for k, v in self.path_mapping.items():
+ filepath = filepath.replace(k, v)
+ return filepath
+
+ def _format_path(self, filepath: str) -> str:
+ """Convert a ``filepath`` to standard format of petrel oss.
+
+ If the ``filepath`` is concatenated by ``os.path.join``, in a Windows
+ environment, the ``filepath`` will be the format of
+ 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the
+ above ``filepath`` will be converted to 's3://bucket_name/image.jpg'.
+
+ Args:
+ filepath (str): Path to be formatted.
+ """
+ return re.sub(r'\\+', '/', filepath)
+
+ def get(self, filepath: Union[str, Path]) -> memoryview:
+ """Read data from a given ``filepath`` with 'rb' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+
+ Returns:
+ memoryview: A memory view of expected bytes object to avoid
+ copying. The memoryview object can be converted to bytes by
+ ``value_buf.tobytes()``.
+ """
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ value = self._client.Get(filepath)
+ value_buf = memoryview(value)
+ return value_buf
+
+ def get_text(self,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ return str(self.get(filepath), encoding=encoding)
+
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Save data to a given ``filepath``.
+
+ Args:
+ obj (bytes): Data to be saved.
+ filepath (str or Path): Path to write data.
+ """
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ self._client.put(filepath, obj)
+
+ def put_text(self,
+ obj: str,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> None:
+ """Save data to a given ``filepath``.
+
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str): The encoding format used to encode the ``obj``.
+ Default: 'utf-8'.
+ """
+ self.put(bytes(obj, encoding=encoding), filepath)
+
+ def remove(self, filepath: Union[str, Path]) -> None:
+ """Remove a file.
+
+ Args:
+ filepath (str or Path): Path to be removed.
+ """
+ if not has_method(self._client, 'delete'):
+ raise NotImplementedError(
+ 'Current version of Petrel Python SDK has not supported '
+ 'the `delete` method, please use a higher version or dev'
+ ' branch instead.')
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ self._client.delete(filepath)
+
+ def exists(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path exists.
+
+ Args:
+ filepath (str or Path): Path to be checked whether exists.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+ """
+ if not (has_method(self._client, 'contains')
+ and has_method(self._client, 'isdir')):
+ raise NotImplementedError(
+ 'Current version of Petrel Python SDK has not supported '
+ 'the `contains` and `isdir` methods, please use a higher'
+ 'version or dev branch instead.')
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ return self._client.contains(filepath) or self._client.isdir(filepath)
+
+ def isdir(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a directory.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a
+ directory.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a directory,
+ ``False`` otherwise.
+ """
+ if not has_method(self._client, 'isdir'):
+ raise NotImplementedError(
+ 'Current version of Petrel Python SDK has not supported '
+ 'the `isdir` method, please use a higher version or dev'
+ ' branch instead.')
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ return self._client.isdir(filepath)
+
+ def isfile(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a file.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a file.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
+ otherwise.
+ """
+ if not has_method(self._client, 'contains'):
+ raise NotImplementedError(
+ 'Current version of Petrel Python SDK has not supported '
+ 'the `contains` method, please use a higher version or '
+ 'dev branch instead.')
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ return self._client.contains(filepath)
+
+ def join_path(self, filepath: Union[str, Path],
+ *filepaths: Union[str, Path]) -> str:
+ """Concatenate all file paths.
+
+ Args:
+ filepath (str or Path): Path to be concatenated.
+
+ Returns:
+ str: The result after concatenation.
+ """
+ filepath = self._format_path(self._map_path(filepath))
+ if filepath.endswith('/'):
+ filepath = filepath[:-1]
+ formatted_paths = [filepath]
+ for path in filepaths:
+ formatted_paths.append(self._format_path(self._map_path(path)))
+ return '/'.join(formatted_paths)
+
+ @contextmanager
+ def get_local_path(
+ self,
+ filepath: Union[str,
+ Path]) -> Generator[Union[str, Path], None, None]:
+ """Download a file from ``filepath`` and return a temporary path.
+
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+
+ Args:
+ filepath (str | Path): Download a file from ``filepath``.
+
+ Examples:
+ >>> client = PetrelBackend()
+ >>> # After existing from the ``with`` clause,
+ >>> # the path will be removed
+ >>> with client.get_local_path('s3://path/of/your/file') as path:
+ ... # do something here
+
+ Yields:
+ Iterable[str]: Only yield one temporary path.
+ """
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ assert self.isfile(filepath)
+ try:
+ f = tempfile.NamedTemporaryFile(delete=False)
+ f.write(self.get(filepath))
+ f.close()
+ yield f.name
+ finally:
+ os.remove(f.name)
+
+ def list_dir_or_file(self,
+ dir_path: Union[str, Path],
+ list_dir: bool = True,
+ list_file: bool = True,
+ suffix: Optional[Union[str, Tuple[str]]] = None,
+ recursive: bool = False) -> Iterator[str]:
+ """Scan a directory to find the interested directories or files in
+ arbitrary order.
+
+ Note:
+ Petrel has no concept of directories but it simulates the directory
+ hierarchy in the filesystem through public prefixes. In addition,
+ if the returned path ends with '/', it means the path is a public
+ prefix which is a logical directory.
+
+ Note:
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+ In addition, the returned path of directory will not contains the
+ suffix '/' which is consistent with other backends.
+
+ Args:
+ dir_path (str | Path): Path of the directory.
+ list_dir (bool): List the directories. Default: True.
+ list_file (bool): List the path of files. Default: True.
+ suffix (str or tuple[str], optional): File suffix
+ that we are interested in. Default: None.
+ recursive (bool): If set to True, recursively scan the
+ directory. Default: False.
+
+ Yields:
+ Iterable[str]: A relative path to ``dir_path``.
+ """
+ if not has_method(self._client, 'list'):
+ raise NotImplementedError(
+ 'Current version of Petrel Python SDK has not supported '
+ 'the `list` method, please use a higher version or dev'
+ ' branch instead.')
+
+ dir_path = self._map_path(dir_path)
+ dir_path = self._format_path(dir_path)
+ if list_dir and suffix is not None:
+ raise TypeError(
+ '`list_dir` should be False when `suffix` is not None')
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('`suffix` must be a string or tuple of strings')
+
+ # Petrel's simulated directory hierarchy assumes that directory paths
+ # should end with `/`
+ if not dir_path.endswith('/'):
+ dir_path += '/'
+
+ root = dir_path
+
+ def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive):
+ for path in self._client.list(dir_path):
+ # the `self.isdir` is not used here to determine whether path
+ # is a directory, because `self.isdir` relies on
+ # `self._client.list`
+ if path.endswith('/'): # a directory path
+ next_dir_path = self.join_path(dir_path, path)
+ if list_dir:
+ # get the relative path and exclude the last
+ # character '/'
+ rel_dir = next_dir_path[len(root):-1]
+ yield rel_dir
+ if recursive:
+ yield from _list_dir_or_file(next_dir_path, list_dir,
+ list_file, suffix,
+ recursive)
+ else: # a file path
+ absolute_path = self.join_path(dir_path, path)
+ rel_path = absolute_path[len(root):]
+ if (suffix is None
+ or rel_path.endswith(suffix)) and list_file:
+ yield rel_path
+
+ return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive)
+
+
+class MemcachedBackend(BaseStorageBackend):
+ """Memcached storage backend.
+
+ Attributes:
+ server_list_cfg (str): Config file for memcached server list.
+ client_cfg (str): Config file for memcached client.
+ sys_path (str | None): Additional path to be appended to `sys.path`.
+ Default: None.
+ """
+
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
+ if sys_path is not None:
+ import sys
+ sys.path.append(sys_path)
+ try:
+ import mc
+ except ImportError:
+ raise ImportError(
+ 'Please install memcached to enable MemcachedBackend.')
+
+ self.server_list_cfg = server_list_cfg
+ self.client_cfg = client_cfg
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
+ self.client_cfg)
+ # mc.pyvector servers as a point which points to a memory cache
+ self._mc_buffer = mc.pyvector()
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ import mc
+ self._client.Get(filepath, self._mc_buffer)
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
+ return value_buf
+
+ def get_text(self, filepath, encoding=None):
+ raise NotImplementedError
+
+
+class LmdbBackend(BaseStorageBackend):
+ """Lmdb storage backend.
+
+ Args:
+ db_path (str): Lmdb database path.
+ readonly (bool, optional): Lmdb environment parameter. If True,
+ disallow any write operations. Default: True.
+ lock (bool, optional): Lmdb environment parameter. If False, when
+ concurrent access occurs, do not lock the database. Default: False.
+ readahead (bool, optional): Lmdb environment parameter. If False,
+ disable the OS filesystem readahead mechanism, which may improve
+ random read performance when a database is larger than RAM.
+ Default: False.
+
+ Attributes:
+ db_path (str): Lmdb database path.
+ """
+
+ def __init__(self,
+ db_path,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ **kwargs):
+ try:
+ import lmdb # NOQA
+ except ImportError:
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
+
+ self.db_path = str(db_path)
+ self.readonly = readonly
+ self.lock = lock
+ self.readahead = readahead
+ self.kwargs = kwargs
+ self._client = None
+
+ def get(self, filepath):
+ """Get values according to the filepath.
+
+ Args:
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
+ """
+ if self._client is None:
+ self._client = self._get_client()
+
+ with self._client.begin(write=False) as txn:
+ value_buf = txn.get(str(filepath).encode('utf-8'))
+ return value_buf
+
+ def get_text(self, filepath, encoding=None):
+ raise NotImplementedError
+
+ def _get_client(self):
+ import lmdb
+
+ return lmdb.open(
+ self.db_path,
+ readonly=self.readonly,
+ lock=self.lock,
+ readahead=self.readahead,
+ **self.kwargs)
+
+ def __del__(self):
+ self._client.close()
+
+
+class HardDiskBackend(BaseStorageBackend):
+ """Raw hard disks storage backend."""
+
+ _allow_symlink = True
+
+ def get(self, filepath: Union[str, Path]) -> bytes:
+ """Read data from a given ``filepath`` with 'rb' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+
+ Returns:
+ bytes: Expected bytes object.
+ """
+ with open(filepath, 'rb') as f:
+ value_buf = f.read()
+ return value_buf
+
+ def get_text(self,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ with open(filepath, encoding=encoding) as f:
+ value_buf = f.read()
+ return value_buf
+
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'wb' mode.
+
+ Note:
+ ``put`` will create a directory if the directory of ``filepath``
+ does not exist.
+
+ Args:
+ obj (bytes): Data to be written.
+ filepath (str or Path): Path to write data.
+ """
+ dbnet_cv.mkdir_or_exist(osp.dirname(filepath))
+ with open(filepath, 'wb') as f:
+ f.write(obj)
+
+ def put_text(self,
+ obj: str,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> None:
+ """Write data to a given ``filepath`` with 'w' mode.
+
+ Note:
+ ``put_text`` will create a directory if the directory of
+ ``filepath`` does not exist.
+
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+ """
+ dbnet_cv.mkdir_or_exist(osp.dirname(filepath))
+ with open(filepath, 'w', encoding=encoding) as f:
+ f.write(obj)
+
+ def remove(self, filepath: Union[str, Path]) -> None:
+ """Remove a file.
+
+ Args:
+ filepath (str or Path): Path to be removed.
+ """
+ os.remove(filepath)
+
+ def exists(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path exists.
+
+ Args:
+ filepath (str or Path): Path to be checked whether exists.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+ """
+ return osp.exists(filepath)
+
+ def isdir(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a directory.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a
+ directory.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a directory,
+ ``False`` otherwise.
+ """
+ return osp.isdir(filepath)
+
+ def isfile(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a file.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a file.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
+ otherwise.
+ """
+ return osp.isfile(filepath)
+
+ def join_path(self, filepath: Union[str, Path],
+ *filepaths: Union[str, Path]) -> str:
+ """Concatenate all file paths.
+
+ Join one or more filepath components intelligently. The return value
+ is the concatenation of filepath and any members of *filepaths.
+
+ Args:
+ filepath (str or Path): Path to be concatenated.
+
+ Returns:
+ str: The result of concatenation.
+ """
+ return osp.join(filepath, *filepaths)
+
+ @contextmanager
+ def get_local_path(
+ self,
+ filepath: Union[str,
+ Path]) -> Generator[Union[str, Path], None, None]:
+ """Only for unified API and do nothing."""
+ yield filepath
+
+ def list_dir_or_file(self,
+ dir_path: Union[str, Path],
+ list_dir: bool = True,
+ list_file: bool = True,
+ suffix: Optional[Union[str, Tuple[str]]] = None,
+ recursive: bool = False) -> Iterator[str]:
+ """Scan a directory to find the interested directories or files in
+ arbitrary order.
+
+ Note:
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+
+ Args:
+ dir_path (str | Path): Path of the directory.
+ list_dir (bool): List the directories. Default: True.
+ list_file (bool): List the path of files. Default: True.
+ suffix (str or tuple[str], optional): File suffix
+ that we are interested in. Default: None.
+ recursive (bool): If set to True, recursively scan the
+ directory. Default: False.
+
+ Yields:
+ Iterable[str]: A relative path to ``dir_path``.
+ """
+ if list_dir and suffix is not None:
+ raise TypeError('`suffix` should be None when `list_dir` is True')
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('`suffix` must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ rel_path = osp.relpath(entry.path, root)
+ if (suffix is None
+ or rel_path.endswith(suffix)) and list_file:
+ yield rel_path
+ elif osp.isdir(entry.path):
+ if list_dir:
+ rel_dir = osp.relpath(entry.path, root)
+ yield rel_dir
+ if recursive:
+ yield from _list_dir_or_file(entry.path, list_dir,
+ list_file, suffix,
+ recursive)
+
+ return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive)
+
+
+class HTTPBackend(BaseStorageBackend):
+ """HTTP and HTTPS storage bachend."""
+
+ def get(self, filepath):
+ value_buf = urlopen(filepath).read()
+ return value_buf
+
+ def get_text(self, filepath, encoding='utf-8'):
+ value_buf = urlopen(filepath).read()
+ return value_buf.decode(encoding)
+
+ @contextmanager
+ def get_local_path(
+ self, filepath: str) -> Generator[Union[str, Path], None, None]:
+ """Download a file from ``filepath``.
+
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+
+ Args:
+ filepath (str): Download a file from ``filepath``.
+
+ Examples:
+ >>> client = HTTPBackend()
+ >>> # After existing from the ``with`` clause,
+ >>> # the path will be removed
+ >>> with client.get_local_path('http://path/of/your/file') as path:
+ ... # do something here
+ """
+ try:
+ f = tempfile.NamedTemporaryFile(delete=False)
+ f.write(self.get(filepath))
+ f.close()
+ yield f.name
+ finally:
+ os.remove(f.name)
+
+
+class FileClient:
+ """A general file client to access files in different backends.
+
+ The client loads a file or text in a specified backend from its path
+ and returns it as a binary or text file. There are two ways to choose a
+ backend, the name of backend and the prefix of path. Although both of them
+ can be used to choose a storage backend, ``backend`` has a higher priority
+ that is if they are all set, the storage backend will be chosen by the
+ backend argument. If they are all `None`, the disk backend will be chosen.
+ Note that It can also register other backend accessor with a given name,
+ prefixes, and backend class. In addition, We use the singleton pattern to
+ avoid repeated object creation. If the arguments are the same, the same
+ object will be returned.
+
+ Args:
+ backend (str, optional): The storage backend type. Options are "disk",
+ "ceph", "memcached", "lmdb", "http" and "petrel". Default: None.
+ prefix (str, optional): The prefix of the registered storage backend.
+ Options are "s3", "http", "https". Default: None.
+
+ Examples:
+ >>> # only set backend
+ >>> file_client = FileClient(backend='petrel')
+ >>> # only set prefix
+ >>> file_client = FileClient(prefix='s3')
+ >>> # set both backend and prefix but use backend to choose client
+ >>> file_client = FileClient(backend='petrel', prefix='s3')
+ >>> # if the arguments are the same, the same object is returned
+ >>> file_client1 = FileClient(backend='petrel')
+ >>> file_client1 is file_client
+ True
+
+ Attributes:
+ client (:obj:`BaseStorageBackend`): The backend object.
+ """
+
+ _backends = {
+ 'disk': HardDiskBackend,
+ 'ceph': CephBackend,
+ 'memcached': MemcachedBackend,
+ 'lmdb': LmdbBackend,
+ 'petrel': PetrelBackend,
+ 'http': HTTPBackend,
+ }
+
+ _prefix_to_backends = {
+ 's3': PetrelBackend,
+ 'http': HTTPBackend,
+ 'https': HTTPBackend,
+ }
+
+ _instances: dict = {}
+
+ client: Any
+
+ def __new__(cls, backend=None, prefix=None, **kwargs):
+ if backend is None and prefix is None:
+ backend = 'disk'
+ if backend is not None and backend not in cls._backends:
+ raise ValueError(
+ f'Backend {backend} is not supported. Currently supported ones'
+ f' are {list(cls._backends.keys())}')
+ if prefix is not None and prefix not in cls._prefix_to_backends:
+ raise ValueError(
+ f'prefix {prefix} is not supported. Currently supported ones '
+ f'are {list(cls._prefix_to_backends.keys())}')
+
+ # concatenate the arguments to a unique key for determining whether
+ # objects with the same arguments were created
+ arg_key = f'{backend}:{prefix}'
+ for key, value in kwargs.items():
+ arg_key += f':{key}:{value}'
+
+ if arg_key in cls._instances:
+ _instance = cls._instances[arg_key]
+ else:
+ # create a new object and put it to _instance
+ _instance = super().__new__(cls)
+ if backend is not None:
+ _instance.client = cls._backends[backend](**kwargs)
+ else:
+ _instance.client = cls._prefix_to_backends[prefix](**kwargs)
+
+ cls._instances[arg_key] = _instance
+
+ return _instance
+
+ @property
+ def name(self):
+ return self.client.name
+
+ @property
+ def allow_symlink(self):
+ return self.client.allow_symlink
+
+ @staticmethod
+ def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]:
+ """Parse the prefix of a uri.
+
+ Args:
+ uri (str | Path): Uri to be parsed that contains the file prefix.
+
+ Examples:
+ >>> FileClient.parse_uri_prefix('s3://path/of/your/file')
+ 's3'
+
+ Returns:
+ str | None: Return the prefix of uri if the uri contains '://' else
+ ``None``.
+ """
+ assert is_filepath(uri)
+ uri = str(uri)
+ if '://' not in uri:
+ return None
+ else:
+ prefix, _ = uri.split('://')
+ # In the case of PetrelBackend, the prefix may contains the cluster
+ # name like clusterName:s3
+ if ':' in prefix:
+ _, prefix = prefix.split(':')
+ return prefix
+
+ @classmethod
+ def infer_client(cls,
+ file_client_args: Optional[dict] = None,
+ uri: Optional[Union[str, Path]] = None) -> 'FileClient':
+ """Infer a suitable file client based on the URI and arguments.
+
+ Args:
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. Default: None.
+ uri (str | Path, optional): Uri to be parsed that contains the file
+ prefix. Default: None.
+
+ Examples:
+ >>> uri = 's3://path/of/your/file'
+ >>> file_client = FileClient.infer_client(uri=uri)
+ >>> file_client_args = {'backend': 'petrel'}
+ >>> file_client = FileClient.infer_client(file_client_args)
+
+ Returns:
+ FileClient: Instantiated FileClient object.
+ """
+ assert file_client_args is not None or uri is not None
+ if file_client_args is None:
+ file_prefix = cls.parse_uri_prefix(uri) # type: ignore
+ return cls(prefix=file_prefix)
+ else:
+ return cls(**file_client_args)
+
+ @classmethod
+ def _register_backend(cls, name, backend, force=False, prefixes=None):
+ if not isinstance(name, str):
+ raise TypeError('the backend name should be a string, '
+ f'but got {type(name)}')
+ if not inspect.isclass(backend):
+ raise TypeError(
+ f'backend should be a class but got {type(backend)}')
+ if not issubclass(backend, BaseStorageBackend):
+ raise TypeError(
+ f'backend {backend} is not a subclass of BaseStorageBackend')
+ if not force and name in cls._backends:
+ raise KeyError(
+ f'{name} is already registered as a storage backend, '
+ 'add "force=True" if you want to override it')
+
+ if name in cls._backends and force:
+ for arg_key, instance in list(cls._instances.items()):
+ if isinstance(instance.client, cls._backends[name]):
+ cls._instances.pop(arg_key)
+ cls._backends[name] = backend
+
+ if prefixes is not None:
+ if isinstance(prefixes, str):
+ prefixes = [prefixes]
+ else:
+ assert isinstance(prefixes, (list, tuple))
+ for prefix in prefixes:
+ if prefix not in cls._prefix_to_backends:
+ cls._prefix_to_backends[prefix] = backend
+ elif (prefix in cls._prefix_to_backends) and force:
+ overridden_backend = cls._prefix_to_backends[prefix]
+ if isinstance(overridden_backend, list):
+ overridden_backend = tuple(overridden_backend)
+ for arg_key, instance in list(cls._instances.items()):
+ if isinstance(instance.client, overridden_backend):
+ cls._instances.pop(arg_key)
+ cls._prefix_to_backends[prefix] = backend
+ else:
+ raise KeyError(
+ f'{prefix} is already registered as a storage backend,'
+ ' add "force=True" if you want to override it')
+
+ @classmethod
+ def register_backend(cls, name, backend=None, force=False, prefixes=None):
+ """Register a backend to FileClient.
+
+ This method can be used as a normal class method or a decorator.
+
+ .. code-block:: python
+
+ class NewBackend(BaseStorageBackend):
+
+ def get(self, filepath):
+ return filepath
+
+ def get_text(self, filepath):
+ return filepath
+
+ FileClient.register_backend('new', NewBackend)
+
+ or
+
+ .. code-block:: python
+
+ @FileClient.register_backend('new')
+ class NewBackend(BaseStorageBackend):
+
+ def get(self, filepath):
+ return filepath
+
+ def get_text(self, filepath):
+ return filepath
+
+ Args:
+ name (str): The name of the registered backend.
+ backend (class, optional): The backend class to be registered,
+ which must be a subclass of :class:`BaseStorageBackend`.
+ When this method is used as a decorator, backend is None.
+ Defaults to None.
+ force (bool, optional): Whether to override the backend if the name
+ has already been registered. Defaults to False.
+ prefixes (str or list[str] or tuple[str], optional): The prefixes
+ of the registered storage backend. Default: None.
+ `New in version 1.3.15.`
+ """
+ if backend is not None:
+ cls._register_backend(
+ name, backend, force=force, prefixes=prefixes)
+ return
+
+ def _register(backend_cls):
+ cls._register_backend(
+ name, backend_cls, force=force, prefixes=prefixes)
+ return backend_cls
+
+ return _register
+
+ def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]:
+ """Read data from a given ``filepath`` with 'rb' mode.
+
+ Note:
+ There are two types of return values for ``get``, one is ``bytes``
+ and the other is ``memoryview``. The advantage of using memoryview
+ is that you can avoid copying, and if you want to convert it to
+ ``bytes``, you can use ``.tobytes()``.
+
+ Args:
+ filepath (str or Path): Path to read data.
+
+ Returns:
+ bytes | memoryview: Expected bytes object or a memory view of the
+ bytes object.
+ """
+ return self.client.get(filepath)
+
+ def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ return self.client.get_text(filepath, encoding)
+
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'wb' mode.
+
+ Note:
+ ``put`` should create a directory if the directory of ``filepath``
+ does not exist.
+
+ Args:
+ obj (bytes): Data to be written.
+ filepath (str or Path): Path to write data.
+ """
+ self.client.put(obj, filepath)
+
+ def put_text(self, obj: str, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'w' mode.
+
+ Note:
+ ``put_text`` should create a directory if the directory of
+ ``filepath`` does not exist.
+
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str, optional): The encoding format used to open the
+ `filepath`. Default: 'utf-8'.
+ """
+ self.client.put_text(obj, filepath)
+
+ def remove(self, filepath: Union[str, Path]) -> None:
+ """Remove a file.
+
+ Args:
+ filepath (str, Path): Path to be removed.
+ """
+ self.client.remove(filepath)
+
+ def exists(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path exists.
+
+ Args:
+ filepath (str or Path): Path to be checked whether exists.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+ """
+ return self.client.exists(filepath)
+
+ def isdir(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a directory.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a
+ directory.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a directory,
+ ``False`` otherwise.
+ """
+ return self.client.isdir(filepath)
+
+ def isfile(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a file.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a file.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
+ otherwise.
+ """
+ return self.client.isfile(filepath)
+
+ def join_path(self, filepath: Union[str, Path],
+ *filepaths: Union[str, Path]) -> str:
+ """Concatenate all file paths.
+
+ Join one or more filepath components intelligently. The return value
+ is the concatenation of filepath and any members of *filepaths.
+
+ Args:
+ filepath (str or Path): Path to be concatenated.
+
+ Returns:
+ str: The result of concatenation.
+ """
+ return self.client.join_path(filepath, *filepaths)
+
+ @contextmanager
+ def get_local_path(
+ self,
+ filepath: Union[str,
+ Path]) -> Generator[Union[str, Path], None, None]:
+ """Download data from ``filepath`` and write the data to local path.
+
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+
+ Note:
+ If the ``filepath`` is a local path, just return itself.
+
+ .. warning::
+ ``get_local_path`` is an experimental interface that may change in
+ the future.
+
+ Args:
+ filepath (str or Path): Path to be read data.
+
+ Examples:
+ >>> file_client = FileClient(prefix='s3')
+ >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path:
+ ... # do something here
+
+ Yields:
+ Iterable[str]: Only yield one path.
+ """
+ with self.client.get_local_path(str(filepath)) as local_path:
+ yield local_path
+
+ def list_dir_or_file(self,
+ dir_path: Union[str, Path],
+ list_dir: bool = True,
+ list_file: bool = True,
+ suffix: Optional[Union[str, Tuple[str]]] = None,
+ recursive: bool = False) -> Iterator[str]:
+ """Scan a directory to find the interested directories or files in
+ arbitrary order.
+
+ Note:
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+
+ Args:
+ dir_path (str | Path): Path of the directory.
+ list_dir (bool): List the directories. Default: True.
+ list_file (bool): List the path of files. Default: True.
+ suffix (str or tuple[str], optional): File suffix
+ that we are interested in. Default: None.
+ recursive (bool): If set to True, recursively scan the
+ directory. Default: False.
+
+ Yields:
+ Iterable[str]: A relative path to ``dir_path``.
+ """
+ yield from self.client.list_dir_or_file(dir_path, list_dir, list_file,
+ suffix, recursive)
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..aa24d91972837b8756b225f4879bac20436eb72a
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import BaseFileHandler
+from .json_handler import JsonHandler
+from .pickle_handler import PickleHandler
+from .yaml_handler import YamlHandler
+
+__all__ = ['BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler']
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/base.py b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/base.py
new file mode 100755
index 0000000000000000000000000000000000000000..0c9cc15b67cbf7d320c2b9c6cbd441a5d5adf235
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/base.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+
+class BaseFileHandler(metaclass=ABCMeta):
+ # `str_like` is a flag to indicate whether the type of file object is
+ # str-like object or bytes-like object. Pickle only processes bytes-like
+ # objects but json only processes str-like object. If it is str-like
+ # object, `StringIO` will be used to process the buffer.
+ str_like = True
+
+ @abstractmethod
+ def load_from_fileobj(self, file, **kwargs):
+ pass
+
+ @abstractmethod
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ pass
+
+ @abstractmethod
+ def dump_to_str(self, obj, **kwargs):
+ pass
+
+ def load_from_path(self, filepath: str, mode: str = 'r', **kwargs):
+ with open(filepath, mode) as f:
+ return self.load_from_fileobj(f, **kwargs)
+
+ def dump_to_path(self, obj, filepath: str, mode: str = 'w', **kwargs):
+ with open(filepath, mode) as f:
+ self.dump_to_fileobj(obj, f, **kwargs)
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/json_handler.py b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/json_handler.py
new file mode 100755
index 0000000000000000000000000000000000000000..18d4f15f74139d20adff18b20be5529c592a66b6
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/json_handler.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+
+import numpy as np
+
+from .base import BaseFileHandler
+
+
+def set_default(obj):
+ """Set default json values for non-serializable values.
+
+ It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
+ It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
+ etc.) into plain numbers of plain python built-in types.
+ """
+ if isinstance(obj, (set, range)):
+ return list(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ elif isinstance(obj, np.generic):
+ return obj.item()
+ raise TypeError(f'{type(obj)} is unsupported for json dump')
+
+
+class JsonHandler(BaseFileHandler):
+
+ def load_from_fileobj(self, file):
+ return json.load(file)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('default', set_default)
+ json.dump(obj, file, **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('default', set_default)
+ return json.dumps(obj, **kwargs)
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/pickle_handler.py b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/pickle_handler.py
new file mode 100755
index 0000000000000000000000000000000000000000..073856fd25a731b42f3cd19269ad95744b20598f
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/pickle_handler.py
@@ -0,0 +1,26 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import pickle
+
+from .base import BaseFileHandler
+
+
+class PickleHandler(BaseFileHandler):
+
+ str_like = False
+
+ def load_from_fileobj(self, file, **kwargs):
+ return pickle.load(file, **kwargs)
+
+ def load_from_path(self, filepath, **kwargs):
+ return super().load_from_path(filepath, mode='rb', **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('protocol', 2)
+ return pickle.dumps(obj, **kwargs)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('protocol', 2)
+ pickle.dump(obj, file, **kwargs)
+
+ def dump_to_path(self, obj, filepath, **kwargs):
+ super().dump_to_path(obj, filepath, mode='wb', **kwargs)
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/yaml_handler.py b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/yaml_handler.py
new file mode 100755
index 0000000000000000000000000000000000000000..1c1b077943d634b3ddcf5ee470855179b8308e9c
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/yaml_handler.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import yaml
+
+try:
+ from yaml import CDumper as Dumper
+ from yaml import CLoader as Loader
+except ImportError:
+ from yaml import Loader, Dumper # type: ignore
+
+from .base import BaseFileHandler # isort:skip
+
+
+class YamlHandler(BaseFileHandler):
+
+ def load_from_fileobj(self, file, **kwargs):
+ kwargs.setdefault('Loader', Loader)
+ return yaml.load(file, **kwargs)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('Dumper', Dumper)
+ yaml.dump(obj, file, **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('Dumper', Dumper)
+ return yaml.dump(obj, **kwargs)
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/io.py b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/io.py
new file mode 100755
index 0000000000000000000000000000000000000000..c5cf52499661bcb3e8db92b9e7659a5489e0e994
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/io.py
@@ -0,0 +1,163 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from io import BytesIO, StringIO
+from pathlib import Path
+from typing import Any, Callable, Dict, List, Optional, TextIO, Union
+
+from ..utils import is_list_of
+from .file_client import FileClient
+from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
+
+FileLikeObject = Union[TextIO, StringIO, BytesIO]
+
+file_handlers = {
+ 'json': JsonHandler(),
+ 'yaml': YamlHandler(),
+ 'yml': YamlHandler(),
+ 'pickle': PickleHandler(),
+ 'pkl': PickleHandler()
+}
+
+
+def load(file: Union[str, Path, FileLikeObject],
+ file_format: Optional[str] = None,
+ file_client_args: Optional[Dict] = None,
+ **kwargs):
+ """Load data from json/yaml/pickle files.
+
+ This method provides a unified api for loading data from serialized files.
+
+ Note:
+ In v1.3.16 and later, ``load`` supports loading data from serialized
+ files those can be storaged in different backends.
+
+ Args:
+ file (str or :obj:`Path` or file-like object): Filename or a file-like
+ object.
+ file_format (str, optional): If not specified, the file format will be
+ inferred from the file extension, otherwise use the specified one.
+ Currently supported formats include "json", "yaml/yml" and
+ "pickle/pkl".
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`dbnet_cv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> load('/path/of/your/file') # file is storaged in disk
+ >>> load('https://path/of/your/file') # file is storaged in Internet
+ >>> load('s3://path/of/your/file') # file is storaged in petrel
+
+ Returns:
+ The content from the file.
+ """
+ if isinstance(file, Path):
+ file = str(file)
+ if file_format is None and isinstance(file, str):
+ file_format = file.split('.')[-1]
+ if file_format not in file_handlers:
+ raise TypeError(f'Unsupported format: {file_format}')
+
+ handler = file_handlers[file_format]
+ f: FileLikeObject
+ if isinstance(file, str):
+ file_client = FileClient.infer_client(file_client_args, file)
+ if handler.str_like:
+ with StringIO(file_client.get_text(file)) as f:
+ obj = handler.load_from_fileobj(f, **kwargs)
+ else:
+ with BytesIO(file_client.get(file)) as f:
+ obj = handler.load_from_fileobj(f, **kwargs)
+ elif hasattr(file, 'read'):
+ obj = handler.load_from_fileobj(file, **kwargs)
+ else:
+ raise TypeError('"file" must be a filepath str or a file-object')
+ return obj
+
+
+def dump(obj: Any,
+ file: Optional[Union[str, Path, FileLikeObject]] = None,
+ file_format: Optional[str] = None,
+ file_client_args: Optional[Dict] = None,
+ **kwargs):
+ """Dump data to json/yaml/pickle strings or files.
+
+ This method provides a unified api for dumping data as strings or to files,
+ and also supports custom arguments for each file format.
+
+ Note:
+ In v1.3.16 and later, ``dump`` supports dumping data as strings or to
+ files which is saved to different backends.
+
+ Args:
+ obj (any): The python object to be dumped.
+ file (str or :obj:`Path` or file-like object, optional): If not
+ specified, then the object is dumped to a str, otherwise to a file
+ specified by the filename or file-like object.
+ file_format (str, optional): Same as :func:`load`.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`dbnet_cv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> dump('hello world', '/path/of/your/file') # disk
+ >>> dump('hello world', 's3://path/of/your/file') # ceph or petrel
+
+ Returns:
+ bool: True for success, False otherwise.
+ """
+ if isinstance(file, Path):
+ file = str(file)
+ if file_format is None:
+ if isinstance(file, str):
+ file_format = file.split('.')[-1]
+ elif file is None:
+ raise ValueError(
+ 'file_format must be specified since file is None')
+ if file_format not in file_handlers:
+ raise TypeError(f'Unsupported format: {file_format}')
+ f: FileLikeObject
+ handler = file_handlers[file_format]
+ if file is None:
+ return handler.dump_to_str(obj, **kwargs)
+ elif isinstance(file, str):
+ file_client = FileClient.infer_client(file_client_args, file)
+ if handler.str_like:
+ with StringIO() as f:
+ handler.dump_to_fileobj(obj, f, **kwargs)
+ file_client.put_text(f.getvalue(), file)
+ else:
+ with BytesIO() as f:
+ handler.dump_to_fileobj(obj, f, **kwargs)
+ file_client.put(f.getvalue(), file)
+ elif hasattr(file, 'write'):
+ handler.dump_to_fileobj(obj, file, **kwargs)
+ else:
+ raise TypeError('"file" must be a filename str or a file-object')
+
+
+def _register_handler(handler: BaseFileHandler,
+ file_formats: Union[str, List[str]]) -> None:
+ """Register a handler for some file extensions.
+
+ Args:
+ handler (:obj:`BaseFileHandler`): Handler to be registered.
+ file_formats (str or list[str]): File formats to be handled by this
+ handler.
+ """
+ if not isinstance(handler, BaseFileHandler):
+ raise TypeError(
+ f'handler must be a child of BaseFileHandler, not {type(handler)}')
+ if isinstance(file_formats, str):
+ file_formats = [file_formats]
+ if not is_list_of(file_formats, str):
+ raise TypeError('file_formats must be a str or a list of str')
+ for ext in file_formats:
+ file_handlers[ext] = handler
+
+
+def register_handler(file_formats: Union[str, list], **kwargs) -> Callable:
+
+ def wrap(cls):
+ _register_handler(cls(**kwargs), file_formats)
+ return cls
+
+ return wrap
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/parse.py b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/parse.py
new file mode 100755
index 0000000000000000000000000000000000000000..298e1f99a23fabe39911e7c68f49854e6d50d43d
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/parse.py
@@ -0,0 +1,99 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from io import StringIO
+from pathlib import Path
+from typing import Dict, List, Optional, Union
+
+from .file_client import FileClient
+
+
+def list_from_file(filename: Union[str, Path],
+ prefix: str = '',
+ offset: int = 0,
+ max_num: int = 0,
+ encoding: str = 'utf-8',
+ file_client_args: Optional[Dict] = None) -> List:
+ """Load a text file and parse the content as a list of strings.
+
+ Note:
+ In v1.3.16 and later, ``list_from_file`` supports loading a text file
+ which can be storaged in different backends and parsing the content as
+ a list for strings.
+
+ Args:
+ filename (str): Filename.
+ prefix (str): The prefix to be inserted to the beginning of each item.
+ offset (int): The offset of lines.
+ max_num (int): The maximum number of lines to be read,
+ zeros and negatives mean no limitation.
+ encoding (str): Encoding used to open the file. Default utf-8.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`dbnet_cv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> list_from_file('/path/of/your/file') # disk
+ ['hello', 'world']
+ >>> list_from_file('s3://path/of/your/file') # ceph or petrel
+ ['hello', 'world']
+
+ Returns:
+ list[str]: A list of strings.
+ """
+ cnt = 0
+ item_list = []
+ file_client = FileClient.infer_client(file_client_args, filename)
+ with StringIO(file_client.get_text(filename, encoding)) as f:
+ for _ in range(offset):
+ f.readline()
+ for line in f:
+ if 0 < max_num <= cnt:
+ break
+ item_list.append(prefix + line.rstrip('\n\r'))
+ cnt += 1
+ return item_list
+
+
+def dict_from_file(filename: Union[str, Path],
+ key_type: type = str,
+ encoding: str = 'utf-8',
+ file_client_args: Optional[Dict] = None) -> Dict:
+ """Load a text file and parse the content as a dict.
+
+ Each line of the text file will be two or more columns split by
+ whitespaces or tabs. The first column will be parsed as dict keys, and
+ the following columns will be parsed as dict values.
+
+ Note:
+ In v1.3.16 and later, ``dict_from_file`` supports loading a text file
+ which can be storaged in different backends and parsing the content as
+ a dict.
+
+ Args:
+ filename(str): Filename.
+ key_type(type): Type of the dict keys. str is user by default and
+ type conversion will be performed if specified.
+ encoding (str): Encoding used to open the file. Default utf-8.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`dbnet_cv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> dict_from_file('/path/of/your/file') # disk
+ {'key1': 'value1', 'key2': 'value2'}
+ >>> dict_from_file('s3://path/of/your/file') # ceph or petrel
+ {'key1': 'value1', 'key2': 'value2'}
+
+ Returns:
+ dict: The parsed contents.
+ """
+ mapping = {}
+ file_client = FileClient.infer_client(file_client_args, filename)
+ with StringIO(file_client.get_text(filename, encoding)) as f:
+ for line in f:
+ items = line.rstrip('\n').split()
+ assert len(items) >= 2
+ key = key_type(items[0])
+ val = items[1:] if len(items) > 2 else items[1]
+ mapping[key] = val
+ return mapping
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/image/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/image/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..7bbcdefd0159d8b95cf67d6e4f21cd08b9f24e88
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/image/__init__.py
@@ -0,0 +1,33 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .colorspace import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, bgr2ycbcr,
+ gray2bgr, gray2rgb, hls2bgr, hsv2bgr, imconvert,
+ rgb2bgr, rgb2gray, rgb2ycbcr, ycbcr2bgr, ycbcr2rgb)
+from .geometric import (cutout, imcrop, imflip, imflip_, impad,
+ impad_to_multiple, imrescale, imresize, imresize_like,
+ imresize_to_multiple, imrotate, imshear, imtranslate,
+ rescale_size)
+from .io import imfrombytes, imread, imwrite, supported_backends, use_backend
+from .misc import tensor2imgs
+# from .photometric import (adjust_brightness, adjust_color, adjust_contrast,
+# adjust_hue, adjust_lighting, adjust_sharpness,
+# auto_contrast, clahe, imdenormalize, imequalize,
+# iminvert, imnormalize, imnormalize_, lut_transform,
+# posterize, solarize)
+from .photometric import (adjust_brightness, adjust_color, adjust_contrast,
+ adjust_hue, adjust_lighting, adjust_sharpness,
+ auto_contrast, clahe, imdenormalize, imequalize,
+ iminvert, imnormalize, imnormalize_, lut_transform,
+ posterize, solarize)
+# __all__ = [
+# 'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb',
+# 'hls2bgr', 'hsv2bgr', 'imconvert', 'rgb2bgr', 'rgb2gray', 'imrescale',
+# 'imresize', 'imresize_like', 'imresize_to_multiple', 'rescale_size',
+# 'imcrop', 'imflip', 'imflip_', 'impad', 'impad_to_multiple', 'imrotate',
+# 'imfrombytes', 'imread', 'imwrite', 'supported_backends', 'use_backend',
+# 'imdenormalize', 'imnormalize', 'imnormalize_', 'iminvert', 'posterize',
+# 'solarize', 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr',
+# 'tensor2imgs', 'imshear', 'imtranslate', 'adjust_color', 'imequalize',
+# 'adjust_brightness', 'adjust_contrast', 'lut_transform', 'clahe',
+# 'adjust_sharpness', 'auto_contrast', 'cutout', 'adjust_lighting',
+# 'adjust_hue'
+# ]
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/image/colorspace.py b/cv/ocr/dbnet/pytorch/dbnet_cv/image/colorspace.py
new file mode 100755
index 0000000000000000000000000000000000000000..08f9952408c8e0bb38b17c10e2089e900ed418c2
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/image/colorspace.py
@@ -0,0 +1,309 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Callable, Union
+
+import cv2
+import numpy as np
+
+
+def imconvert(img: np.ndarray, src: str, dst: str) -> np.ndarray:
+ """Convert an image from the src colorspace to dst colorspace.
+
+ Args:
+ img (ndarray): The input image.
+ src (str): The source colorspace, e.g., 'rgb', 'hsv'.
+ dst (str): The destination colorspace, e.g., 'rgb', 'hsv'.
+
+ Returns:
+ ndarray: The converted image.
+ """
+ code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
+ out_img = cv2.cvtColor(img, code)
+ return out_img
+
+
+def bgr2gray(img: np.ndarray, keepdim: bool = False) -> np.ndarray:
+ """Convert a BGR image to grayscale image.
+
+ Args:
+ img (ndarray): The input image.
+ keepdim (bool): If False (by default), then return the grayscale image
+ with 2 dims, otherwise 3 dims.
+
+ Returns:
+ ndarray: The converted grayscale image.
+ """
+ out_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ if keepdim:
+ out_img = out_img[..., None]
+ return out_img
+
+
+def rgb2gray(img: np.ndarray, keepdim: bool = False) -> np.ndarray:
+ """Convert a RGB image to grayscale image.
+
+ Args:
+ img (ndarray): The input image.
+ keepdim (bool): If False (by default), then return the grayscale image
+ with 2 dims, otherwise 3 dims.
+
+ Returns:
+ ndarray: The converted grayscale image.
+ """
+ out_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+ if keepdim:
+ out_img = out_img[..., None]
+ return out_img
+
+
+def gray2bgr(img: np.ndarray) -> np.ndarray:
+ """Convert a grayscale image to BGR image.
+
+ Args:
+ img (ndarray): The input image.
+
+ Returns:
+ ndarray: The converted BGR image.
+ """
+ img = img[..., None] if img.ndim == 2 else img
+ out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ return out_img
+
+
+def gray2rgb(img: np.ndarray) -> np.ndarray:
+ """Convert a grayscale image to RGB image.
+
+ Args:
+ img (ndarray): The input image.
+
+ Returns:
+ ndarray: The converted RGB image.
+ """
+ img = img[..., None] if img.ndim == 2 else img
+ out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+ return out_img
+
+
+def _convert_input_type_range(img: np.ndarray) -> np.ndarray:
+ """Convert the type and range of the input image.
+
+ It converts the input image to np.float32 type and range of [0, 1].
+ It is mainly used for pre-processing the input image in colorspace
+ conversion functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with type of np.float32 and range of
+ [0, 1].
+ """
+ img_type = img.dtype
+ img = img.astype(np.float32)
+ if img_type == np.float32:
+ pass
+ elif img_type == np.uint8:
+ img /= 255.
+ else:
+ raise TypeError('The img type should be np.float32 or np.uint8, '
+ f'but got {img_type}')
+ return img
+
+
+def _convert_output_type_range(
+ img: np.ndarray, dst_type: Union[np.uint8, np.float32]) -> np.ndarray:
+ """Convert the type and range of the image according to dst_type.
+
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
+ images will be converted to np.uint8 type with range [0, 255]. If
+ `dst_type` is np.float32, it converts the image to np.float32 type with
+ range [0, 1].
+ It is mainly used for post-processing images in colorspace conversion
+ functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The image to be converted with np.float32 type and
+ range [0, 255].
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
+ converts the image to np.uint8 type with range [0, 255]. If
+ dst_type is np.float32, it converts the image to np.float32 type
+ with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with desired type and range.
+ """
+ if dst_type not in (np.uint8, np.float32):
+ raise TypeError('The dst_type should be np.float32 or np.uint8, '
+ f'but got {dst_type}')
+ if dst_type == np.uint8:
+ img = img.round()
+ else:
+ img /= 255.
+ return img.astype(dst_type)
+
+
+def rgb2ycbcr(img: np.ndarray, y_only: bool = False) -> np.ndarray:
+ """Convert a RGB image to YCbCr image.
+
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def bgr2ycbcr(img: np.ndarray, y_only: bool = False) -> np.ndarray:
+ """Convert a BGR image to YCbCr image.
+
+ The bgr version of rgb2ycbcr.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2rgb(img: np.ndarray) -> np.ndarray:
+ """Convert a YCbCr image to RGB image.
+
+ This function produces the same results as Matlab's ycbcr2rgb function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted RGB image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
+ [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [
+ -222.921, 135.576, -276.836
+ ]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2bgr(img: np.ndarray) -> np.ndarray:
+ """Convert a YCbCr image to BGR image.
+
+ The bgr version of ycbcr2rgb.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted BGR image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
+ [0.00791071, -0.00153632, 0],
+ [0, -0.00318811, 0.00625893]]) * 255.0 + [
+ -276.836, 135.576, -222.921
+ ]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def convert_color_factory(src: str, dst: str) -> Callable:
+
+ code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
+
+ def convert_color(img: np.ndarray) -> np.ndarray:
+ out_img = cv2.cvtColor(img, code)
+ return out_img
+
+ convert_color.__doc__ = f"""Convert a {src.upper()} image to {dst.upper()}
+ image.
+
+ Args:
+ img (ndarray or str): The input image.
+
+ Returns:
+ ndarray: The converted {dst.upper()} image.
+ """
+
+ return convert_color
+
+
+bgr2rgb = convert_color_factory('bgr', 'rgb')
+
+rgb2bgr = convert_color_factory('rgb', 'bgr')
+
+bgr2hsv = convert_color_factory('bgr', 'hsv')
+
+hsv2bgr = convert_color_factory('hsv', 'bgr')
+
+bgr2hls = convert_color_factory('bgr', 'hls')
+
+hls2bgr = convert_color_factory('hls', 'bgr')
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/image/geometric.py b/cv/ocr/dbnet/pytorch/dbnet_cv/image/geometric.py
new file mode 100755
index 0000000000000000000000000000000000000000..38e440b6a9d86cd7958fc797cc9416b8df5a7e57
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/image/geometric.py
@@ -0,0 +1,741 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numbers
+
+import cv2
+import numpy as np
+
+from ..utils import to_2tuple
+from .io import imread_backend
+
+try:
+ from PIL import Image
+except ImportError:
+ Image = None
+
+
+def _scale_size(size, scale):
+ """Rescale a size by a ratio.
+
+ Args:
+ size (tuple[int]): (w, h).
+ scale (float | tuple(float)): Scaling factor.
+
+ Returns:
+ tuple[int]: scaled size.
+ """
+ if isinstance(scale, (float, int)):
+ scale = (scale, scale)
+ w, h = size
+ return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
+
+
+cv2_interp_codes = {
+ 'nearest': cv2.INTER_NEAREST,
+ 'bilinear': cv2.INTER_LINEAR,
+ 'bicubic': cv2.INTER_CUBIC,
+ 'area': cv2.INTER_AREA,
+ 'lanczos': cv2.INTER_LANCZOS4
+}
+
+# Pillow >=v9.1.0 use a slightly different naming scheme for filters.
+# Set pillow_interp_codes according to the naming scheme used.
+if Image is not None:
+ if hasattr(Image, 'Resampling'):
+ pillow_interp_codes = {
+ 'nearest': Image.Resampling.NEAREST,
+ 'bilinear': Image.Resampling.BILINEAR,
+ 'bicubic': Image.Resampling.BICUBIC,
+ 'box': Image.Resampling.BOX,
+ 'lanczos': Image.Resampling.LANCZOS,
+ 'hamming': Image.Resampling.HAMMING
+ }
+ else:
+ pillow_interp_codes = {
+ 'nearest': Image.NEAREST,
+ 'bilinear': Image.BILINEAR,
+ 'bicubic': Image.BICUBIC,
+ 'box': Image.BOX,
+ 'lanczos': Image.LANCZOS,
+ 'hamming': Image.HAMMING
+ }
+
+
+def imresize(img,
+ size,
+ return_scale=False,
+ interpolation='bilinear',
+ out=None,
+ backend=None):
+ """Resize image to a given size.
+
+ Args:
+ img (ndarray): The input image.
+ size (tuple[int]): Target size (w, h).
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Interpolation method, accepted values are
+ "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+ backend, "nearest", "bilinear" for 'pillow' backend.
+ out (ndarray): The output destination.
+ backend (str | None): The image resize backend type. Options are `cv2`,
+ `pillow`, `None`. If backend is None, the global imread_backend
+ specified by ``dbnet_cv.use_backend()`` will be used. Default: None.
+
+ Returns:
+ tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = img.shape[:2]
+ if backend is None:
+ backend = imread_backend
+ if backend not in ['cv2', 'pillow']:
+ raise ValueError(f'backend: {backend} is not supported for resize.'
+ f"Supported backends are 'cv2', 'pillow'")
+
+ if backend == 'pillow':
+ assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
+ pil_image = Image.fromarray(img)
+ pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
+ resized_img = np.array(pil_image)
+ else:
+ resized_img = cv2.resize(
+ img, size, dst=out, interpolation=cv2_interp_codes[interpolation])
+ if not return_scale:
+ return resized_img
+ else:
+ w_scale = size[0] / w
+ h_scale = size[1] / h
+ return resized_img, w_scale, h_scale
+
+
+def imresize_to_multiple(img,
+ divisor,
+ size=None,
+ scale_factor=None,
+ keep_ratio=False,
+ return_scale=False,
+ interpolation='bilinear',
+ out=None,
+ backend=None):
+ """Resize image according to a given size or scale factor and then rounds
+ up the the resized or rescaled image size to the nearest value that can be
+ divided by the divisor.
+
+ Args:
+ img (ndarray): The input image.
+ divisor (int | tuple): Resized image size will be a multiple of
+ divisor. If divisor is a tuple, divisor should be
+ (w_divisor, h_divisor).
+ size (None | int | tuple[int]): Target size (w, h). Default: None.
+ scale_factor (None | float | tuple[float]): Multiplier for spatial
+ size. Should match input size if it is a tuple and the 2D style is
+ (w_scale_factor, h_scale_factor). Default: None.
+ keep_ratio (bool): Whether to keep the aspect ratio when resizing the
+ image. Default: False.
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Interpolation method, accepted values are
+ "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+ backend, "nearest", "bilinear" for 'pillow' backend.
+ out (ndarray): The output destination.
+ backend (str | None): The image resize backend type. Options are `cv2`,
+ `pillow`, `None`. If backend is None, the global imread_backend
+ specified by ``dbnet_cv.use_backend()`` will be used. Default: None.
+
+ Returns:
+ tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = img.shape[:2]
+ if size is not None and scale_factor is not None:
+ raise ValueError('only one of size or scale_factor should be defined')
+ elif size is None and scale_factor is None:
+ raise ValueError('one of size or scale_factor should be defined')
+ elif size is not None:
+ size = to_2tuple(size)
+ if keep_ratio:
+ size = rescale_size((w, h), size, return_scale=False)
+ else:
+ size = _scale_size((w, h), scale_factor)
+
+ divisor = to_2tuple(divisor)
+ size = tuple(int(np.ceil(s / d)) * d for s, d in zip(size, divisor))
+ resized_img, w_scale, h_scale = imresize(
+ img,
+ size,
+ return_scale=True,
+ interpolation=interpolation,
+ out=out,
+ backend=backend)
+ if return_scale:
+ return resized_img, w_scale, h_scale
+ else:
+ return resized_img
+
+
+def imresize_like(img,
+ dst_img,
+ return_scale=False,
+ interpolation='bilinear',
+ backend=None):
+ """Resize image to the same size of a given image.
+
+ Args:
+ img (ndarray): The input image.
+ dst_img (ndarray): The target image.
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Same as :func:`resize`.
+ backend (str | None): Same as :func:`resize`.
+
+ Returns:
+ tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = dst_img.shape[:2]
+ return imresize(img, (w, h), return_scale, interpolation, backend=backend)
+
+
+def rescale_size(old_size, scale, return_scale=False):
+ """Calculate the new size to be rescaled to.
+
+ Args:
+ old_size (tuple[int]): The old size (w, h) of image.
+ scale (float | tuple[int]): The scaling factor or maximum size.
+ If it is a float number, then the image will be rescaled by this
+ factor, else if it is a tuple of 2 integers, then the image will
+ be rescaled as large as possible within the scale.
+ return_scale (bool): Whether to return the scaling factor besides the
+ rescaled image size.
+
+ Returns:
+ tuple[int]: The new rescaled image size.
+ """
+ w, h = old_size
+ if isinstance(scale, (float, int)):
+ if scale <= 0:
+ raise ValueError(f'Invalid scale {scale}, must be positive.')
+ scale_factor = scale
+ elif isinstance(scale, tuple):
+ max_long_edge = max(scale)
+ max_short_edge = min(scale)
+ scale_factor = min(max_long_edge / max(h, w),
+ max_short_edge / min(h, w))
+ else:
+ raise TypeError(
+ f'Scale must be a number or tuple of int, but got {type(scale)}')
+
+ new_size = _scale_size((w, h), scale_factor)
+
+ if return_scale:
+ return new_size, scale_factor
+ else:
+ return new_size
+
+
+def imrescale(img,
+ scale,
+ return_scale=False,
+ interpolation='bilinear',
+ backend=None):
+ """Resize image while keeping the aspect ratio.
+
+ Args:
+ img (ndarray): The input image.
+ scale (float | tuple[int]): The scaling factor or maximum size.
+ If it is a float number, then the image will be rescaled by this
+ factor, else if it is a tuple of 2 integers, then the image will
+ be rescaled as large as possible within the scale.
+ return_scale (bool): Whether to return the scaling factor besides the
+ rescaled image.
+ interpolation (str): Same as :func:`resize`.
+ backend (str | None): Same as :func:`resize`.
+
+ Returns:
+ ndarray: The rescaled image.
+ """
+ h, w = img.shape[:2]
+ new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
+ rescaled_img = imresize(
+ img, new_size, interpolation=interpolation, backend=backend)
+ if return_scale:
+ return rescaled_img, scale_factor
+ else:
+ return rescaled_img
+
+
+def imflip(img, direction='horizontal'):
+ """Flip an image horizontally or vertically.
+
+ Args:
+ img (ndarray): Image to be flipped.
+ direction (str): The flip direction, either "horizontal" or
+ "vertical" or "diagonal".
+
+ Returns:
+ ndarray: The flipped image.
+ """
+ assert direction in ['horizontal', 'vertical', 'diagonal']
+ if direction == 'horizontal':
+ return np.flip(img, axis=1)
+ elif direction == 'vertical':
+ return np.flip(img, axis=0)
+ else:
+ return np.flip(img, axis=(0, 1))
+
+
+def imflip_(img, direction='horizontal'):
+ """Inplace flip an image horizontally or vertically.
+
+ Args:
+ img (ndarray): Image to be flipped.
+ direction (str): The flip direction, either "horizontal" or
+ "vertical" or "diagonal".
+
+ Returns:
+ ndarray: The flipped image (inplace).
+ """
+ assert direction in ['horizontal', 'vertical', 'diagonal']
+ if direction == 'horizontal':
+ return cv2.flip(img, 1, img)
+ elif direction == 'vertical':
+ return cv2.flip(img, 0, img)
+ else:
+ return cv2.flip(img, -1, img)
+
+
+def imrotate(img,
+ angle,
+ center=None,
+ scale=1.0,
+ border_value=0,
+ interpolation='bilinear',
+ auto_bound=False):
+ """Rotate an image.
+
+ Args:
+ img (ndarray): Image to be rotated.
+ angle (float): Rotation angle in degrees, positive values mean
+ clockwise rotation.
+ center (tuple[float], optional): Center point (w, h) of the rotation in
+ the source image. If not specified, the center of the image will be
+ used.
+ scale (float): Isotropic scale factor.
+ border_value (int): Border value.
+ interpolation (str): Same as :func:`resize`.
+ auto_bound (bool): Whether to adjust the image size to cover the whole
+ rotated image.
+
+ Returns:
+ ndarray: The rotated image.
+ """
+ if center is not None and auto_bound:
+ raise ValueError('`auto_bound` conflicts with `center`')
+ h, w = img.shape[:2]
+ if center is None:
+ center = ((w - 1) * 0.5, (h - 1) * 0.5)
+ assert isinstance(center, tuple)
+
+ matrix = cv2.getRotationMatrix2D(center, -angle, scale)
+ if auto_bound:
+ cos = np.abs(matrix[0, 0])
+ sin = np.abs(matrix[0, 1])
+ new_w = h * sin + w * cos
+ new_h = h * cos + w * sin
+ matrix[0, 2] += (new_w - w) * 0.5
+ matrix[1, 2] += (new_h - h) * 0.5
+ w = int(np.round(new_w))
+ h = int(np.round(new_h))
+ rotated = cv2.warpAffine(
+ img,
+ matrix, (w, h),
+ flags=cv2_interp_codes[interpolation],
+ borderValue=border_value)
+ return rotated
+
+
+def bbox_clip(bboxes, img_shape):
+ """Clip bboxes to fit the image shape.
+
+ Args:
+ bboxes (ndarray): Shape (..., 4*k)
+ img_shape (tuple[int]): (height, width) of the image.
+
+ Returns:
+ ndarray: Clipped bboxes.
+ """
+ assert bboxes.shape[-1] % 4 == 0
+ cmin = np.empty(bboxes.shape[-1], dtype=bboxes.dtype)
+ cmin[0::2] = img_shape[1] - 1
+ cmin[1::2] = img_shape[0] - 1
+ clipped_bboxes = np.maximum(np.minimum(bboxes, cmin), 0)
+ return clipped_bboxes
+
+
+def bbox_scaling(bboxes, scale, clip_shape=None):
+ """Scaling bboxes w.r.t the box center.
+
+ Args:
+ bboxes (ndarray): Shape(..., 4).
+ scale (float): Scaling factor.
+ clip_shape (tuple[int], optional): If specified, bboxes that exceed the
+ boundary will be clipped according to the given shape (h, w).
+
+ Returns:
+ ndarray: Scaled bboxes.
+ """
+ if float(scale) == 1.0:
+ scaled_bboxes = bboxes.copy()
+ else:
+ w = bboxes[..., 2] - bboxes[..., 0] + 1
+ h = bboxes[..., 3] - bboxes[..., 1] + 1
+ dw = (w * (scale - 1)) * 0.5
+ dh = (h * (scale - 1)) * 0.5
+ scaled_bboxes = bboxes + np.stack((-dw, -dh, dw, dh), axis=-1)
+ if clip_shape is not None:
+ return bbox_clip(scaled_bboxes, clip_shape)
+ else:
+ return scaled_bboxes
+
+
+def imcrop(img, bboxes, scale=1.0, pad_fill=None):
+ """Crop image patches.
+
+ 3 steps: scale the bboxes -> clip bboxes -> crop and pad.
+
+ Args:
+ img (ndarray): Image to be cropped.
+ bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes.
+ scale (float, optional): Scale ratio of bboxes, the default value
+ 1.0 means no padding.
+ pad_fill (Number | list[Number]): Value to be filled for padding.
+ Default: None, which means no padding.
+
+ Returns:
+ list[ndarray] | ndarray: The cropped image patches.
+ """
+ chn = 1 if img.ndim == 2 else img.shape[2]
+ if pad_fill is not None:
+ if isinstance(pad_fill, (int, float)):
+ pad_fill = [pad_fill for _ in range(chn)]
+ assert len(pad_fill) == chn
+
+ _bboxes = bboxes[None, ...] if bboxes.ndim == 1 else bboxes
+ scaled_bboxes = bbox_scaling(_bboxes, scale).astype(np.int32)
+ clipped_bbox = bbox_clip(scaled_bboxes, img.shape)
+
+ patches = []
+ for i in range(clipped_bbox.shape[0]):
+ x1, y1, x2, y2 = tuple(clipped_bbox[i, :])
+ if pad_fill is None:
+ patch = img[y1:y2 + 1, x1:x2 + 1, ...]
+ else:
+ _x1, _y1, _x2, _y2 = tuple(scaled_bboxes[i, :])
+ if chn == 1:
+ patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1)
+ else:
+ patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1, chn)
+ patch = np.array(
+ pad_fill, dtype=img.dtype) * np.ones(
+ patch_shape, dtype=img.dtype)
+ x_start = 0 if _x1 >= 0 else -_x1
+ y_start = 0 if _y1 >= 0 else -_y1
+ w = x2 - x1 + 1
+ h = y2 - y1 + 1
+ patch[y_start:y_start + h, x_start:x_start + w,
+ ...] = img[y1:y1 + h, x1:x1 + w, ...]
+ patches.append(patch)
+
+ if bboxes.ndim == 1:
+ return patches[0]
+ else:
+ return patches
+
+
+def impad(img,
+ *,
+ shape=None,
+ padding=None,
+ pad_val=0,
+ padding_mode='constant'):
+ """Pad the given image to a certain shape or pad on all sides with
+ specified padding mode and padding value.
+
+ Args:
+ img (ndarray): Image to be padded.
+ shape (tuple[int]): Expected padding shape (h, w). Default: None.
+ padding (int or tuple[int]): Padding on each border. If a single int is
+ provided this is used to pad all borders. If tuple of length 2 is
+ provided this is the padding on left/right and top/bottom
+ respectively. If a tuple of length 4 is provided this is the
+ padding for the left, top, right and bottom borders respectively.
+ Default: None. Note that `shape` and `padding` can not be both
+ set.
+ pad_val (Number | Sequence[Number]): Values to be filled in padding
+ areas when padding_mode is 'constant'. Default: 0.
+ padding_mode (str): Type of padding. Should be: constant, edge,
+ reflect or symmetric. Default: constant.
+ - constant: pads with a constant value, this value is specified
+ with pad_val.
+ - edge: pads with the last value at the edge of the image.
+ - reflect: pads with reflection of image without repeating the last
+ value on the edge. For example, padding [1, 2, 3, 4] with 2
+ elements on both sides in reflect mode will result in
+ [3, 2, 1, 2, 3, 4, 3, 2].
+ - symmetric: pads with reflection of image repeating the last value
+ on the edge. For example, padding [1, 2, 3, 4] with 2 elements on
+ both sides in symmetric mode will result in
+ [2, 1, 1, 2, 3, 4, 4, 3]
+
+ Returns:
+ ndarray: The padded image.
+ """
+
+ assert (shape is not None) ^ (padding is not None)
+ if shape is not None:
+ width = max(shape[1] - img.shape[1], 0)
+ height = max(shape[0] - img.shape[0], 0)
+ padding = (0, 0, width, height)
+
+ # check pad_val
+ if isinstance(pad_val, tuple):
+ assert len(pad_val) == img.shape[-1]
+ elif not isinstance(pad_val, numbers.Number):
+ raise TypeError('pad_val must be a int or a tuple. '
+ f'But received {type(pad_val)}')
+
+ # check padding
+ if isinstance(padding, tuple) and len(padding) in [2, 4]:
+ if len(padding) == 2:
+ padding = (padding[0], padding[1], padding[0], padding[1])
+ elif isinstance(padding, numbers.Number):
+ padding = (padding, padding, padding, padding)
+ else:
+ raise ValueError('Padding must be a int or a 2, or 4 element tuple.'
+ f'But received {padding}')
+
+ # check padding mode
+ assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
+
+ border_type = {
+ 'constant': cv2.BORDER_CONSTANT,
+ 'edge': cv2.BORDER_REPLICATE,
+ 'reflect': cv2.BORDER_REFLECT_101,
+ 'symmetric': cv2.BORDER_REFLECT
+ }
+ img = cv2.copyMakeBorder(
+ img,
+ padding[1],
+ padding[3],
+ padding[0],
+ padding[2],
+ border_type[padding_mode],
+ value=pad_val)
+
+ return img
+
+
+def impad_to_multiple(img, divisor, pad_val=0):
+ """Pad an image to ensure each edge to be multiple to some number.
+
+ Args:
+ img (ndarray): Image to be padded.
+ divisor (int): Padded image edges will be multiple to divisor.
+ pad_val (Number | Sequence[Number]): Same as :func:`impad`.
+
+ Returns:
+ ndarray: The padded image.
+ """
+ pad_h = int(np.ceil(img.shape[0] / divisor)) * divisor
+ pad_w = int(np.ceil(img.shape[1] / divisor)) * divisor
+ return impad(img, shape=(pad_h, pad_w), pad_val=pad_val)
+
+
+def cutout(img, shape, pad_val=0):
+ """Randomly cut out a rectangle from the original img.
+
+ Args:
+ img (ndarray): Image to be cutout.
+ shape (int | tuple[int]): Expected cutout shape (h, w). If given as a
+ int, the value will be used for both h and w.
+ pad_val (int | float | tuple[int | float]): Values to be filled in the
+ cut area. Defaults to 0.
+
+ Returns:
+ ndarray: The cutout image.
+ """
+
+ channels = 1 if img.ndim == 2 else img.shape[2]
+ if isinstance(shape, int):
+ cut_h, cut_w = shape, shape
+ else:
+ assert isinstance(shape, tuple) and len(shape) == 2, \
+ f'shape must be a int or a tuple with length 2, but got type ' \
+ f'{type(shape)} instead.'
+ cut_h, cut_w = shape
+ if isinstance(pad_val, (int, float)):
+ pad_val = tuple([pad_val] * channels)
+ elif isinstance(pad_val, tuple):
+ assert len(pad_val) == channels, \
+ 'Expected the num of elements in tuple equals the channels' \
+ 'of input image. Found {} vs {}'.format(
+ len(pad_val), channels)
+ else:
+ raise TypeError(f'Invalid type {type(pad_val)} for `pad_val`')
+
+ img_h, img_w = img.shape[:2]
+ y0 = np.random.uniform(img_h)
+ x0 = np.random.uniform(img_w)
+
+ y1 = int(max(0, y0 - cut_h / 2.))
+ x1 = int(max(0, x0 - cut_w / 2.))
+ y2 = min(img_h, y1 + cut_h)
+ x2 = min(img_w, x1 + cut_w)
+
+ if img.ndim == 2:
+ patch_shape = (y2 - y1, x2 - x1)
+ else:
+ patch_shape = (y2 - y1, x2 - x1, channels)
+
+ img_cutout = img.copy()
+ patch = np.array(
+ pad_val, dtype=img.dtype) * np.ones(
+ patch_shape, dtype=img.dtype)
+ img_cutout[y1:y2, x1:x2, ...] = patch
+
+ return img_cutout
+
+
+def _get_shear_matrix(magnitude, direction='horizontal'):
+ """Generate the shear matrix for transformation.
+
+ Args:
+ magnitude (int | float): The magnitude used for shear.
+ direction (str): The flip direction, either "horizontal"
+ or "vertical".
+
+ Returns:
+ ndarray: The shear matrix with dtype float32.
+ """
+ if direction == 'horizontal':
+ shear_matrix = np.float32([[1, magnitude, 0], [0, 1, 0]])
+ elif direction == 'vertical':
+ shear_matrix = np.float32([[1, 0, 0], [magnitude, 1, 0]])
+ return shear_matrix
+
+
+def imshear(img,
+ magnitude,
+ direction='horizontal',
+ border_value=0,
+ interpolation='bilinear'):
+ """Shear an image.
+
+ Args:
+ img (ndarray): Image to be sheared with format (h, w)
+ or (h, w, c).
+ magnitude (int | float): The magnitude used for shear.
+ direction (str): The flip direction, either "horizontal"
+ or "vertical".
+ border_value (int | tuple[int]): Value used in case of a
+ constant border.
+ interpolation (str): Same as :func:`resize`.
+
+ Returns:
+ ndarray: The sheared image.
+ """
+ assert direction in ['horizontal',
+ 'vertical'], f'Invalid direction: {direction}'
+ height, width = img.shape[:2]
+ if img.ndim == 2:
+ channels = 1
+ elif img.ndim == 3:
+ channels = img.shape[-1]
+ if isinstance(border_value, int):
+ border_value = tuple([border_value] * channels)
+ elif isinstance(border_value, tuple):
+ assert len(border_value) == channels, \
+ 'Expected the num of elements in tuple equals the channels' \
+ 'of input image. Found {} vs {}'.format(
+ len(border_value), channels)
+ else:
+ raise ValueError(
+ f'Invalid type {type(border_value)} for `border_value`')
+ shear_matrix = _get_shear_matrix(magnitude, direction)
+ sheared = cv2.warpAffine(
+ img,
+ shear_matrix,
+ (width, height),
+ # Note case when the number elements in `border_value`
+ # greater than 3 (e.g. shearing masks whose channels large
+ # than 3) will raise TypeError in `cv2.warpAffine`.
+ # Here simply slice the first 3 values in `border_value`.
+ borderValue=border_value[:3],
+ flags=cv2_interp_codes[interpolation])
+ return sheared
+
+
+def _get_translate_matrix(offset, direction='horizontal'):
+ """Generate the translate matrix.
+
+ Args:
+ offset (int | float): The offset used for translate.
+ direction (str): The translate direction, either
+ "horizontal" or "vertical".
+
+ Returns:
+ ndarray: The translate matrix with dtype float32.
+ """
+ if direction == 'horizontal':
+ translate_matrix = np.float32([[1, 0, offset], [0, 1, 0]])
+ elif direction == 'vertical':
+ translate_matrix = np.float32([[1, 0, 0], [0, 1, offset]])
+ return translate_matrix
+
+
+def imtranslate(img,
+ offset,
+ direction='horizontal',
+ border_value=0,
+ interpolation='bilinear'):
+ """Translate an image.
+
+ Args:
+ img (ndarray): Image to be translated with format
+ (h, w) or (h, w, c).
+ offset (int | float): The offset used for translate.
+ direction (str): The translate direction, either "horizontal"
+ or "vertical".
+ border_value (int | tuple[int]): Value used in case of a
+ constant border.
+ interpolation (str): Same as :func:`resize`.
+
+ Returns:
+ ndarray: The translated image.
+ """
+ assert direction in ['horizontal',
+ 'vertical'], f'Invalid direction: {direction}'
+ height, width = img.shape[:2]
+ if img.ndim == 2:
+ channels = 1
+ elif img.ndim == 3:
+ channels = img.shape[-1]
+ if isinstance(border_value, int):
+ border_value = tuple([border_value] * channels)
+ elif isinstance(border_value, tuple):
+ assert len(border_value) == channels, \
+ 'Expected the num of elements in tuple equals the channels' \
+ 'of input image. Found {} vs {}'.format(
+ len(border_value), channels)
+ else:
+ raise ValueError(
+ f'Invalid type {type(border_value)} for `border_value`.')
+ translate_matrix = _get_translate_matrix(offset, direction)
+ translated = cv2.warpAffine(
+ img,
+ translate_matrix,
+ (width, height),
+ # Note case when the number elements in `border_value`
+ # greater than 3 (e.g. translating masks whose channels
+ # large than 3) will raise TypeError in `cv2.warpAffine`.
+ # Here simply slice the first 3 values in `border_value`.
+ borderValue=border_value[:3],
+ flags=cv2_interp_codes[interpolation])
+ return translated
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/image/io.py b/cv/ocr/dbnet/pytorch/dbnet_cv/image/io.py
new file mode 100755
index 0000000000000000000000000000000000000000..ba0b7918a3b4cd132b060fddf285a8b856f5fad3
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/image/io.py
@@ -0,0 +1,314 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import io
+import os.path as osp
+import warnings
+from pathlib import Path
+
+import cv2
+import numpy as np
+from cv2 import (IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_IGNORE_ORIENTATION,
+ IMREAD_UNCHANGED)
+
+from dbnet_cv.fileio import FileClient
+from dbnet_cv.utils import is_filepath, is_str
+
+try:
+ from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG
+except ImportError:
+ TJCS_RGB = TJPF_GRAY = TJPF_BGR = TurboJPEG = None
+
+try:
+ from PIL import Image, ImageOps
+except ImportError:
+ Image = None
+
+try:
+ import tifffile
+except ImportError:
+ tifffile = None
+
+jpeg = None
+supported_backends = ['cv2', 'turbojpeg', 'pillow', 'tifffile']
+
+imread_flags = {
+ 'color': IMREAD_COLOR,
+ 'grayscale': IMREAD_GRAYSCALE,
+ 'unchanged': IMREAD_UNCHANGED,
+ 'color_ignore_orientation': IMREAD_IGNORE_ORIENTATION | IMREAD_COLOR,
+ 'grayscale_ignore_orientation':
+ IMREAD_IGNORE_ORIENTATION | IMREAD_GRAYSCALE
+}
+
+imread_backend = 'cv2'
+
+
+def use_backend(backend):
+ """Select a backend for image decoding.
+
+ Args:
+ backend (str): The image decoding backend type. Options are `cv2`,
+ `pillow`, `turbojpeg` (see https://github.com/lilohuang/PyTurboJPEG)
+ and `tifffile`. `turbojpeg` is faster but it only supports `.jpeg`
+ file format.
+ """
+ assert backend in supported_backends
+ global imread_backend
+ imread_backend = backend
+ if imread_backend == 'turbojpeg':
+ if TurboJPEG is None:
+ raise ImportError('`PyTurboJPEG` is not installed')
+ global jpeg
+ if jpeg is None:
+ jpeg = TurboJPEG()
+ elif imread_backend == 'pillow':
+ if Image is None:
+ raise ImportError('`Pillow` is not installed')
+ elif imread_backend == 'tifffile':
+ if tifffile is None:
+ raise ImportError('`tifffile` is not installed')
+
+
+def _jpegflag(flag='color', channel_order='bgr'):
+ channel_order = channel_order.lower()
+ if channel_order not in ['rgb', 'bgr']:
+ raise ValueError('channel order must be either "rgb" or "bgr"')
+
+ if flag == 'color':
+ if channel_order == 'bgr':
+ return TJPF_BGR
+ elif channel_order == 'rgb':
+ return TJCS_RGB
+ elif flag == 'grayscale':
+ return TJPF_GRAY
+ else:
+ raise ValueError('flag must be "color" or "grayscale"')
+
+
+def _pillow2array(img, flag='color', channel_order='bgr'):
+ """Convert a pillow image to numpy array.
+
+ Args:
+ img (:obj:`PIL.Image.Image`): The image loaded using PIL
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are 'color', 'grayscale' and 'unchanged'.
+ Default to 'color'.
+ channel_order (str): The channel order of the output image array,
+ candidates are 'bgr' and 'rgb'. Default to 'bgr'.
+
+ Returns:
+ np.ndarray: The converted numpy array
+ """
+ channel_order = channel_order.lower()
+ if channel_order not in ['rgb', 'bgr']:
+ raise ValueError('channel order must be either "rgb" or "bgr"')
+
+ if flag == 'unchanged':
+ array = np.array(img)
+ if array.ndim >= 3 and array.shape[2] >= 3: # color image
+ array[:, :, :3] = array[:, :, (2, 1, 0)] # RGB to BGR
+ else:
+ # Handle exif orientation tag
+ if flag in ['color', 'grayscale']:
+ img = ImageOps.exif_transpose(img)
+ # If the image mode is not 'RGB', convert it to 'RGB' first.
+ if img.mode != 'RGB':
+ if img.mode != 'LA':
+ # Most formats except 'LA' can be directly converted to RGB
+ img = img.convert('RGB')
+ else:
+ # When the mode is 'LA', the default conversion will fill in
+ # the canvas with black, which sometimes shadows black objects
+ # in the foreground.
+ #
+ # Therefore, a random color (124, 117, 104) is used for canvas
+ img_rgba = img.convert('RGBA')
+ img = Image.new('RGB', img_rgba.size, (124, 117, 104))
+ img.paste(img_rgba, mask=img_rgba.split()[3]) # 3 is alpha
+ if flag in ['color', 'color_ignore_orientation']:
+ array = np.array(img)
+ if channel_order != 'rgb':
+ array = array[:, :, ::-1] # RGB to BGR
+ elif flag in ['grayscale', 'grayscale_ignore_orientation']:
+ img = img.convert('L')
+ array = np.array(img)
+ else:
+ raise ValueError(
+ 'flag must be "color", "grayscale", "unchanged", '
+ f'"color_ignore_orientation" or "grayscale_ignore_orientation"'
+ f' but got {flag}')
+ return array
+
+
+def imread(img_or_path,
+ flag='color',
+ channel_order='bgr',
+ backend=None,
+ file_client_args=None):
+ """Read an image.
+
+ Note:
+ In v1.4.1 and later, add `file_client_args` parameters.
+
+ Args:
+ img_or_path (ndarray or str or Path): Either a numpy array or str or
+ pathlib.Path. If it is a numpy array (loaded image), then
+ it will be returned as is.
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are `color`, `grayscale`, `unchanged`,
+ `color_ignore_orientation` and `grayscale_ignore_orientation`.
+ By default, `cv2` and `pillow` backend would rotate the image
+ according to its EXIF info unless called with `unchanged` or
+ `*_ignore_orientation` flags. `turbojpeg` and `tifffile` backend
+ always ignore image's EXIF info regardless of the flag.
+ The `turbojpeg` backend only supports `color` and `grayscale`.
+ channel_order (str): Order of channel, candidates are `bgr` and `rgb`.
+ backend (str | None): The image decoding backend type. Options are
+ `cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`.
+ If backend is None, the global imread_backend specified by
+ ``dbnet_cv.use_backend()`` will be used. Default: None.
+ file_client_args (dict | None): Arguments to instantiate a
+ FileClient. See :class:`dbnet_cv.fileio.FileClient` for details.
+ Default: None.
+
+ Returns:
+ ndarray: Loaded image array.
+
+ Examples:
+ >>> import dbnet_cv
+ >>> img_path = '/path/to/img.jpg'
+ >>> img = dbnet_cv.imread(img_path)
+ >>> img = dbnet_cv.imread(img_path, flag='color', channel_order='rgb',
+ ... backend='cv2')
+ >>> img = dbnet_cv.imread(img_path, flag='color', channel_order='bgr',
+ ... backend='pillow')
+ >>> s3_img_path = 's3://bucket/img.jpg'
+ >>> # infer the file backend by the prefix s3
+ >>> img = dbnet_cv.imread(s3_img_path)
+ >>> # manually set the file backend petrel
+ >>> img = dbnet_cv.imread(s3_img_path, file_client_args={
+ ... 'backend': 'petrel'})
+ >>> http_img_path = 'http://path/to/img.jpg'
+ >>> img = dbnet_cv.imread(http_img_path)
+ >>> img = dbnet_cv.imread(http_img_path, file_client_args={
+ ... 'backend': 'http'})
+ """
+
+ if isinstance(img_or_path, Path):
+ img_or_path = str(img_or_path)
+
+ if isinstance(img_or_path, np.ndarray):
+ return img_or_path
+ elif is_str(img_or_path):
+ file_client = FileClient.infer_client(file_client_args, img_or_path)
+ img_bytes = file_client.get(img_or_path)
+ return imfrombytes(img_bytes, flag, channel_order, backend)
+ else:
+ raise TypeError('"img" must be a numpy array or a str or '
+ 'a pathlib.Path object')
+
+
+def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
+ """Read an image from bytes.
+
+ Args:
+ content (bytes): Image bytes got from files or other streams.
+ flag (str): Same as :func:`imread`.
+ channel_order (str): The channel order of the output, candidates
+ are 'bgr' and 'rgb'. Default to 'bgr'.
+ backend (str | None): The image decoding backend type. Options are
+ `cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`. If backend is
+ None, the global imread_backend specified by ``dbnet_cv.use_backend()``
+ will be used. Default: None.
+
+ Returns:
+ ndarray: Loaded image array.
+
+ Examples:
+ >>> img_path = '/path/to/img.jpg'
+ >>> with open(img_path, 'rb') as f:
+ >>> img_buff = f.read()
+ >>> img = dbnet_cv.imfrombytes(img_buff)
+ >>> img = dbnet_cv.imfrombytes(img_buff, flag='color', channel_order='rgb')
+ >>> img = dbnet_cv.imfrombytes(img_buff, backend='pillow')
+ >>> img = dbnet_cv.imfrombytes(img_buff, backend='cv2')
+ """
+
+ if backend is None:
+ backend = imread_backend
+ if backend not in supported_backends:
+ raise ValueError(
+ f'backend: {backend} is not supported. Supported '
+ "backends are 'cv2', 'turbojpeg', 'pillow', 'tifffile'")
+ if backend == 'turbojpeg':
+ img = jpeg.decode(content, _jpegflag(flag, channel_order))
+ if img.shape[-1] == 1:
+ img = img[:, :, 0]
+ return img
+ elif backend == 'pillow':
+ with io.BytesIO(content) as buff:
+ img = Image.open(buff)
+ img = _pillow2array(img, flag, channel_order)
+ return img
+ elif backend == 'tifffile':
+ with io.BytesIO(content) as buff:
+ img = tifffile.imread(buff)
+ return img
+ else:
+ img_np = np.frombuffer(content, np.uint8)
+ flag = imread_flags[flag] if is_str(flag) else flag
+ img = cv2.imdecode(img_np, flag)
+ if flag == IMREAD_COLOR and channel_order == 'rgb':
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
+ return img
+
+
+def imwrite(img,
+ file_path,
+ params=None,
+ auto_mkdir=None,
+ file_client_args=None):
+ """Write image to file.
+
+ Note:
+ In v1.4.1 and later, add `file_client_args` parameters.
+
+ Warning:
+ The parameter `auto_mkdir` will be deprecated in the future and every
+ file clients will make directory automatically.
+
+ Args:
+ img (ndarray): Image array to be written.
+ file_path (str): Image file path.
+ params (None or list): Same as opencv :func:`imwrite` interface.
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+ whether to create it automatically. It will be deprecated.
+ file_client_args (dict | None): Arguments to instantiate a
+ FileClient. See :class:`dbnet_cv.fileio.FileClient` for details.
+ Default: None.
+
+ Returns:
+ bool: Successful or not.
+
+ Examples:
+ >>> # write to hard disk client
+ >>> ret = dbnet_cv.imwrite(img, '/path/to/img.jpg')
+ >>> # infer the file backend by the prefix s3
+ >>> ret = dbnet_cv.imwrite(img, 's3://bucket/img.jpg')
+ >>> # manually set the file backend petrel
+ >>> ret = dbnet_cv.imwrite(img, 's3://bucket/img.jpg', file_client_args={
+ ... 'backend': 'petrel'})
+ """
+ assert is_filepath(file_path)
+ file_path = str(file_path)
+ if auto_mkdir is not None:
+ warnings.warn(
+ 'The parameter `auto_mkdir` will be deprecated in the future and '
+ 'every file clients will make directory automatically.')
+ file_client = FileClient.infer_client(file_client_args, file_path)
+ img_ext = osp.splitext(file_path)[-1]
+ # Encode image according to image suffix.
+ # For example, if image path is '/path/your/img.jpg', the encode
+ # format is '.jpg'.
+ flag, img_buff = cv2.imencode(img_ext, img, params)
+ file_client.put(img_buff.tobytes(), file_path)
+ return flag
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/image/misc.py b/cv/ocr/dbnet/pytorch/dbnet_cv/image/misc.py
new file mode 100755
index 0000000000000000000000000000000000000000..468f44c628057a39c1a68673bf11a87b7777480e
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/image/misc.py
@@ -0,0 +1,53 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+
+import dbnet_cv
+
+try:
+ import torch
+except ImportError:
+ torch = None
+
+
+def tensor2imgs(tensor, mean=None, std=None, to_rgb=True):
+ """Convert tensor to 3-channel images or 1-channel gray images.
+
+ Args:
+ tensor (torch.Tensor): Tensor that contains multiple images, shape (
+ N, C, H, W). :math:`C` can be either 3 or 1.
+ mean (tuple[float], optional): Mean of images. If None,
+ (0, 0, 0) will be used for tensor with 3-channel,
+ while (0, ) for tensor with 1-channel. Defaults to None.
+ std (tuple[float], optional): Standard deviation of images. If None,
+ (1, 1, 1) will be used for tensor with 3-channel,
+ while (1, ) for tensor with 1-channel. Defaults to None.
+ to_rgb (bool, optional): Whether the tensor was converted to RGB
+ format in the first place. If so, convert it back to BGR.
+ For the tensor with 1 channel, it must be False. Defaults to True.
+
+ Returns:
+ list[np.ndarray]: A list that contains multiple images.
+ """
+
+ if torch is None:
+ raise RuntimeError('pytorch is not installed')
+ assert torch.is_tensor(tensor) and tensor.ndim == 4
+ channels = tensor.size(1)
+ assert channels in [1, 3]
+ if mean is None:
+ mean = (0, ) * channels
+ if std is None:
+ std = (1, ) * channels
+ assert (channels == len(mean) == len(std) == 3) or \
+ (channels == len(mean) == len(std) == 1 and not to_rgb)
+
+ num_imgs = tensor.size(0)
+ mean = np.array(mean, dtype=np.float32)
+ std = np.array(std, dtype=np.float32)
+ imgs = []
+ for img_id in range(num_imgs):
+ img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0)
+ img = dbnet_cv.imdenormalize(
+ img, mean, std, to_bgr=to_rgb).astype(np.uint8)
+ imgs.append(np.ascontiguousarray(img))
+ return imgs
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/image/photometric.py b/cv/ocr/dbnet/pytorch/dbnet_cv/image/photometric.py
new file mode 100755
index 0000000000000000000000000000000000000000..bba818f6988f8c318645c2794d4ea51d2b1cf5c9
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/image/photometric.py
@@ -0,0 +1,471 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+
+from ..utils import is_tuple_of
+from .colorspace import bgr2gray, gray2bgr
+
+
+def imnormalize(img, mean, std, to_rgb=True):
+ """Normalize an image with mean and std.
+
+ Args:
+ img (ndarray): Image to be normalized.
+ mean (ndarray): The mean to be used for normalize.
+ std (ndarray): The std to be used for normalize.
+ to_rgb (bool): Whether to convert to rgb.
+
+ Returns:
+ ndarray: The normalized image.
+ """
+ img = img.copy().astype(np.float32)
+ return imnormalize_(img, mean, std, to_rgb)
+
+
+def imnormalize_(img, mean, std, to_rgb=True):
+ """Inplace normalize an image with mean and std.
+
+ Args:
+ img (ndarray): Image to be normalized.
+ mean (ndarray): The mean to be used for normalize.
+ std (ndarray): The std to be used for normalize.
+ to_rgb (bool): Whether to convert to rgb.
+
+ Returns:
+ ndarray: The normalized image.
+ """
+ # cv2 inplace normalization does not accept uint8
+ assert img.dtype != np.uint8
+ mean = np.float64(mean.reshape(1, -1))
+ stdinv = 1 / np.float64(std.reshape(1, -1))
+ if to_rgb:
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
+ cv2.subtract(img, mean, img) # inplace
+ cv2.multiply(img, stdinv, img) # inplace
+ return img
+
+
+def imdenormalize(img, mean, std, to_bgr=True):
+ assert img.dtype != np.uint8
+ mean = mean.reshape(1, -1).astype(np.float64)
+ std = std.reshape(1, -1).astype(np.float64)
+ img = cv2.multiply(img, std) # make a copy
+ cv2.add(img, mean, img) # inplace
+ if to_bgr:
+ cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace
+ return img
+
+
+def iminvert(img):
+ """Invert (negate) an image.
+
+ Args:
+ img (ndarray): Image to be inverted.
+
+ Returns:
+ ndarray: The inverted image.
+ """
+ return np.full_like(img, 255) - img
+
+
+def solarize(img, thr=128):
+ """Solarize an image (invert all pixel values above a threshold)
+
+ Args:
+ img (ndarray): Image to be solarized.
+ thr (int): Threshold for solarizing (0 - 255).
+
+ Returns:
+ ndarray: The solarized image.
+ """
+ img = np.where(img < thr, img, 255 - img)
+ return img
+
+
+def posterize(img, bits):
+ """Posterize an image (reduce the number of bits for each color channel)
+
+ Args:
+ img (ndarray): Image to be posterized.
+ bits (int): Number of bits (1 to 8) to use for posterizing.
+
+ Returns:
+ ndarray: The posterized image.
+ """
+ shift = 8 - bits
+ img = np.left_shift(np.right_shift(img, shift), shift)
+ return img
+
+
+def adjust_color(img, alpha=1, beta=None, gamma=0):
+ r"""It blends the source image and its gray image:
+
+ .. math::
+ output = img * alpha + gray\_img * beta + gamma
+
+ Args:
+ img (ndarray): The input source image.
+ alpha (int | float): Weight for the source image. Default 1.
+ beta (int | float): Weight for the converted gray image.
+ If None, it's assigned the value (1 - `alpha`).
+ gamma (int | float): Scalar added to each sum.
+ Same as :func:`cv2.addWeighted`. Default 0.
+
+ Returns:
+ ndarray: Colored image which has the same size and dtype as input.
+ """
+ gray_img = bgr2gray(img)
+ gray_img = np.tile(gray_img[..., None], [1, 1, 3])
+ if beta is None:
+ beta = 1 - alpha
+ colored_img = cv2.addWeighted(img, alpha, gray_img, beta, gamma)
+ if not colored_img.dtype == np.uint8:
+ # Note when the dtype of `img` is not the default `np.uint8`
+ # (e.g. np.float32), the value in `colored_img` got from cv2
+ # is not guaranteed to be in range [0, 255], so here clip
+ # is needed.
+ colored_img = np.clip(colored_img, 0, 255)
+ return colored_img
+
+
+def imequalize(img):
+ """Equalize the image histogram.
+
+ This function applies a non-linear mapping to the input image,
+ in order to create a uniform distribution of grayscale values
+ in the output image.
+
+ Args:
+ img (ndarray): Image to be equalized.
+
+ Returns:
+ ndarray: The equalized image.
+ """
+
+ def _scale_channel(im, c):
+ """Scale the data in the corresponding channel."""
+ im = im[:, :, c]
+ # Compute the histogram of the image channel.
+ histo = np.histogram(im, 256, (0, 255))[0]
+ # For computing the step, filter out the nonzeros.
+ nonzero_histo = histo[histo > 0]
+ step = (np.sum(nonzero_histo) - nonzero_histo[-1]) // 255
+ if not step:
+ lut = np.array(range(256))
+ else:
+ # Compute the cumulative sum, shifted by step // 2
+ # and then normalized by step.
+ lut = (np.cumsum(histo) + (step // 2)) // step
+ # Shift lut, prepending with 0.
+ lut = np.concatenate([[0], lut[:-1]], 0)
+ # handle potential integer overflow
+ lut[lut > 255] = 255
+ # If step is zero, return the original image.
+ # Otherwise, index from lut.
+ return np.where(np.equal(step, 0), im, lut[im])
+
+ # Scales each channel independently and then stacks
+ # the result.
+ s1 = _scale_channel(img, 0)
+ s2 = _scale_channel(img, 1)
+ s3 = _scale_channel(img, 2)
+ equalized_img = np.stack([s1, s2, s3], axis=-1)
+ return equalized_img.astype(img.dtype)
+
+
+def adjust_brightness(img, factor=1.):
+ """Adjust image brightness.
+
+ This function controls the brightness of an image. An
+ enhancement factor of 0.0 gives a black image.
+ A factor of 1.0 gives the original image. This function
+ blends the source image and the degenerated black image:
+
+ .. math::
+ output = img * factor + degenerated * (1 - factor)
+
+ Args:
+ img (ndarray): Image to be brightened.
+ factor (float): A value controls the enhancement.
+ Factor 1.0 returns the original image, lower
+ factors mean less color (brightness, contrast,
+ etc), and higher values more. Default 1.
+
+ Returns:
+ ndarray: The brightened image.
+ """
+ degenerated = np.zeros_like(img)
+ # Note manually convert the dtype to np.float32, to
+ # achieve as close results as PIL.ImageEnhance.Brightness.
+ # Set beta=1-factor, and gamma=0
+ brightened_img = cv2.addWeighted(
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
+ 1 - factor, 0)
+ brightened_img = np.clip(brightened_img, 0, 255)
+ return brightened_img.astype(img.dtype)
+
+
+def adjust_contrast(img, factor=1.):
+ """Adjust image contrast.
+
+ This function controls the contrast of an image. An
+ enhancement factor of 0.0 gives a solid grey
+ image. A factor of 1.0 gives the original image. It
+ blends the source image and the degenerated mean image:
+
+ .. math::
+ output = img * factor + degenerated * (1 - factor)
+
+ Args:
+ img (ndarray): Image to be contrasted. BGR order.
+ factor (float): Same as :func:`dbnet_cv.adjust_brightness`.
+
+ Returns:
+ ndarray: The contrasted image.
+ """
+ gray_img = bgr2gray(img)
+ hist = np.histogram(gray_img, 256, (0, 255))[0]
+ mean = round(np.sum(gray_img) / np.sum(hist))
+ degenerated = (np.ones_like(img[..., 0]) * mean).astype(img.dtype)
+ degenerated = gray2bgr(degenerated)
+ contrasted_img = cv2.addWeighted(
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
+ 1 - factor, 0)
+ contrasted_img = np.clip(contrasted_img, 0, 255)
+ return contrasted_img.astype(img.dtype)
+
+
+def auto_contrast(img, cutoff=0):
+ """Auto adjust image contrast.
+
+ This function maximize (normalize) image contrast by first removing cutoff
+ percent of the lightest and darkest pixels from the histogram and remapping
+ the image so that the darkest pixel becomes black (0), and the lightest
+ becomes white (255).
+
+ Args:
+ img (ndarray): Image to be contrasted. BGR order.
+ cutoff (int | float | tuple): The cutoff percent of the lightest and
+ darkest pixels to be removed. If given as tuple, it shall be
+ (low, high). Otherwise, the single value will be used for both.
+ Defaults to 0.
+
+ Returns:
+ ndarray: The contrasted image.
+ """
+
+ def _auto_contrast_channel(im, c, cutoff):
+ im = im[:, :, c]
+ # Compute the histogram of the image channel.
+ histo = np.histogram(im, 256, (0, 255))[0]
+ # Remove cut-off percent pixels from histo
+ histo_sum = np.cumsum(histo)
+ cut_low = histo_sum[-1] * cutoff[0] // 100
+ cut_high = histo_sum[-1] - histo_sum[-1] * cutoff[1] // 100
+ histo_sum = np.clip(histo_sum, cut_low, cut_high) - cut_low
+ histo = np.concatenate([[histo_sum[0]], np.diff(histo_sum)], 0)
+
+ # Compute mapping
+ low, high = np.nonzero(histo)[0][0], np.nonzero(histo)[0][-1]
+ # If all the values have been cut off, return the origin img
+ if low >= high:
+ return im
+ scale = 255.0 / (high - low)
+ offset = -low * scale
+ lut = np.array(range(256))
+ lut = lut * scale + offset
+ lut = np.clip(lut, 0, 255)
+ return lut[im]
+
+ if isinstance(cutoff, (int, float)):
+ cutoff = (cutoff, cutoff)
+ else:
+ assert isinstance(cutoff, tuple), 'cutoff must be of type int, ' \
+ f'float or tuple, but got {type(cutoff)} instead.'
+ # Auto adjusts contrast for each channel independently and then stacks
+ # the result.
+ s1 = _auto_contrast_channel(img, 0, cutoff)
+ s2 = _auto_contrast_channel(img, 1, cutoff)
+ s3 = _auto_contrast_channel(img, 2, cutoff)
+ contrasted_img = np.stack([s1, s2, s3], axis=-1)
+ return contrasted_img.astype(img.dtype)
+
+
+def adjust_sharpness(img, factor=1., kernel=None):
+ """Adjust image sharpness.
+
+ This function controls the sharpness of an image. An
+ enhancement factor of 0.0 gives a blurred image. A
+ factor of 1.0 gives the original image. And a factor
+ of 2.0 gives a sharpened image. It blends the source
+ image and the degenerated mean image:
+
+ .. math::
+ output = img * factor + degenerated * (1 - factor)
+
+ Args:
+ img (ndarray): Image to be sharpened. BGR order.
+ factor (float): Same as :func:`dbnet_cv.adjust_brightness`.
+ kernel (np.ndarray, optional): Filter kernel to be applied on the img
+ to obtain the degenerated img. Defaults to None.
+
+ Note:
+ No value sanity check is enforced on the kernel set by users. So with
+ an inappropriate kernel, the ``adjust_sharpness`` may fail to perform
+ the function its name indicates but end up performing whatever
+ transform determined by the kernel.
+
+ Returns:
+ ndarray: The sharpened image.
+ """
+
+ if kernel is None:
+ # adopted from PIL.ImageFilter.SMOOTH
+ kernel = np.array([[1., 1., 1.], [1., 5., 1.], [1., 1., 1.]]) / 13
+ assert isinstance(kernel, np.ndarray), \
+ f'kernel must be of type np.ndarray, but got {type(kernel)} instead.'
+ assert kernel.ndim == 2, \
+ f'kernel must have a dimension of 2, but got {kernel.ndim} instead.'
+
+ degenerated = cv2.filter2D(img, -1, kernel)
+ sharpened_img = cv2.addWeighted(
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
+ 1 - factor, 0)
+ sharpened_img = np.clip(sharpened_img, 0, 255)
+ return sharpened_img.astype(img.dtype)
+
+
+def adjust_lighting(img, eigval, eigvec, alphastd=0.1, to_rgb=True):
+ """AlexNet-style PCA jitter.
+
+ This data augmentation is proposed in `ImageNet Classification with Deep
+ Convolutional Neural Networks
+ `_.
+
+ Args:
+ img (ndarray): Image to be adjusted lighting. BGR order.
+ eigval (ndarray): the eigenvalue of the convariance matrix of pixel
+ values, respectively.
+ eigvec (ndarray): the eigenvector of the convariance matrix of pixel
+ values, respectively.
+ alphastd (float): The standard deviation for distribution of alpha.
+ Defaults to 0.1
+ to_rgb (bool): Whether to convert img to rgb.
+
+ Returns:
+ ndarray: The adjusted image.
+ """
+ assert isinstance(eigval, np.ndarray) and isinstance(eigvec, np.ndarray), \
+ f'eigval and eigvec should both be of type np.ndarray, got ' \
+ f'{type(eigval)} and {type(eigvec)} instead.'
+
+ assert eigval.ndim == 1 and eigvec.ndim == 2
+ assert eigvec.shape == (3, eigval.shape[0])
+ n_eigval = eigval.shape[0]
+ assert isinstance(alphastd, float), 'alphastd should be of type float, ' \
+ f'got {type(alphastd)} instead.'
+
+ img = img.copy().astype(np.float32)
+ if to_rgb:
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
+
+ alpha = np.random.normal(0, alphastd, n_eigval)
+ alter = eigvec \
+ * np.broadcast_to(alpha.reshape(1, n_eigval), (3, n_eigval)) \
+ * np.broadcast_to(eigval.reshape(1, n_eigval), (3, n_eigval))
+ alter = np.broadcast_to(alter.sum(axis=1).reshape(1, 1, 3), img.shape)
+ img_adjusted = img + alter
+ return img_adjusted
+
+
+def lut_transform(img, lut_table):
+ """Transform array by look-up table.
+
+ The function lut_transform fills the output array with values from the
+ look-up table. Indices of the entries are taken from the input array.
+
+ Args:
+ img (ndarray): Image to be transformed.
+ lut_table (ndarray): look-up table of 256 elements; in case of
+ multi-channel input array, the table should either have a single
+ channel (in this case the same table is used for all channels) or
+ the same number of channels as in the input array.
+
+ Returns:
+ ndarray: The transformed image.
+ """
+ assert isinstance(img, np.ndarray)
+ assert 0 <= np.min(img) and np.max(img) <= 255
+ assert isinstance(lut_table, np.ndarray)
+ assert lut_table.shape == (256, )
+
+ return cv2.LUT(np.array(img, dtype=np.uint8), lut_table)
+
+
+def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)):
+ """Use CLAHE method to process the image.
+
+ See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
+ Graphics Gems, 1994:474-485.` for more information.
+
+ Args:
+ img (ndarray): Image to be processed.
+ clip_limit (float): Threshold for contrast limiting. Default: 40.0.
+ tile_grid_size (tuple[int]): Size of grid for histogram equalization.
+ Input image will be divided into equally sized rectangular tiles.
+ It defines the number of tiles in row and column. Default: (8, 8).
+
+ Returns:
+ ndarray: The processed image.
+ """
+ assert isinstance(img, np.ndarray)
+ assert img.ndim == 2
+ assert isinstance(clip_limit, (float, int))
+ assert is_tuple_of(tile_grid_size, int)
+ assert len(tile_grid_size) == 2
+
+ clahe = cv2.createCLAHE(clip_limit, tile_grid_size)
+ return clahe.apply(np.array(img, dtype=np.uint8))
+
+
+def adjust_hue(img: np.ndarray, hue_factor: float) -> np.ndarray:
+ """Adjust hue of an image.
+
+ The image hue is adjusted by converting the image to HSV and cyclically
+ shifting the intensities in the hue channel (H). The image is then
+ converted back to original image mode.
+
+ `hue_factor` is the amount of shift in H channel and must be in the
+ interval `[-0.5, 0.5]`.
+
+ Modified from
+ https://github.com/pytorch/vision/blob/main/torchvision/
+ transforms/functional.py
+
+ Args:
+ img (ndarray): Image to be adjusted.
+ hue_factor (float): How much to shift the hue channel. Should be in
+ [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
+ HSV space in positive and negative direction respectively.
+ 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
+ with complementary colors while 0 gives the original image.
+
+ Returns:
+ ndarray: Hue adjusted image.
+ """
+
+ if not (-0.5 <= hue_factor <= 0.5):
+ raise ValueError(f'hue_factor:{hue_factor} is not in [-0.5, 0.5].')
+ if not (isinstance(img, np.ndarray) and (img.ndim in {2, 3})):
+ raise TypeError('img should be ndarray with dim=[2 or 3].')
+
+ dtype = img.dtype
+ img = img.astype(np.uint8)
+ hsv_img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV_FULL)
+ h, s, v = cv2.split(hsv_img)
+ h = h.astype(np.uint8)
+ # uint8 addition take cares of rotation across boundaries
+ with np.errstate(over='ignore'):
+ h += np.uint8(hue_factor * 255)
+ hsv_img = cv2.merge([h, s, v])
+ return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2RGB_FULL).astype(dtype)
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/deprecated.json b/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/deprecated.json
new file mode 100755
index 0000000000000000000000000000000000000000..25cf6f28caecc22a77e3136fefa6b8dfc0e6cb5b
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/deprecated.json
@@ -0,0 +1,6 @@
+{
+ "resnet50_caffe": "detectron/resnet50_caffe",
+ "resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr",
+ "resnet101_caffe": "detectron/resnet101_caffe",
+ "resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr"
+}
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/mmcls.json b/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/mmcls.json
new file mode 100755
index 0000000000000000000000000000000000000000..c073a41d0aeb44ee0243f97ecc3558de538f9300
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/mmcls.json
@@ -0,0 +1,59 @@
+{
+ "vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_batch256_imagenet_20210208-4271cd6c.pth",
+ "vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_batch256_imagenet_20210208-4d1d6080.pth",
+ "vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_batch256_imagenet_20210208-db26f1a5.pth",
+ "vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_batch256_imagenet_20210208-e6920e4a.pth",
+ "vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth",
+ "vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_batch256_imagenet_20210207-1a8b7864.pth",
+ "vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_batch256_imagenet_20210208-7e55cd29.pth",
+ "vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth",
+ "resnet18": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth",
+ "resnet34": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth",
+ "resnet50": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth",
+ "resnet101": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth",
+ "resnet152": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.pth",
+ "resnet50_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.pth",
+ "resnet101_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.pth",
+ "resnet152_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.pth",
+ "resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_b32x8_imagenet_20210429-56066e27.pth",
+ "resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_b32x8_imagenet_20210506-e0fa3dd5.pth",
+ "resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_b32x8_imagenet_20210506-23a247d5.pth",
+ "resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_b32x8_imagenet_20210524-927787be.pth",
+ "se-resnet50": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth",
+ "se-resnet101": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet101_batch256_imagenet_20200804-ba5b51d4.pth",
+ "resnest50": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth",
+ "resnest101": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth",
+ "resnest200": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth",
+ "resnest269": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest269_imagenet_converted-59930960.pth",
+ "shufflenet_v1": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v1/shufflenet_v1_batch1024_imagenet_20200804-5d6cec73.pth",
+ "shufflenet_v2": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth",
+ "mobilenet_v2": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth",
+ "mobilenet_v3_small": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_small-8427ecf0.pth",
+ "mobilenet_v3_large": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_large-3ea3c186.pth",
+ "repvgg_A0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_3rdparty_4xb64-coslr-120e_in1k_20210909-883ab98c.pth",
+ "repvgg_A1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A1_3rdparty_4xb64-coslr-120e_in1k_20210909-24003a24.pth",
+ "repvgg_A2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A2_3rdparty_4xb64-coslr-120e_in1k_20210909-97d7695a.pth",
+ "repvgg_B0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B0_3rdparty_4xb64-coslr-120e_in1k_20210909-446375f4.pth",
+ "repvgg_B1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1_3rdparty_4xb64-coslr-120e_in1k_20210909-750cdf67.pth",
+ "repvgg_B1g2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g2_3rdparty_4xb64-coslr-120e_in1k_20210909-344f6422.pth",
+ "repvgg_B1g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g4_3rdparty_4xb64-coslr-120e_in1k_20210909-d4c1a642.pth",
+ "repvgg_B2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2_3rdparty_4xb64-coslr-120e_in1k_20210909-bd6b937c.pth",
+ "repvgg_B2g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-7b7955f0.pth",
+ "repvgg_B3": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-dda968bf.pth",
+ "repvgg_B3g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-4e54846a.pth",
+ "repvgg_D2se": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-D2se_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-cf3139b7.pth",
+ "res2net101_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net101-w26-s4_3rdparty_8xb32_in1k_20210927-870b6c36.pth",
+ "res2net50_w14": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w14-s8_3rdparty_8xb32_in1k_20210927-bc967bf1.pth",
+ "res2net50_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w26-s8_3rdparty_8xb32_in1k_20210927-f547a94b.pth",
+ "swin_tiny": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth",
+ "swin_small": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth",
+ "swin_base": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window7_224_22kto1k-f967f799.pth",
+ "swin_large": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_large_patch4_window7_224_22kto1k-5f0996db.pth",
+ "t2t_vit_t_14": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth",
+ "t2t_vit_t_19": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth",
+ "t2t_vit_t_24": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth",
+ "tnt_small": "https://download.openmmlab.com/mmclassification/v0/tnt/tnt-small-p16_3rdparty_in1k_20210903-c56ee7df.pth",
+ "vit_base_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-98e8652b.pth",
+ "vit_base_p32": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p32_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-9cea8599.pth",
+ "vit_large_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-large-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-b20ba619.pth"
+}
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/open_mmlab.json b/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/open_mmlab.json
new file mode 100755
index 0000000000000000000000000000000000000000..a2a0772ef3511694e4fb8b11349c264e8d41b4d7
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/open_mmlab.json
@@ -0,0 +1,50 @@
+{
+ "vgg16_caffe": "https://download.openmmlab.com/pretrain/third_party/vgg16_caffe-292e1171.pth",
+ "detectron/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth",
+ "detectron2/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_msra-5891d200.pth",
+ "detectron/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_caffe-3ad79236.pth",
+ "detectron2/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_msra-6cc46731.pth",
+ "detectron2/resnext101_32x8d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x8d-1516f1aa.pth",
+ "resnext50_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth",
+ "resnext101_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth",
+ "resnext101_64x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth",
+ "contrib/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth",
+ "detectron/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn-9186a21c.pth",
+ "detectron/resnet101_gn": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn-cac0ab98.pth",
+ "jhu/resnet50_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth",
+ "jhu/resnet101_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth",
+ "jhu/resnext50_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth",
+ "jhu/resnext101_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth",
+ "jhu/resnext50_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth",
+ "jhu/resnext101_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth",
+ "msra/hrnetv2_w18_small": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18_small-b5a04e21.pth",
+ "msra/hrnetv2_w18": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth",
+ "msra/hrnetv2_w32": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth",
+ "msra/hrnetv2_w40": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth",
+ "msra/hrnetv2_w48": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w48-d2186c55.pth",
+ "bninception_caffe": "https://download.openmmlab.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth",
+ "kin400/i3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth",
+ "kin400/nl3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth",
+ "res2net101_v1d_26w_4s": "https://download.openmmlab.com/pretrain/third_party/res2net101_v1d_26w_4s_dbnet_detv2-f0a600f9.pth",
+ "regnetx_400mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_400mf-a5b10d96.pth",
+ "regnetx_800mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth",
+ "regnetx_1.6gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth",
+ "regnetx_3.2gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth",
+ "regnetx_4.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth",
+ "regnetx_6.4gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth",
+ "regnetx_8.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth",
+ "regnetx_12gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth",
+ "resnet18_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet18_v1c-b5776b93.pth",
+ "resnet50_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet50_v1c-2cccc1ad.pth",
+ "resnet101_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet101_v1c-e67eebb6.pth",
+ "mmedit/vgg16": "https://download.openmmlab.com/mmediting/third_party/vgg_state_dict.pth",
+ "mmedit/res34_en_nomixup": "https://download.openmmlab.com/mmediting/third_party/model_best_resnet34_En_nomixup.pth",
+ "mmedit/mobilenet_v2": "https://download.openmmlab.com/mmediting/third_party/mobilenet_v2.pth",
+ "contrib/mobilenet_v3_large": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth",
+ "contrib/mobilenet_v3_small": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_small-47085aa1.pth",
+ "resnest50": "https://download.openmmlab.com/pretrain/third_party/resnest50_d2-7497a55b.pth",
+ "resnest101": "https://download.openmmlab.com/pretrain/third_party/resnest101_d2-f3b931b2.pth",
+ "resnest200": "https://download.openmmlab.com/pretrain/third_party/resnest200_d2-ca88e41f.pth",
+ "darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth",
+ "dbnet_det/mobilenet_v2": "https://download.openmmlab.com/dbnet_detection/v2.0/third_party/mobilenet_v2_batch256_imagenet-ff34753d.pth"
+}
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/torchvision_0.12.json b/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/torchvision_0.12.json
new file mode 100755
index 0000000000000000000000000000000000000000..06defe67484dff91cf6f69109324cb1dd9d64bc3
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/torchvision_0.12.json
@@ -0,0 +1,57 @@
+{
+ "alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
+ "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
+ "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
+ "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
+ "densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth",
+ "efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
+ "efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
+ "efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
+ "efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
+ "efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
+ "efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
+ "efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
+ "efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
+ "googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth",
+ "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
+ "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
+ "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
+ "mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
+ "regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
+ "regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
+ "regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
+ "regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
+ "regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
+ "regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
+ "regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
+ "regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
+ "regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
+ "regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
+ "regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
+ "regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
+ "regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
+ "regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
+ "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
+ "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
+ "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
+ "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
+ "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
+ "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
+ "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
+ "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
+ "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
+ "shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
+ "shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
+ "shufflenetv2_x1.5": null,
+ "shufflenetv2_x2.0": null,
+ "squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
+ "squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
+ "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
+ "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
+ "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
+ "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
+ "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
+ "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
+ "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
+ "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
+}
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..035b051776596dab0909aa041b8f52ad6a6def66
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/__init__.py
@@ -0,0 +1,104 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# from .active_rotated_filter import active_rotated_filter
+# from .assign_score_withk import assign_score_withk
+# from .ball_query import ball_query
+# from .bbox import bbox_overlaps
+# from .border_align import BorderAlign, border_align
+# from .box_iou_rotated import box_iou_rotated
+# from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
+# from .cc_attention import CrissCrossAttention
+# from .chamfer_distance import chamfer_distance
+# from .contour_expand import contour_expand
+# from .convex_iou import convex_giou, convex_iou
+# from .corner_pool import CornerPool
+# from .correlation import Correlation
+from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d
+from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack,
+ ModulatedDeformRoIPoolPack, deform_roi_pool)
+from .deprecated_wrappers import Conv2d_deprecated as Conv2d
+from .deprecated_wrappers import ConvTranspose2d_deprecated as ConvTranspose2d
+from .deprecated_wrappers import Linear_deprecated as Linear
+from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
+# from .diff_iou_rotated import diff_iou_rotated_2d, diff_iou_rotated_3d
+# from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
+# sigmoid_focal_loss, softmax_focal_loss)
+# from .furthest_point_sample import (furthest_point_sample,
+# furthest_point_sample_with_dist)
+# from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu
+# from .gather_points import gather_points
+# from .group_points import GroupAll, QueryAndGroup, grouping_operation
+from .info import (get_compiler_version, get_compiling_cuda_version,
+ get_onnxruntime_op_path)
+# from .iou3d import (boxes_iou3d, boxes_iou_bev, boxes_overlap_bev, nms3d,
+# nms3d_normal, nms_bev, nms_normal_bev)
+# from .knn import knn
+# from .masked_conv import MaskedConv2d, masked_conv2d
+# from .min_area_polygons import min_area_polygons
+from .modulated_deform_conv import (ModulatedDeformConv2d,
+ ModulatedDeformConv2dPack,
+ modulated_deform_conv2d)
+from .multi_scale_deform_attn import MultiScaleDeformableAttention
+# from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms
+# from .pixel_group import pixel_group
+# from .point_sample import (SimpleRoIAlign, point_sample,
+# rel_roi_point_to_rel_img_point)
+# from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
+# points_in_boxes_part)
+# from .points_in_polygons import points_in_polygons
+# from .points_sampler import PointsSampler
+# from .psa_mask import PSAMask
+# from .riroi_align_rotated import RiRoIAlignRotated, riroi_align_rotated
+from .roi_align import RoIAlign, roi_align
+# from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
+from .roi_pool import RoIPool, roi_pool
+# from .roiaware_pool3d import RoIAwarePool3d
+# from .roipoint_pool3d import RoIPointPool3d
+# from .rotated_feature_align import rotated_feature_align
+# from .saconv import SAConv2d
+# from .scatter_points import DynamicScatter, dynamic_scatter
+# from .sparse_conv import (SparseConv2d, SparseConv3d, SparseConvTranspose2d,
+# SparseConvTranspose3d, SparseInverseConv2d,
+# SparseInverseConv3d, SubMConv2d, SubMConv3d)
+# from .sparse_modules import SparseModule, SparseSequential
+# from .sparse_pool import SparseMaxPool2d, SparseMaxPool3d
+# from .sparse_structure import SparseConvTensor, scatter_nd
+from .sync_bn import SyncBatchNorm
+# from .three_interpolate import three_interpolate
+# from .three_nn import three_nn
+# from .tin_shift import TINShift, tin_shift
+# from .upfirdn2d import upfirdn2d
+# from .voxelize import Voxelization, voxelization
+
+# __all__ = [
+# 'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe',
+# 'carafe_naive', 'CornerPool', 'DeformConv2d', 'DeformConv2dPack',
+# 'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack',
+# 'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss',
+# 'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss',
+# 'get_compiler_version', 'get_compiling_cuda_version',
+# 'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d',
+# 'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack',
+# 'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match',
+# 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
+# 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
+# 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
+# 'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
+# 'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query',
+# 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
+# 'rotated_feature_align', 'RiRoIAlignRotated', 'riroi_align_rotated',
+# 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup',
+# 'GroupAll', 'grouping_operation', 'contour_expand', 'three_nn',
+# 'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign',
+# 'border_align', 'gather_points', 'furthest_point_sample',
+# 'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
+# 'boxes_iou3d', 'boxes_iou_bev', 'boxes_overlap_bev', 'nms_bev',
+# 'nms_normal_bev', 'nms3d', 'nms3d_normal', 'Voxelization', 'voxelization',
+# 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d', 'SparseConv2d',
+# 'SparseConv3d', 'SparseConvTranspose2d', 'SparseConvTranspose3d',
+# 'SparseInverseConv2d', 'SparseInverseConv3d', 'SubMConv2d', 'SubMConv3d',
+# 'SparseModule', 'SparseSequential', 'SparseMaxPool2d', 'SparseMaxPool3d',
+# 'SparseConvTensor', 'scatter_nd', 'points_in_boxes_part',
+# 'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons',
+# 'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou',
+# 'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance'
+# ]
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/README.md b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/README.md
new file mode 100755
index 0000000000000000000000000000000000000000..3a979b10b549d6912d11eb556532a1164b212533
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/README.md
@@ -0,0 +1,170 @@
+# Code Structure of CUDA operators
+
+This folder contains all non-python code for DBNET_CV custom ops. Please follow the same architecture if you want to add new ops.
+
+## Directories Tree
+
+```folder
+.
+├── common
+│ ├── box_iou_rotated_utils.hpp
+│ ├── parrots_cpp_helper.hpp
+│ ├── parrots_cuda_helper.hpp
+│ ├── pytorch_cpp_helper.hpp
+│ ├── pytorch_cuda_helper.hpp
+│ ├── pytorch_device_registry.hpp
+│ └── cuda
+│ ├── common_cuda_helper.hpp
+│ ├── parrots_cudawarpfunction.cuh
+│ ├── ...
+│ └── ops_cuda_kernel.cuh
+├── onnxruntime
+│ ├── onnxruntime_register.h
+│ ├── onnxruntime_session_options_config_keys.h
+│ ├── ort_dbnet_cv_utils.h
+│ ├── ...
+│ ├── onnx_ops.h
+│ └── cpu
+│ ├── onnxruntime_register.cpp
+│ ├── ...
+│ └── onnx_ops_impl.cpp
+├── parrots
+│ ├── ...
+│ ├── ops.cpp
+│ ├── ops_parrots.cpp
+│ └── ops_pytorch.h
+├── pytorch
+│ ├── info.cpp
+│ ├── pybind.cpp
+│ ├── ...
+│ ├── ops.cpp
+│ ├── cuda
+│ │ ├── ...
+│ │ └── ops_cuda.cu
+│ └── cpu
+│ ├── ...
+│ └── ops.cpp
+└── tensorrt
+ ├── trt_cuda_helper.cuh
+ ├── trt_plugin_helper.hpp
+ ├── trt_plugin.hpp
+ ├── trt_serialize.hpp
+ ├── ...
+ ├── trt_ops.hpp
+ └── plugins
+ ├── trt_cuda_helper.cu
+ ├── trt_plugin.cpp
+ ├── ...
+ ├── trt_ops.cpp
+ └── trt_ops_kernel.cu
+```
+
+## Components
+
+- `common`: This directory contains all tools and shared codes.
+ - `cuda`: The cuda kernels which can be shared by all backends. **HIP** kernel is also here since they have similar syntax.
+- `onnxruntime`: **ONNX Runtime** support for custom ops.
+ - `cpu`: CPU implementation of supported ops.
+- `parrots`: **Parrots** is a deep learning frame for model training and inference. Parrots custom ops are placed in this directory.
+- `pytorch`: **PyTorch** custom ops are supported by binding C++ to Python with **pybind11**. The ops implementation and binding codes are placed in this directory.
+ - `cuda`: This directory contains cuda kernel launchers, which feed memory pointers of tensor to the cuda kernel in `common/cuda`. The launchers provide c++ interface of cuda implementation of corresponding custom ops.
+ - `cpu`: This directory contain cpu implementations of corresponding custom ops.
+- `tensorrt`: **TensorRT** support for custom ops.
+ - `plugins`: This directory contains the implementation of the supported custom ops. Some ops might also use shared cuda kernel in `common/cuda`.
+
+## How to add new PyTorch ops?
+
+1. (Optional) Add shared kernel in `common` to support special hardware platform.
+
+ ```c++
+ // src/common/cuda/new_ops_cuda_kernel.cuh
+
+ template
+ __global__ void new_ops_forward_cuda_kernel(const T* input, T* output, ...) {
+ // forward here
+ }
+
+ ```
+
+ Add cuda kernel launcher in `pytorch/cuda`.
+
+ ```c++
+ // src/pytorch/cuda
+ #include
+
+ void NewOpsForwardCUDAKernelLauncher(Tensor input, Tensor output, ...){
+ // initialize
+ at::cuda::CUDAGuard device_guard(input.device());
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ ...
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ input.scalar_type(), "new_ops_forward_cuda_kernel", ([&] {
+ new_ops_forward_cuda_kernel
+ <<>>(
+ input.data_ptr(), output.data_ptr(),...);
+ }));
+ AT_CUDA_CHECK(cudaGetLastError());
+ }
+ ```
+
+2. Register implementation for different devices.
+
+ ```c++
+ // src/pytorch/cuda/cudabind.cpp
+ ...
+
+ Tensor new_ops_forward_cuda(Tensor input, Tensor output, ...){
+ // implement cuda forward here
+ // use `NewOpsForwardCUDAKernelLauncher` here
+ }
+ // declare interface here.
+ Tensor new_ops_forward_impl(Tensor input, Tensor output, ...);
+ // register the implementation for given device (CUDA here).
+ REGISTER_DEVICE_IMPL(new_ops_forward_impl, CUDA, new_ops_forward_cuda);
+ ```
+
+3. Add ops implementation in `pytorch` directory. Select different implementations according to device type.
+
+ ```c++
+ // src/pytorch/new_ops.cpp
+ Tensor new_ops_forward_impl(Tensor input, Tensor output, ...){
+ // dispatch the implementation according to the device type of input.
+ DISPATCH_DEVICE_IMPL(new_ops_forward_impl, input, output, ...);
+ }
+ ...
+
+ Tensor new_ops_forward(Tensor input, Tensor output, ...){
+ return new_ops_forward_impl(input, output, ...);
+ }
+ ```
+
+4. Binding the implementation in `pytorch/pybind.cpp`
+
+ ```c++
+ // src/pytorch/pybind.cpp
+
+ ...
+
+ Tensor new_ops_forward(Tensor input, Tensor output, ...);
+
+ ...
+
+ // bind with pybind11
+ m.def("new_ops_forward", &new_ops_forward, "new_ops_forward",
+ py::arg("input"), py::arg("output"), ...);
+
+ ...
+
+ ```
+
+5. Build DBNET_CV again. Enjoy new ops in python
+
+ ```python
+ from ..utils import ext_loader
+ ext_module = ext_loader.load_ext('_ext', ['new_ops_forward'])
+
+ ...
+
+ ext_module.new_ops_forward(input, output, ...)
+
+ ```
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/common_cuda_helper.hpp b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/common_cuda_helper.hpp
new file mode 100755
index 0000000000000000000000000000000000000000..e18036bac56d401760be7e72f01c1abf241af1ef
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/common_cuda_helper.hpp
@@ -0,0 +1,120 @@
+#ifndef COMMON_CUDA_HELPER
+#define COMMON_CUDA_HELPER
+
+#include
+
+#define CUDA_1D_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+ i += blockDim.x * gridDim.x)
+
+#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
+ for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+ i += blockDim.x * gridDim.x) \
+ for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); \
+ j += blockDim.y * gridDim.y)
+
+#define CUDA_2D_KERNEL_BLOCK_LOOP(i, n, j, m) \
+ for (size_t i = blockIdx.x; i < (n); i += gridDim.x) \
+ for (size_t j = blockIdx.y; j < (m); j += gridDim.y)
+
+#define THREADS_PER_BLOCK 512
+
+inline int GET_BLOCKS(const int N, const int num_threads = THREADS_PER_BLOCK) {
+ int optimal_block_num = (N + num_threads - 1) / num_threads;
+ int max_block_num = 4096;
+ return std::min(optimal_block_num, max_block_num);
+}
+
+template
+__device__ T bilinear_interpolate(const T* input, const int height,
+ const int width, T y, T x,
+ const int index /* index for debug only*/) {
+ // deal with cases that inverse elements are out of feature map boundary
+ if (y < -1.0 || y > height || x < -1.0 || x > width) return 0;
+
+ if (y <= 0) y = 0;
+ if (x <= 0) x = 0;
+
+ int y_low = (int)y;
+ int x_low = (int)x;
+ int y_high;
+ int x_high;
+
+ if (y_low >= height - 1) {
+ y_high = y_low = height - 1;
+ y = (T)y_low;
+ } else {
+ y_high = y_low + 1;
+ }
+
+ if (x_low >= width - 1) {
+ x_high = x_low = width - 1;
+ x = (T)x_low;
+ } else {
+ x_high = x_low + 1;
+ }
+
+ T ly = y - y_low;
+ T lx = x - x_low;
+ T hy = 1. - ly, hx = 1. - lx;
+ // do bilinear interpolation
+ T v1 = input[y_low * width + x_low];
+ T v2 = input[y_low * width + x_high];
+ T v3 = input[y_high * width + x_low];
+ T v4 = input[y_high * width + x_high];
+ T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+ T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+
+ return val;
+}
+
+template
+__device__ void bilinear_interpolate_gradient(
+ const int height, const int width, T y, T x, T& w1, T& w2, T& w3, T& w4,
+ int& x_low, int& x_high, int& y_low, int& y_high,
+ const int index /* index for debug only*/) {
+ // deal with cases that inverse elements are out of feature map boundary
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
+ // empty
+ w1 = w2 = w3 = w4 = 0.;
+ x_low = x_high = y_low = y_high = -1;
+ return;
+ }
+
+ if (y <= 0) y = 0;
+ if (x <= 0) x = 0;
+
+ y_low = (int)y;
+ x_low = (int)x;
+
+ if (y_low >= height - 1) {
+ y_high = y_low = height - 1;
+ y = (T)y_low;
+ } else {
+ y_high = y_low + 1;
+ }
+
+ if (x_low >= width - 1) {
+ x_high = x_low = width - 1;
+ x = (T)x_low;
+ } else {
+ x_high = x_low + 1;
+ }
+
+ T ly = y - y_low;
+ T lx = x - x_low;
+ T hy = 1. - ly, hx = 1. - lx;
+
+ // reference in forward
+ // T v1 = input[y_low * width + x_low];
+ // T v2 = input[y_low * width + x_high];
+ // T v3 = input[y_high * width + x_low];
+ // T v4 = input[y_high * width + x_high];
+ // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+
+ w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+ return;
+}
+#endif // COMMON_CUDA_HELPER
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/deform_conv_cuda_kernel.cuh b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/deform_conv_cuda_kernel.cuh
new file mode 100755
index 0000000000000000000000000000000000000000..9e864bb8bada51caf485a3d06b8fe13c9e5dda34
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/deform_conv_cuda_kernel.cuh
@@ -0,0 +1,367 @@
+/*!
+ ******************* BEGIN Caffe Copyright Notice and Disclaimer
+ *****************
+ *
+ * COPYRIGHT
+ *
+ * All contributions by the University of California:
+ * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
+ * All rights reserved.
+ *
+ * All other contributions:
+ * Copyright (c) 2014-2017, the respective contributors
+ * All rights reserved.
+ *
+ * Caffe uses a shared copyright model: each contributor holds copyright over
+ * their contributions to Caffe. The project versioning records all such
+ * contribution and copyright details. If a contributor wants to further mark
+ * their specific copyright on a particular contribution, they should indicate
+ * their copyright solely in the commit message of the change when it is
+ * committed.
+ *
+ * LICENSE
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice,
+ *this list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
+ *FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ *DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ *SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ *CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ *OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ *OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * CONTRIBUTION AGREEMENT
+ *
+ * By contributing to the BVLC/caffe repository through pull-request, comment,
+ * or otherwise, the contributor releases their content to the
+ * license and copyright terms herein.
+ *
+ ***************** END Caffe Copyright Notice and Disclaimer
+ *********************
+ *
+ * Copyright (c) 2018 Microsoft
+ * Licensed under The MIT License [see LICENSE for details]
+ * \file modulated_deformable_im2col.cuh
+ * \brief Function definitions of converting an image to
+ * column matrix based on kernel, padding, dilation, and offset.
+ * These functions are mainly used in deformable convolution operators.
+ * \ref: https://arxiv.org/abs/1703.06211
+ * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
+ */
+
+// modified from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/dbnet_detection/dbnet_det/ops/dcn/src/deform_conv_cuda_kernel.cu
+
+#ifndef DEFORM_CONV_CUDA_KERNEL_CUH
+#define DEFORM_CONV_CUDA_KERNEL_CUH
+
+#include
+#ifdef DBNET_CV_WITH_TRT
+#include "common_cuda_helper.hpp"
+#else // DBNET_CV_WITH_TRT
+#ifdef DBNET_CV_USE_PARROTS
+#include "parrots_cuda_helper.hpp"
+#else // DBNET_CV_USE_PARROTS
+#include "pytorch_cuda_helper.hpp"
+#endif // DBNET_CV_USE_PARROTS
+#endif // DBNET_CV_WITH_TRT
+
+template
+__device__ T deformable_im2col_bilinear(const T *input, const int data_width,
+ const int height, const int width, T h,
+ T w) {
+ if (h <= -1 || height <= h || w <= -1 || width <= w) {
+ return 0;
+ }
+
+ int h_low = floorf(h);
+ int w_low = floorf(w);
+ int h_high = h_low + 1;
+ int w_high = w_low + 1;
+
+ T lh = h - h_low;
+ T lw = w - w_low;
+ T hh = 1 - lh, hw = 1 - lw;
+
+ T v1 = 0;
+ if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low];
+ T v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ v2 = input[h_low * data_width + w_high];
+ T v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ v3 = input[h_high * data_width + w_low];
+ T v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ v4 = input[h_high * data_width + w_high];
+
+ T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+template
+__device__ T get_gradient_weight(T argmax_h, T argmax_w, const int h,
+ const int w, const int height,
+ const int width) {
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
+ argmax_w >= width) {
+ // empty
+ return 0;
+ }
+
+ int argmax_h_low = floorf(argmax_h);
+ int argmax_w_low = floorf(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ T weight = 0;
+ if (h == argmax_h_low && w == argmax_w_low)
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+ if (h == argmax_h_low && w == argmax_w_high)
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+ if (h == argmax_h_high && w == argmax_w_low)
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+ if (h == argmax_h_high && w == argmax_w_high)
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+ return weight;
+}
+
+template
+__device__ T get_coordinate_weight(T argmax_h, T argmax_w, const int height,
+ const int width, const T *im_data,
+ const int data_width, const int bp_dir) {
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
+ argmax_w >= width) {
+ // empty
+ return 0;
+ }
+
+ int argmax_h_low = floorf(argmax_h);
+ int argmax_w_low = floorf(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ T weight = 0;
+
+ if (bp_dir == 0) {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_w_low + 1 - argmax_w) *
+ im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += -1 * (argmax_w - argmax_w_low) *
+ im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += (argmax_w_low + 1 - argmax_w) *
+ im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_w - argmax_w_low) *
+ im_data[argmax_h_high * data_width + argmax_w_high];
+ } else if (bp_dir == 1) {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h_low + 1 - argmax_h) *
+ im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += (argmax_h_low + 1 - argmax_h) *
+ im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h - argmax_h_low) *
+ im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_h - argmax_h_low) *
+ im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+
+ return weight;
+}
+
+template
+__global__ void deformable_im2col_gpu_kernel(
+ const int n, const T *data_im, const T *data_offset, const int height,
+ const int width, const int kernel_h, const int kernel_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group, const int batch_size,
+ const int num_channels, const int deformable_group, const int height_col,
+ const int width_col, T *data_col) {
+ CUDA_1D_KERNEL_LOOP(index, n) {
+ // index index of output matrix
+ const int w_col = index % width_col;
+ const int h_col = (index / width_col) % height_col;
+ const int b_col = (index / width_col / height_col) % batch_size;
+ const int c_im = (index / width_col / height_col) / batch_size;
+ const int c_col = c_im * kernel_h * kernel_w;
+
+ // compute deformable group index
+ const int deformable_group_index = c_im / channel_per_deformable_group;
+
+ const int h_in = h_col * stride_h - pad_h;
+ const int w_in = w_col * stride_w - pad_w;
+ T *data_col_ptr =
+ data_col +
+ ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+ const T *data_im_ptr =
+ data_im + (b_col * num_channels + c_im) * height * width;
+ const T *data_offset_ptr =
+ data_offset + (b_col * deformable_group + deformable_group_index) * 2 *
+ kernel_h * kernel_w * height_col * width_col;
+
+ for (int i = 0; i < kernel_h; ++i) {
+ for (int j = 0; j < kernel_w; ++j) {
+ const int data_offset_h_ptr =
+ ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+ const int data_offset_w_ptr =
+ ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
+ w_col;
+ const T offset_h = data_offset_ptr[data_offset_h_ptr];
+ const T offset_w = data_offset_ptr[data_offset_w_ptr];
+ T val = static_cast(0);
+ const T h_im = h_in + i * dilation_h + offset_h;
+ const T w_im = w_in + j * dilation_w + offset_w;
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+ val = deformable_im2col_bilinear(data_im_ptr, width, height, width,
+ h_im, w_im);
+ *data_col_ptr = val;
+ data_col_ptr += batch_size * height_col * width_col;
+ }
+ }
+ }
+}
+
+template
+__global__ void deformable_col2im_gpu_kernel(
+ const int n, const T *data_col, const T *data_offset, const int channels,
+ const int height, const int width, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group, const int batch_size,
+ const int deformable_group, const int height_col, const int width_col,
+ T *grad_im) {
+ CUDA_1D_KERNEL_LOOP(index, n) {
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
+ const int i =
+ (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ const int c =
+ index / width_col / height_col / batch_size / kernel_w / kernel_h;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / channel_per_deformable_group;
+
+ int w_out = index % width_col;
+ int h_out = (index / width_col) % height_col;
+ int b = (index / width_col / height_col) % batch_size;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+
+ const T *data_offset_ptr =
+ data_offset + (b * deformable_group + deformable_group_index) * 2 *
+ kernel_h * kernel_w * height_col * width_col;
+ const int data_offset_h_ptr =
+ ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+ const int data_offset_w_ptr =
+ ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+ const T offset_h = data_offset_ptr[data_offset_h_ptr];
+ const T offset_w = data_offset_ptr[data_offset_w_ptr];
+ const T cur_inv_h_data = h_in + i * dilation_h + offset_h;
+ const T cur_inv_w_data = w_in + j * dilation_w + offset_w;
+
+ const T cur_top_grad = data_col[index];
+ const int cur_h = (int)cur_inv_h_data;
+ const int cur_w = (int)cur_inv_w_data;
+ for (int dy = -2; dy <= 2; dy++) {
+ for (int dx = -2; dx <= 2; dx++) {
+ if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 &&
+ cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+ abs(cur_inv_w_data - (cur_w + dx)) < 1) {
+ int cur_bottom_grad_pos =
+ ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+ T weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data,
+ cur_h + dy, cur_w + dx, height, width);
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+ }
+ }
+ }
+ }
+}
+
+template
+__global__ void deformable_col2im_coord_gpu_kernel(
+ const int n, const T *data_col, const T *data_im, const T *data_offset,
+ const int channels, const int height, const int width, const int kernel_h,
+ const int kernel_w, const int pad_h, const int pad_w, const int stride_h,
+ const int stride_w, const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group, const int batch_size,
+ const int offset_channels, const int deformable_group, const int height_col,
+ const int width_col, T *grad_offset) {
+ CUDA_1D_KERNEL_LOOP(index, n) {
+ T val = 0;
+ int w = index % width_col;
+ int h = (index / width_col) % height_col;
+ int c = (index / width_col / height_col) % offset_channels;
+ int b = (index / width_col / height_col) / offset_channels;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+ const int col_step = kernel_h * kernel_w;
+ int cnt = 0;
+ const T *data_col_ptr = data_col + deformable_group_index *
+ channel_per_deformable_group *
+ batch_size * width_col * height_col;
+ const T *data_im_ptr =
+ data_im + (b * deformable_group + deformable_group_index) *
+ channel_per_deformable_group / kernel_h / kernel_w *
+ height * width;
+ const T *data_offset_ptr =
+ data_offset + (b * deformable_group + deformable_group_index) * 2 *
+ kernel_h * kernel_w * height_col * width_col;
+
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group;
+ col_c += col_step) {
+ const int col_pos =
+ (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+ const int bp_dir = offset_c % 2;
+
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+ int i =
+ (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ int w_out = col_pos % width_col;
+ int h_out = (col_pos / width_col) % height_col;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+ const int data_offset_h_ptr =
+ (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+ const int data_offset_w_ptr =
+ (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
+ w_out);
+ const T offset_h = data_offset_ptr[data_offset_h_ptr];
+ const T offset_w = data_offset_ptr[data_offset_w_ptr];
+ T inv_h = h_in + i * dilation_h + offset_h;
+ T inv_w = w_in + j * dilation_w + offset_w;
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+ inv_h = inv_w = -2;
+ const T weight = get_coordinate_weight(inv_h, inv_w, height, width,
+ data_im_ptr + cnt * height * width,
+ width, bp_dir);
+ val += weight * data_col_ptr[col_pos];
+ cnt += 1;
+ }
+
+ grad_offset[index] = val;
+ }
+}
+
+#endif // DEFORM_CONV_CUDA_KERNEL_CUH
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/deform_roi_pool_cuda_kernel.cuh b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/deform_roi_pool_cuda_kernel.cuh
new file mode 100755
index 0000000000000000000000000000000000000000..a1c0520ddb8394ebf6eca012fb701c9c066c1f82
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/deform_roi_pool_cuda_kernel.cuh
@@ -0,0 +1,186 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef DEFORM_ROI_POOL_CUDA_KERNEL_CUH
+#define DEFORM_ROI_POOL_CUDA_KERNEL_CUH
+
+#ifdef DBNET_CV_USE_PARROTS
+#include "parrots_cuda_helper.hpp"
+#else
+#include "pytorch_cuda_helper.hpp"
+#endif
+
+template
+__global__ void deform_roi_pool_forward_cuda_kernel(
+ const int nthreads, const T* input, const T* rois, const T* offset,
+ T* output, const int pooled_height, const int pooled_width,
+ const T spatial_scale, const int sampling_ratio, const T gamma,
+ const int channels, const int height, const int width) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // (n, c, ph, pw) is an element in the pooled output
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+
+ const T* offset_rois = rois + n * 5;
+ int roi_batch_ind = offset_rois[0];
+
+ // Do not using rounding; this implementation detail is critical
+ T roi_start_w = offset_rois[1] * spatial_scale - 0.5;
+ T roi_start_h = offset_rois[2] * spatial_scale - 0.5;
+ T roi_end_w = offset_rois[3] * spatial_scale - 0.5;
+ T roi_end_h = offset_rois[4] * spatial_scale - 0.5;
+
+ T roi_width = roi_end_w - roi_start_w;
+ T roi_height = roi_end_h - roi_start_h;
+
+ T bin_size_h = static_cast(roi_height) / static_cast(pooled_height);
+ T bin_size_w = static_cast(roi_width) / static_cast(pooled_width);
+
+ const T* offset_input =
+ input + (roi_batch_ind * channels + c) * height * width;
+
+ // We use roi_bin_grid to sample the grid and mimic integral
+ int roi_bin_grid_h =
+ (sampling_ratio > 0)
+ ? sampling_ratio
+ : static_cast(ceilf(roi_height / pooled_height));
+ int roi_bin_grid_w =
+ (sampling_ratio > 0)
+ ? sampling_ratio
+ : static_cast(ceilf(roi_width / pooled_width));
+
+ // Compute roi offset
+ if (offset != NULL) {
+ const T* offset_cur_w = offset + n * pooled_width * pooled_height * 2 +
+ ph * pooled_width + pw;
+ T offset_roi_w = gamma * roi_width * offset_cur_w[0];
+ T offset_roi_h =
+ gamma * roi_height * offset_cur_w[pooled_width * pooled_height];
+ roi_start_w += offset_roi_w;
+ roi_start_h += offset_roi_h;
+ }
+
+ // We do average pooling inside a bin
+ const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1);
+ T output_val = 0.;
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+ const T y = roi_start_h + ph * bin_size_h +
+ static_cast(iy + .5f) * bin_size_h /
+ static_cast(roi_bin_grid_h);
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+ const T x = roi_start_w + pw * bin_size_w +
+ static_cast(ix + .5f) * bin_size_w /
+ static_cast(roi_bin_grid_w);
+ T val = bilinear_interpolate(offset_input, height, width, y, x, index);
+ output_val += val;
+ }
+ }
+ output[index] = output_val / count;
+ }
+}
+
+template
+__global__ void deform_roi_pool_backward_cuda_kernel(
+ const int nthreads, const T* grad_output, const T* input, const T* rois,
+ const T* offset, T* grad_input, T* grad_offset, const int pooled_height,
+ const int pooled_width, const T spatial_scale, const int sampling_ratio,
+ const T gamma, const int channels, const int height, const int width) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // (n, c, ph, pw) is an element in the pooled output
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+
+ const T* offset_rois = rois + n * 5;
+ int roi_batch_ind = offset_rois[0];
+ const T* offset_input =
+ input + ((roi_batch_ind * channels + c) * height * width);
+ T* offset_grad_input =
+ grad_input + ((roi_batch_ind * channels + c) * height * width);
+
+ // Do not using rounding; this implementation detail is critical
+ T roi_start_w = offset_rois[1] * spatial_scale - 0.5;
+ T roi_start_h = offset_rois[2] * spatial_scale - 0.5;
+ T roi_end_w = offset_rois[3] * spatial_scale - 0.5;
+ T roi_end_h = offset_rois[4] * spatial_scale - 0.5;
+
+ T roi_width = roi_end_w - roi_start_w;
+ T roi_height = roi_end_h - roi_start_h;
+
+ T bin_size_h = static_cast(roi_height) / static_cast(pooled_height);
+ T bin_size_w = static_cast(roi_width) / static_cast(pooled_width);
+
+ // We use roi_bin_grid to sample the grid and mimic integral
+ int roi_bin_grid_h =
+ (sampling_ratio > 0)
+ ? sampling_ratio
+ : static_cast(ceilf(roi_height / pooled_height));
+ int roi_bin_grid_w =
+ (sampling_ratio > 0)
+ ? sampling_ratio
+ : static_cast(ceilf(roi_width / pooled_width));
+
+ // Compute roi offset
+ if (offset != NULL) {
+ const T* offset_cur_w = offset + n * pooled_width * pooled_height * 2 +
+ ph * pooled_width + pw;
+ T offset_roi_w = gamma * roi_width * offset_cur_w[0];
+ T offset_roi_h =
+ gamma * roi_height * offset_cur_w[pooled_width * pooled_height];
+ roi_start_w += offset_roi_w;
+ roi_start_h += offset_roi_h;
+ }
+
+ // We do average (integral) pooling inside a bin
+ const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
+ const T grad_output_this_bin = grad_output[index] / count;
+
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+ const T y = roi_start_h + ph * bin_size_h +
+ static_cast(iy + .5f) * bin_size_h /
+ static_cast(roi_bin_grid_h);
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+ const T x = roi_start_w + pw * bin_size_w +
+ static_cast(ix + .5f) * bin_size_w /
+ static_cast(roi_bin_grid_w);
+
+ T w1, w2, w3, w4;
+ int x_low, x_high, y_low, y_high;
+ bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
+ x_low, x_high, y_low, y_high, index);
+
+ if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
+ atomicAdd(offset_grad_input + y_low * width + x_low,
+ grad_output_this_bin * w1);
+ atomicAdd(offset_grad_input + y_low * width + x_high,
+ grad_output_this_bin * w2);
+ atomicAdd(offset_grad_input + y_high * width + x_low,
+ grad_output_this_bin * w3);
+ atomicAdd(offset_grad_input + y_high * width + x_high,
+ grad_output_this_bin * w4);
+ if (offset != NULL) {
+ T input_00 = offset_input[y_low * width + x_low];
+ T input_10 = offset_input[y_low * width + x_high];
+ T input_01 = offset_input[y_high * width + x_low];
+ T input_11 = offset_input[y_high * width + x_high];
+ T ogx = gamma * roi_width * grad_output_this_bin *
+ (input_11 * (y - y_low) + input_10 * (y_high - y) +
+ input_01 * (y_low - y) + input_00 * (y - y_high));
+ T ogy = gamma * roi_height * grad_output_this_bin *
+ (input_11 * (x - x_low) + input_01 * (x_high - x) +
+ input_10 * (x_low - x) + input_00 * (x - x_high));
+ atomicAdd(grad_offset + n * pooled_width * pooled_height * 2 +
+ ph * pooled_width + pw,
+ ogx);
+ atomicAdd(grad_offset + n * pooled_width * pooled_height * 2 +
+ pooled_width * pooled_height + ph * pooled_width + pw,
+ ogy);
+ }
+ }
+ }
+ }
+ }
+}
+
+#endif // DEFORM_ROI_POOL_CUDA_KERNEL_CUH
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/modulated_deform_conv_cuda_kernel.cuh b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/modulated_deform_conv_cuda_kernel.cuh
new file mode 100755
index 0000000000000000000000000000000000000000..5046c8d4f9496f790a7768ccdd9558026877ac54
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/modulated_deform_conv_cuda_kernel.cuh
@@ -0,0 +1,399 @@
+/*!
+ ******************* BEGIN Caffe Copyright Notice and Disclaimer
+ *****************
+ *
+ * COPYRIGHT
+ *
+ * All contributions by the University of California:
+ * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
+ * All rights reserved.
+ *
+ * All other contributions:
+ * Copyright (c) 2014-2017, the respective contributors
+ * All rights reserved.
+ *
+ * Caffe uses a shared copyright model: each contributor holds copyright over
+ * their contributions to Caffe. The project versioning records all such
+ * contribution and copyright details. If a contributor wants to further mark
+ * their specific copyright on a particular contribution, they should indicate
+ * their copyright solely in the commit message of the change when it is
+ * committed.
+ *
+ * LICENSE
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice,
+ *this list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
+ *FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ *DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ *SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ *CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ *OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ *OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * CONTRIBUTION AGREEMENT
+ *
+ * By contributing to the BVLC/caffe repository through pull-request, comment,
+ * or otherwise, the contributor releases their content to the
+ * license and copyright terms herein.
+ *
+ ***************** END Caffe Copyright Notice and Disclaimer
+ *********************
+ *
+ * Copyright (c) 2018 Microsoft
+ * Licensed under The MIT License [see LICENSE for details]
+ * \file modulated_deformable_im2col.cuh
+ * \brief Function definitions of converting an image to
+ * column matrix based on kernel, padding, dilation, and offset.
+ * These functions are mainly used in deformable convolution operators.
+ * \ref: https://arxiv.org/abs/1703.06211
+ * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
+ */
+
+// modified from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/dbnet_detection/dbnet_det/ops/dcn/src/deform_conv_cuda_kernel.cu
+
+#ifndef MODULATED_DEFORM_CONV_CUDA_KERNEL_CUH
+#define MODULATED_DEFORM_CONV_CUDA_KERNEL_CUH
+
+#include
+#ifdef DBNET_CV_WITH_TRT
+#include "common_cuda_helper.hpp"
+#else // DBNET_CV_WITH_TRT
+#ifdef DBNET_CV_USE_PARROTS
+#include "parrots_cuda_helper.hpp"
+#else // DBNET_CV_USE_PARROTS
+#include "pytorch_cuda_helper.hpp"
+#endif // DBNET_CV_USE_PARROTS
+#endif // DBNET_CV_WITH_TRT
+
+template
+__device__ T dmcn_im2col_bilinear(const T *input, const int data_width,
+ const int height, const int width, T h, T w) {
+ int h_low = floorf(h);
+ int w_low = floorf(w);
+ int h_high = h_low + 1;
+ int w_high = w_low + 1;
+
+ T lh = h - h_low;
+ T lw = w - w_low;
+ T hh = 1 - lh, hw = 1 - lw;
+
+ T v1 = 0;
+ if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low];
+ T v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ v2 = input[h_low * data_width + w_high];
+ T v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ v3 = input[h_high * data_width + w_low];
+ T v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ v4 = input[h_high * data_width + w_high];
+
+ T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+template
+__device__ T dmcn_get_gradient_weight(T argmax_h, T argmax_w, const int h,
+ const int w, const int height,
+ const int width) {
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
+ argmax_w >= width) {
+ // empty
+ return 0;
+ }
+
+ int argmax_h_low = floorf(argmax_h);
+ int argmax_w_low = floorf(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ T weight = 0;
+ if (h == argmax_h_low && w == argmax_w_low)
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+ if (h == argmax_h_low && w == argmax_w_high)
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+ if (h == argmax_h_high && w == argmax_w_low)
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+ if (h == argmax_h_high && w == argmax_w_high)
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+ return weight;
+}
+
+template
+__device__ T dmcn_get_coordinate_weight(T argmax_h, T argmax_w,
+ const int height, const int width,
+ const T *im_data, const int data_width,
+ const int bp_dir) {
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
+ argmax_w >= width) {
+ // empty
+ return 0;
+ }
+
+ int argmax_h_low = floorf(argmax_h);
+ int argmax_w_low = floorf(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ T weight = 0;
+
+ if (bp_dir == 0) {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_w_low + 1 - argmax_w) *
+ im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += -1 * (argmax_w - argmax_w_low) *
+ im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += (argmax_w_low + 1 - argmax_w) *
+ im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_w - argmax_w_low) *
+ im_data[argmax_h_high * data_width + argmax_w_high];
+ } else if (bp_dir == 1) {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h_low + 1 - argmax_h) *
+ im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += (argmax_h_low + 1 - argmax_h) *
+ im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h - argmax_h_low) *
+ im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_h - argmax_h_low) *
+ im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+
+ return weight;
+}
+
+template
+__global__ void modulated_deformable_im2col_gpu_kernel(
+ const int n, const T *data_im, const T *data_offset, const T *data_mask,
+ const int height, const int width, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group, const int batch_size,
+ const int num_channels, const int deformable_group, const int height_col,
+ const int width_col, T *data_col) {
+ CUDA_1D_KERNEL_LOOP(index, n) {
+ // index index of output matrix
+ const int w_col = index % width_col;
+ const int h_col = (index / width_col) % height_col;
+ const int b_col = (index / width_col / height_col) % batch_size;
+ const int c_im = (index / width_col / height_col) / batch_size;
+ const int c_col = c_im * kernel_h * kernel_w;
+
+ // compute deformable group index
+ const int deformable_group_index = c_im / channel_per_deformable_group;
+
+ const int h_in = h_col * stride_h - pad_h;
+ const int w_in = w_col * stride_w - pad_w;
+
+ T *data_col_ptr =
+ data_col +
+ ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+ const T *data_im_ptr =
+ data_im + (b_col * num_channels + c_im) * height * width;
+ const T *data_offset_ptr =
+ data_offset + (b_col * deformable_group + deformable_group_index) * 2 *
+ kernel_h * kernel_w * height_col * width_col;
+
+ const T *data_mask_ptr =
+ data_mask + (b_col * deformable_group + deformable_group_index) *
+ kernel_h * kernel_w * height_col * width_col;
+
+ for (int i = 0; i < kernel_h; ++i) {
+ for (int j = 0; j < kernel_w; ++j) {
+ const int data_offset_h_ptr =
+ ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+ const int data_offset_w_ptr =
+ ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
+ w_col;
+ const int data_mask_hw_ptr =
+ ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
+ const T offset_h = data_offset_ptr[data_offset_h_ptr];
+ const T offset_w = data_offset_ptr[data_offset_w_ptr];
+ const T mask = data_mask_ptr[data_mask_hw_ptr];
+ T val = static_cast(0);
+ const T h_im = h_in + i * dilation_h + offset_h;
+ const T w_im = w_in + j * dilation_w + offset_w;
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+ val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im,
+ w_im);
+ *data_col_ptr = val * mask;
+ data_col_ptr += batch_size * height_col * width_col;
+ }
+ }
+ }
+}
+
+template
+__global__ void modulated_deformable_col2im_gpu_kernel(
+ const int n, const T *data_col, const T *data_offset, const T *data_mask,
+ const int channels, const int height, const int width, const int kernel_h,
+ const int kernel_w, const int pad_h, const int pad_w, const int stride_h,
+ const int stride_w, const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group, const int batch_size,
+ const int deformable_group, const int height_col, const int width_col,
+ T *grad_im) {
+ CUDA_1D_KERNEL_LOOP(index, n) {
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
+ const int i =
+ (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ const int c =
+ index / width_col / height_col / batch_size / kernel_w / kernel_h;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / channel_per_deformable_group;
+
+ int w_out = index % width_col;
+ int h_out = (index / width_col) % height_col;
+ int b = (index / width_col / height_col) % batch_size;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+
+ const T *data_offset_ptr =
+ data_offset + (b * deformable_group + deformable_group_index) * 2 *
+ kernel_h * kernel_w * height_col * width_col;
+ const T *data_mask_ptr =
+ data_mask + (b * deformable_group + deformable_group_index) * kernel_h *
+ kernel_w * height_col * width_col;
+ const int data_offset_h_ptr =
+ ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+ const int data_offset_w_ptr =
+ ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+ const int data_mask_hw_ptr =
+ ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
+ const T offset_h = data_offset_ptr[data_offset_h_ptr];
+ const T offset_w = data_offset_ptr[data_offset_w_ptr];
+ const T mask = data_mask_ptr[data_mask_hw_ptr];
+ const T cur_inv_h_data = h_in + i * dilation_h + offset_h;
+ const T cur_inv_w_data = w_in + j * dilation_w + offset_w;
+
+ const T cur_top_grad = data_col[index] * mask;
+ const int cur_h = (int)cur_inv_h_data;
+ const int cur_w = (int)cur_inv_w_data;
+ for (int dy = -2; dy <= 2; dy++) {
+ for (int dx = -2; dx <= 2; dx++) {
+ if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 &&
+ cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+ abs(cur_inv_w_data - (cur_w + dx)) < 1) {
+ int cur_bottom_grad_pos =
+ ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+ T weight =
+ dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data,
+ cur_h + dy, cur_w + dx, height, width);
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+ }
+ }
+ }
+ }
+}
+
+template
+__global__ void modulated_deformable_col2im_coord_gpu_kernel(
+ const int n, const T *data_col, const T *data_im, const T *data_offset,
+ const T *data_mask, const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w, const int dilation_h,
+ const int dilation_w, const int channel_per_deformable_group,
+ const int batch_size, const int offset_channels, const int deformable_group,
+ const int height_col, const int width_col, T *grad_offset, T *grad_mask) {
+ CUDA_1D_KERNEL_LOOP(index, n) {
+ T val = 0, mval = 0;
+ int w = index % width_col;
+ int h = (index / width_col) % height_col;
+ int c = (index / width_col / height_col) % offset_channels;
+ int b = (index / width_col / height_col) / offset_channels;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+ const int col_step = kernel_h * kernel_w;
+ int cnt = 0;
+ const T *data_col_ptr = data_col + deformable_group_index *
+ channel_per_deformable_group *
+ batch_size * width_col * height_col;
+ const T *data_im_ptr =
+ data_im + (b * deformable_group + deformable_group_index) *
+ channel_per_deformable_group / kernel_h / kernel_w *
+ height * width;
+ const T *data_offset_ptr =
+ data_offset + (b * deformable_group + deformable_group_index) * 2 *
+ kernel_h * kernel_w * height_col * width_col;
+ const T *data_mask_ptr =
+ data_mask + (b * deformable_group + deformable_group_index) * kernel_h *
+ kernel_w * height_col * width_col;
+
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group;
+ col_c += col_step) {
+ const int col_pos =
+ (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+ const int bp_dir = offset_c % 2;
+
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+ int i =
+ (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ int w_out = col_pos % width_col;
+ int h_out = (col_pos / width_col) % height_col;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+ const int data_offset_h_ptr =
+ (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+ const int data_offset_w_ptr =
+ (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
+ w_out);
+ const int data_mask_hw_ptr =
+ (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
+ const T offset_h = data_offset_ptr[data_offset_h_ptr];
+ const T offset_w = data_offset_ptr[data_offset_w_ptr];
+ const T mask = data_mask_ptr[data_mask_hw_ptr];
+ T inv_h = h_in + i * dilation_h + offset_h;
+ T inv_w = w_in + j * dilation_w + offset_w;
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+ inv_h = inv_w = -2;
+ else
+ mval += data_col_ptr[col_pos] *
+ dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width,
+ height, width, inv_h, inv_w);
+ const T weight = dmcn_get_coordinate_weight(
+ inv_h, inv_w, height, width, data_im_ptr + cnt * height * width,
+ width, bp_dir);
+ val += weight * data_col_ptr[col_pos] * mask;
+ cnt += 1;
+ }
+ // KERNEL_ASSIGN(grad_offset[index], offset_req, val);
+ grad_offset[index] = val;
+ if (offset_c % 2 == 0)
+ // KERNEL_ASSIGN(grad_mask[(((b * deformable_group +
+ // deformable_group_index) * kernel_h * kernel_w + offset_c / 2) *
+ // height_col + h) * width_col + w], mask_req, mval);
+ grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h *
+ kernel_w +
+ offset_c / 2) *
+ height_col +
+ h) *
+ width_col +
+ w] = mval;
+ }
+}
+
+#endif // MODULATED_DEFORM_CONV_CUDA_KERNEL_CUH
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/ms_deform_attn_cuda_kernel.cuh b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/ms_deform_attn_cuda_kernel.cuh
new file mode 100755
index 0000000000000000000000000000000000000000..12225ffdb3b1691ad9edabcd1663109f67ef1a6f
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/ms_deform_attn_cuda_kernel.cuh
@@ -0,0 +1,801 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from
+*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+#ifndef DEFORM_ATTN_CUDA_KERNEL
+#define DEFORM_ATTN_CUDA_KERNEL
+
+#include "common_cuda_helper.hpp"
+#include "pytorch_cuda_helper.hpp"
+
+template
+__device__ scalar_t ms_deform_attn_im2col_bilinear(
+ const scalar_t *&bottom_data, const int &height, const int &width,
+ const int &nheads, const int &channels, const scalar_t &h,
+ const scalar_t &w, const int &m, const int &c) {
+ const int h_low = floorf(h);
+ const int w_low = floorf(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0) {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1) {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0) {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1) {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ }
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+template
+__device__ void ms_deform_attn_col2im_bilinear(
+ const scalar_t *&bottom_data, const int &height, const int &width,
+ const int &nheads, const int &channels, const scalar_t &h,
+ const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad,
+ const scalar_t &attn_weight, scalar_t *&grad_value,
+ scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) {
+ const int h_low = floorf(h);
+ const int w_low = floorf(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0) {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value + ptr1, w1 * top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1) {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value + ptr2, w2 * top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0) {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value + ptr3, w3 * top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1) {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value + ptr4, w4 * top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ *grad_attn_weight = top_grad * val;
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
+}
+
+template
+__device__ void ms_deform_attn_col2im_bilinear_gm(
+ const scalar_t *&bottom_data, const int &height, const int &width,
+ const int &nheads, const int &channels, const scalar_t &h,
+ const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad,
+ const scalar_t &attn_weight, scalar_t *&grad_value,
+ scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) {
+ const int h_low = floorf(h);
+ const int w_low = floorf(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0) {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value + ptr1, w1 * top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1) {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value + ptr2, w2 * top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0) {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value + ptr3, w3 * top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1) {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value + ptr4, w4 * top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ atomicAdd(grad_attn_weight, top_grad * val);
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
+}
+
+template
+__global__ void ms_deformable_im2col_gpu_kernel(
+ const int n, const scalar_t *data_value, const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index, const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight, const int batch_size,
+ const int spatial_size, const int num_heads, const int channels,
+ const int num_levels, const int num_query, const int num_point,
+ scalar_t *data_col) {
+ CUDA_1D_KERNEL_LOOP(index, n) {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ scalar_t *data_col_ptr = data_col + index;
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+ scalar_t col = 0;
+
+ for (int l_col = 0; l_col < num_levels; ++l_col) {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const scalar_t *data_value_ptr =
+ data_value +
+ (data_value_ptr_init_offset + level_start_id * qid_stride);
+ for (int p_col = 0; p_col < num_point; ++p_col) {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h,
+ spatial_w, num_heads, channels,
+ h_im, w_im, m_col, c_col) *
+ weight;
+ }
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ }
+ }
+ *data_col_ptr = col;
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
+ const int n, const scalar_t *grad_col, const scalar_t *data_value,
+ const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
+ const int batch_size, const int spatial_size, const int num_heads,
+ const int channels, const int num_levels, const int num_query,
+ const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight) {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ const int qid_stride = num_heads * channels;
+ CUDA_1D_KERNEL_LOOP(index, n) {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ scalar_t *grad_sampling_loc_out =
+ grad_sampling_loc + (grad_sampling_ptr << 1);
+ scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col = 0; l_col < num_levels; ++l_col) {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset =
+ data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col = 0; p_col < num_point; ++p_col) {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight + threadIdx.x) = 0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
+ w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc + (threadIdx.x << 1),
+ cache_grad_attn_weight + threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0) {
+ scalar_t _grad_w = cache_grad_sampling_loc[0],
+ _grad_h = cache_grad_sampling_loc[1],
+ _grad_a = cache_grad_attn_weight[0];
+ int sid = 2;
+ for (unsigned int _tid = 1; _tid < blockSize; ++_tid) {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[_tid];
+ sid += 2;
+ }
+
+ *grad_sampling_loc_out = _grad_w;
+ *(grad_sampling_loc_out + 1) = _grad_h;
+ *grad_attn_weight_out = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight_out += grad_weight_stride;
+ grad_sampling_loc_out += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
+ const int n, const scalar_t *grad_col, const scalar_t *data_value,
+ const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
+ const int batch_size, const int spatial_size, const int num_heads,
+ const int channels, const int num_levels, const int num_query,
+ const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight) {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ CUDA_1D_KERNEL_LOOP(index, n) {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ scalar_t *grad_sampling_loc_out =
+ grad_sampling_loc + (grad_sampling_ptr << 1);
+ scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col = 0; l_col < num_levels; ++l_col) {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset =
+ data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col = 0; p_col < num_point; ++p_col) {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight + threadIdx.x) = 0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
+ w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc + (threadIdx.x << 1),
+ cache_grad_attn_weight + threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s = blockSize / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] +=
+ cache_grad_sampling_loc[xid2 + 1];
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0) {
+ *grad_sampling_loc_out = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc_out + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight_out = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight_out += grad_weight_stride;
+ grad_sampling_loc_out += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
+ const int n, const scalar_t *grad_col, const scalar_t *data_value,
+ const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
+ const int batch_size, const int spatial_size, const int num_heads,
+ const int channels, const int num_levels, const int num_query,
+ const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight) {
+ extern __shared__ int _s[];
+ scalar_t *cache_grad_sampling_loc = reinterpret_cast(_s);
+ scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ CUDA_1D_KERNEL_LOOP(index, n) {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ scalar_t *grad_sampling_loc_out =
+ grad_sampling_loc + (grad_sampling_ptr << 1);
+ scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col = 0; l_col < num_levels; ++l_col) {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset =
+ data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col = 0; p_col < num_point; ++p_col) {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight + threadIdx.x) = 0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
+ w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc + (threadIdx.x << 1),
+ cache_grad_attn_weight + threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0) {
+ scalar_t _grad_w = cache_grad_sampling_loc[0],
+ _grad_h = cache_grad_sampling_loc[1],
+ _grad_a = cache_grad_attn_weight[0];
+ int sid = 2;
+ for (unsigned int _tid = 1; _tid < blockDim.x; ++_tid) {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[_tid];
+ sid += 2;
+ }
+
+ *grad_sampling_loc_out = _grad_w;
+ *(grad_sampling_loc_out + 1) = _grad_h;
+ *grad_attn_weight_out = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight_out += grad_weight_stride;
+ grad_sampling_loc_out += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
+ const int n, const scalar_t *grad_col, const scalar_t *data_value,
+ const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
+ const int batch_size, const int spatial_size, const int num_heads,
+ const int channels, const int num_levels, const int num_query,
+ const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight) {
+ extern __shared__ int _s[];
+ scalar_t *cache_grad_sampling_loc = reinterpret_cast(_s);
+ scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ CUDA_1D_KERNEL_LOOP(index, n) {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ scalar_t *grad_sampling_loc_out =
+ grad_sampling_loc + (grad_sampling_ptr << 1);
+ scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col = 0; l_col < num_levels; ++l_col) {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset =
+ data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col = 0; p_col < num_point; ++p_col) {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight + threadIdx.x) = 0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
+ w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc + (threadIdx.x << 1),
+ cache_grad_attn_weight + threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0;
+ s >>= 1, spre >>= 1) {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] +=
+ cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre) {
+ cache_grad_attn_weight[tid] +=
+ cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] +=
+ cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] +=
+ cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0) {
+ *grad_sampling_loc_out = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc_out + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight_out = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight_out += grad_weight_stride;
+ grad_sampling_loc_out += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
+ const int n, const scalar_t *grad_col, const scalar_t *data_value,
+ const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
+ const int batch_size, const int spatial_size, const int num_heads,
+ const int channels, const int num_levels, const int num_query,
+ const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight) {
+ extern __shared__ int _s[];
+ scalar_t *cache_grad_sampling_loc = reinterpret_cast(_s);
+ scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ CUDA_1D_KERNEL_LOOP(index, n) {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ scalar_t *grad_sampling_loc_out =
+ grad_sampling_loc + (grad_sampling_ptr << 1);
+ scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col = 0; l_col < num_levels; ++l_col) {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset =
+ data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col = 0; p_col < num_point; ++p_col) {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight + threadIdx.x) = 0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
+ w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc + (threadIdx.x << 1),
+ cache_grad_attn_weight + threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0;
+ s >>= 1, spre >>= 1) {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] +=
+ cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre) {
+ cache_grad_attn_weight[tid] +=
+ cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] +=
+ cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] +=
+ cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0) {
+ atomicAdd(grad_sampling_loc_out, cache_grad_sampling_loc[0]);
+ atomicAdd(grad_sampling_loc_out + 1, cache_grad_sampling_loc[1]);
+ atomicAdd(grad_attn_weight_out, cache_grad_attn_weight[0]);
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight_out += grad_weight_stride;
+ grad_sampling_loc_out += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_gm(
+ const int n, const scalar_t *grad_col, const scalar_t *data_value,
+ const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
+ const int batch_size, const int spatial_size, const int num_heads,
+ const int channels, const int num_levels, const int num_query,
+ const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight) {
+ CUDA_1D_KERNEL_LOOP(index, n) {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ scalar_t *grad_sampling_loc_out =
+ grad_sampling_loc + (grad_sampling_ptr << 1);
+ scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col = 0; l_col < num_levels; ++l_col) {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset =
+ data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col = 0; p_col < num_point; ++p_col) {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+ ms_deform_attn_col2im_bilinear_gm(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
+ w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
+ grad_sampling_loc_out, grad_attn_weight_out);
+ }
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight_out += grad_weight_stride;
+ grad_sampling_loc_out += grad_loc_stride;
+ }
+ }
+ }
+}
+#endif // DEFORM_ATTN_CUDA_KERNEL
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/roi_align_cuda_kernel.cuh b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/roi_align_cuda_kernel.cuh
new file mode 100755
index 0000000000000000000000000000000000000000..e608474a36a62ff6c54b7b24bd0bde0c7fed2c5c
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/roi_align_cuda_kernel.cuh
@@ -0,0 +1,212 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef ROI_ALIGN_CUDA_KERNEL_CUH
+#define ROI_ALIGN_CUDA_KERNEL_CUH
+
+#include
+#ifdef DBNET_CV_WITH_TRT
+#include "common_cuda_helper.hpp"
+#else // DBNET_CV_WITH_TRT
+#ifdef DBNET_CV_USE_PARROTS
+#include "parrots_cuda_helper.hpp"
+#else // DBNET_CV_USE_PARROTS
+#include "pytorch_cuda_helper.hpp"
+#endif // DBNET_CV_USE_PARROTS
+#endif // DBNET_CV_WITH_TRT
+
+/*** Forward ***/
+template
+__global__ void roi_align_forward_cuda_kernel(
+ const int nthreads, const T* input, const T* rois, T* output, T* argmax_y,
+ T* argmax_x, const int pooled_height, const int pooled_width,
+ const T spatial_scale, const int sampling_ratio,
+ const int pool_mode, // 0 - max pool, 1 - avg pool
+ const bool aligned, const int channels, const int height, const int width) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // (n, c, ph, pw) is an element in the pooled output
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+
+ const T* offset_rois = rois + n * 5;
+ int roi_batch_ind = offset_rois[0];
+
+ // Do not using rounding; this implementation detail is critical
+ T offset = aligned ? (T)0.5 : (T)0.0;
+ T roi_start_w = offset_rois[1] * spatial_scale - offset;
+ T roi_start_h = offset_rois[2] * spatial_scale - offset;
+ T roi_end_w = offset_rois[3] * spatial_scale - offset;
+ T roi_end_h = offset_rois[4] * spatial_scale - offset;
+
+ T roi_width = roi_end_w - roi_start_w;
+ T roi_height = roi_end_h - roi_start_h;
+ if (!aligned) { // for backward-compatibility only
+ roi_width = max(roi_width, (T)1.);
+ roi_height = max(roi_height, (T)1.);
+ }
+
+ T bin_size_h = static_cast(roi_height) / static_cast(pooled_height);
+ T bin_size_w = static_cast(roi_width) / static_cast(pooled_width);
+
+ const T* offset_input =
+ input + (roi_batch_ind * channels + c) * height * width;
+
+ // We use roi_bin_grid to sample the grid and mimic integral
+ int roi_bin_grid_h =
+ (sampling_ratio > 0)
+ ? sampling_ratio
+ : static_cast(ceilf(roi_height / pooled_height));
+ int roi_bin_grid_w =
+ (sampling_ratio > 0)
+ ? sampling_ratio
+ : static_cast(ceilf(roi_width / pooled_width));
+
+ if (pool_mode == 0) {
+ // We do max pooling inside a bin
+ T maxval = -FLT_MAX;
+ T maxidx_y = -1.f, maxidx_x = -1.f;
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+ const T y = roi_start_h + ph * bin_size_h +
+ static_cast(iy + .5f) * bin_size_h /
+ static_cast(roi_bin_grid_h);
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+ const T x = roi_start_w + pw * bin_size_w +
+ static_cast(ix + .5f) * bin_size_w /
+ static_cast(roi_bin_grid_w);
+ T val =
+ bilinear_interpolate(offset_input, height, width, y, x, index);
+ if (val > maxval) {
+ maxval = val;
+ maxidx_y = y;
+ maxidx_x = x;
+ }
+ }
+ }
+ output[index] = maxval;
+ argmax_y[index] = maxidx_y;
+ argmax_x[index] = maxidx_x;
+ } else if (pool_mode == 1) {
+ // We do average pooling inside a bin
+ const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1);
+ T output_val = 0.;
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+ const T y = roi_start_h + ph * bin_size_h +
+ static_cast(iy + .5f) * bin_size_h /
+ static_cast(roi_bin_grid_h);
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+ const T x = roi_start_w + pw * bin_size_w +
+ static_cast(ix + .5f) * bin_size_w /
+ static_cast(roi_bin_grid_w);
+ T val =
+ bilinear_interpolate(offset_input, height, width, y, x, index);
+ output_val += val;
+ }
+ }
+ output[index] = output_val / count;
+ }
+ }
+}
+
+/*** Backward ***/
+template
+__global__ void roi_align_backward_cuda_kernel(
+ const int nthreads, const T* grad_output, const T* rois, const T* argmax_y,
+ const T* argmax_x, T* grad_input, const int pooled_height,
+ const int pooled_width, const T spatial_scale, const int sampling_ratio,
+ const int pool_mode, // 0 - max pool, 1 - avg pool
+ const bool aligned, const int channels, const int height, const int width) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // (n, c, ph, pw) is an element in the pooled output
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+
+ const T grad_output_this_bin = grad_output[index];
+
+ const T* offset_rois = rois + n * 5;
+ int roi_batch_ind = offset_rois[0];
+ T* offset_grad_input =
+ grad_input + ((roi_batch_ind * channels + c) * height * width);
+
+ if (pool_mode == 0) {
+ T y = argmax_y[index], x = argmax_x[index];
+ if (y != -1.f) {
+ T w1, w2, w3, w4;
+ int x_low, x_high, y_low, y_high;
+ bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
+ x_low, x_high, y_low, y_high, index);
+
+ if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
+ atomicAdd(offset_grad_input + y_low * width + x_low,
+ grad_output_this_bin * w1);
+ atomicAdd(offset_grad_input + y_low * width + x_high,
+ grad_output_this_bin * w2);
+ atomicAdd(offset_grad_input + y_high * width + x_low,
+ grad_output_this_bin * w3);
+ atomicAdd(offset_grad_input + y_high * width + x_high,
+ grad_output_this_bin * w4);
+ }
+ }
+ } else if (pool_mode == 1) {
+ // Do not using rounding; this implementation detail is critical
+ T offset = aligned ? (T)0.5 : (T)0.0;
+ T roi_start_w = offset_rois[1] * spatial_scale - offset;
+ T roi_start_h = offset_rois[2] * spatial_scale - offset;
+ T roi_end_w = offset_rois[3] * spatial_scale - offset;
+ T roi_end_h = offset_rois[4] * spatial_scale - offset;
+
+ T roi_width = roi_end_w - roi_start_w;
+ T roi_height = roi_end_h - roi_start_h;
+ if (!aligned) { // for backward-compatibility only
+ roi_width = max(roi_width, (T)1.);
+ roi_height = max(roi_height, (T)1.);
+ }
+
+ T bin_size_h = static_cast(roi_height) / static_cast(pooled_height);
+ T bin_size_w = static_cast(roi_width) / static_cast(pooled_width);
+
+ // We use roi_bin_grid to sample the grid and mimic integral
+ int roi_bin_grid_h =
+ (sampling_ratio > 0)
+ ? sampling_ratio
+ : static_cast(ceilf(roi_height / pooled_height));
+ int roi_bin_grid_w =
+ (sampling_ratio > 0)
+ ? sampling_ratio
+ : static_cast(ceilf(roi_width / pooled_width));
+
+ // We do average (integral) pooling inside a bin
+ const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
+
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+ const T y = roi_start_h + ph * bin_size_h +
+ static_cast(iy + .5f) * bin_size_h /
+ static_cast(roi_bin_grid_h);
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+ const T x = roi_start_w + pw * bin_size_w +
+ static_cast(ix + .5f) * bin_size_w /
+ static_cast(roi_bin_grid_w);
+
+ T w1, w2, w3, w4;
+ int x_low, x_high, y_low, y_high;
+ bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
+ x_low, x_high, y_low, y_high, index);
+
+ if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
+ atomicAdd(offset_grad_input + y_low * width + x_low,
+ grad_output_this_bin * w1 / count);
+ atomicAdd(offset_grad_input + y_low * width + x_high,
+ grad_output_this_bin * w2 / count);
+ atomicAdd(offset_grad_input + y_high * width + x_low,
+ grad_output_this_bin * w3 / count);
+ atomicAdd(offset_grad_input + y_high * width + x_high,
+ grad_output_this_bin * w4 / count);
+ }
+ }
+ }
+ }
+ }
+}
+
+#endif // ROI_ALIGN_CUDA_KERNEL_CUH
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/roi_pool_cuda_kernel.cuh b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/roi_pool_cuda_kernel.cuh
new file mode 100755
index 0000000000000000000000000000000000000000..ba732685c1e10ae7710a0faee95bf31a3e8ce0dd
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/roi_pool_cuda_kernel.cuh
@@ -0,0 +1,93 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef ROI_POOL_CUDA_KERNEL_CUH
+#define ROI_POOL_CUDA_KERNEL_CUH
+
+#ifdef DBNET_CV_USE_PARROTS
+#include "parrots_cuda_helper.hpp"
+#else
+#include "pytorch_cuda_helper.hpp"
+#endif
+
+template
+__global__ void roi_pool_forward_cuda_kernel(
+ const int nthreads, const T* input, const T* rois, T* output, int* argmax,
+ const int pooled_height, const int pooled_width, const T spatial_scale,
+ const int channels, const int height, const int width) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // (n, c, ph, pw) is an element in the pooled output
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+
+ const T* offset_rois = rois + n * 5;
+ int roi_batch_ind = offset_rois[0];
+ // calculate the roi region on feature maps
+ T roi_x1 = offset_rois[1] * spatial_scale;
+ T roi_y1 = offset_rois[2] * spatial_scale;
+ T roi_x2 = (offset_rois[3] + 1) * spatial_scale;
+ T roi_y2 = (offset_rois[4] + 1) * spatial_scale;
+
+ // force malformed rois to be 1x1
+ T roi_w = roi_x2 - roi_x1;
+ T roi_h = roi_y2 - roi_y1;
+ if (roi_w <= 0 || roi_h <= 0) continue;
+
+ T bin_size_w = roi_w / static_cast(pooled_width);
+ T bin_size_h = roi_h / static_cast(pooled_height);
+
+ // the corresponding bin region
+ int bin_x1 = floorf(static_cast(pw) * bin_size_w + roi_x1);
+ int bin_y1 = floorf(static_cast(ph) * bin_size_h + roi_y1);
+ int bin_x2 = ceilf(static_cast(pw + 1) * bin_size_w + roi_x1);
+ int bin_y2 = ceilf(static_cast(ph + 1) * bin_size_h + roi_y1);
+
+ // add roi offsets and clip to input boundaries
+ bin_x1 = min(max(bin_x1, 0), width);
+ bin_y1 = min(max(bin_y1, 0), height);
+ bin_x2 = min(max(bin_x2, 0), width);
+ bin_y2 = min(max(bin_y2, 0), height);
+ bool is_empty = (bin_y2 <= bin_y1) || (bin_x2 <= bin_x1);
+
+ const T* offset_input =
+ input + (roi_batch_ind * channels + c) * height * width;
+ // Define an empty pooling region to be zero
+ // If nothing is pooled, argmax = -1 causes nothing to be backprop'd
+ T max_val = is_empty ? 0 : -FLT_MAX;
+ int max_idx = -1;
+ for (int h = bin_y1; h < bin_y2; ++h) {
+ for (int w = bin_x1; w < bin_x2; ++w) {
+ int offset = h * width + w;
+ if (offset_input[offset] > max_val) {
+ max_val = offset_input[offset];
+ max_idx = offset;
+ }
+ }
+ }
+ output[index] = max_val;
+ if (argmax != NULL) argmax[index] = max_idx;
+ }
+}
+
+template
+__global__ void roi_pool_backward_cuda_kernel(
+ const int nthreads, const T* grad_output, const T* rois, const int* argmax,
+ T* grad_input, const int pooled_height, const int pooled_width,
+ const int channels, const int height, const int width) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // (n, c) is an element in the pooled output
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+
+ int roi_batch_ind = rois[n * 5];
+ T* grad_input_offset =
+ grad_input + ((roi_batch_ind * channels + c) * height * width);
+ int argmax_index = argmax[index];
+
+ if (argmax_index != -1) {
+ atomicAdd(grad_input_offset + argmax_index, grad_output[index]);
+ }
+ }
+}
+
+#endif // ROI_POOL_CUDA_KERNEL_CUH
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/sync_bn_cuda_kernel.cuh b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/sync_bn_cuda_kernel.cuh
new file mode 100755
index 0000000000000000000000000000000000000000..78d97bedd0e2ce9f6cf8568f9c43909bdd49b22a
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/sync_bn_cuda_kernel.cuh
@@ -0,0 +1,331 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef SYNCBN_CUDA_KERNEL_CUH
+#define SYNCBN_CUDA_KERNEL_CUH
+
+#ifdef DBNET_CV_USE_PARROTS
+#include "parrots_cuda_helper.hpp"
+#else
+#include "pytorch_cuda_helper.hpp"
+#endif
+
+template
+__global__ void sync_bn_forward_mean_cuda_kernel(const T *input, float *mean,
+ int num, int channels,
+ int spatial) {
+ __shared__ float buffer[THREADS_PER_BLOCK];
+ int tid = threadIdx.x;
+ int c = blockIdx.x;
+ buffer[tid] = 0;
+ for (int i = tid; i < num * spatial; i += blockDim.x) {
+ int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
+ buffer[tid] += input[index];
+ }
+ __syncthreads();
+
+ for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ buffer[tid] += buffer[tid + s];
+ }
+ __syncthreads();
+ }
+ int total = num * spatial;
+ if (tid == 0) {
+ mean[c] = buffer[0] / total;
+ }
+}
+
+template <>
+__global__ void sync_bn_forward_mean_cuda_kernel(const phalf *input,
+ float *mean, int num,
+ int channels, int spatial) {
+ __shared__ float buffer[THREADS_PER_BLOCK];
+ int tid = threadIdx.x;
+ int c = blockIdx.x;
+ buffer[tid] = 0;
+ for (int i = tid; i < num * spatial; i += blockDim.x) {
+ int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
+ buffer[tid] += static_cast(input[index]);
+ }
+ __syncthreads();
+
+ for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ buffer[tid] += buffer[tid + s];
+ }
+ __syncthreads();
+ }
+ int total = num * spatial;
+ if (tid == 0) {
+ mean[c] = buffer[0] / total;
+ }
+}
+
+template
+__global__ void sync_bn_forward_var_cuda_kernel(const T *input,
+ const float *mean, float *var,
+ int num, int channels,
+ int spatial) {
+ __shared__ float buffer[THREADS_PER_BLOCK];
+ int tid = threadIdx.x;
+ int c = blockIdx.x;
+ buffer[tid] = 0;
+ for (int i = tid; i < num * spatial; i += blockDim.x) {
+ int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
+ float td = input[index] - mean[c];
+ buffer[tid] += td * td;
+ }
+ __syncthreads();
+ for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ buffer[tid] += buffer[tid + s];
+ }
+ __syncthreads();
+ }
+ int total = num * spatial;
+ if (tid == 0) {
+ var[c] = buffer[0] / total;
+ }
+}
+
+template <>
+__global__ void sync_bn_forward_var_cuda_kernel(const phalf *input,
+ const float *mean, float *var,
+ int num, int channels,
+ int spatial) {
+ __shared__ float buffer[THREADS_PER_BLOCK];
+ int tid = threadIdx.x;
+ int c = blockIdx.x;
+ buffer[tid] = 0;
+ for (int i = tid; i < num * spatial; i += blockDim.x) {
+ int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
+ float td = static_cast(input[index]) - mean[c];
+ buffer[tid] += td * td;
+ }
+ __syncthreads();
+ for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ buffer[tid] += buffer[tid + s];
+ }
+ __syncthreads();
+ }
+ int total = num * spatial;
+ if (tid == 0) {
+ var[c] = buffer[0] / total;
+ }
+}
+
+template
+__global__ void sync_bn_forward_output_cuda_kernel(
+ const T *input, const float *mean, const float *var, float *running_mean,
+ float *running_var, const float *weight, const float *bias, float *norm,
+ float *std, T *output, int num, int channels, int spatial, float eps,
+ float momentum, int group_size) {
+ int tid = threadIdx.x;
+ int c = blockIdx.x;
+ float mean_value = mean[c];
+ float std_value = sqrt(var[c] + eps);
+
+ if (weight != nullptr) {
+ float weight_value = weight[c];
+ float bias_value = bias[c];
+ if (norm != nullptr) {
+ for (int i = tid; i < num * spatial; i += blockDim.x) {
+ int index =
+ (i / spatial) * channels * spatial + c * spatial + i % spatial;
+ norm[index] = (input[index] - mean_value) / std_value;
+ output[index] = norm[index] * weight_value + bias_value;
+ }
+ } else {
+ for (int i = tid; i < num * spatial; i += blockDim.x) {
+ int index =
+ (i / spatial) * channels * spatial + c * spatial + i % spatial;
+ output[index] =
+ (input[index] - mean_value) / std_value * weight_value + bias_value;
+ }
+ }
+ } else {
+ if (norm != nullptr) {
+ for (int i = tid; i < num * spatial; i += blockDim.x) {
+ int index =
+ (i / spatial) * channels * spatial + c * spatial + i % spatial;
+ output[index] = norm[index] = (input[index] - mean_value) / std_value;
+ }
+ } else {
+ for (int i = tid; i < num * spatial; i += blockDim.x) {
+ int index =
+ (i / spatial) * channels * spatial + c * spatial + i % spatial;
+ output[index] = (input[index] - mean_value) / std_value;
+ }
+ }
+ }
+ if (tid == 0) {
+ if (std != nullptr) std[c] = std_value;
+ if (running_mean != nullptr) {
+ running_mean[c] =
+ momentum * mean_value + (1 - momentum) * running_mean[c];
+ int count = num * spatial * group_size;
+ float var_unbias = count > 1 ? var[c] * count / (count - 1) : var[c];
+ running_var[c] = momentum * var_unbias + (1 - momentum) * running_var[c];
+ }
+ }
+}
+
+template <>
+__global__ void sync_bn_forward_output_cuda_kernel(
+ const phalf *input, const float *mean, const float *var,
+ float *running_mean, float *running_var, const float *weight,
+ const float *bias, float *norm, float *std, phalf *output, int num,
+ int channels, int spatial, float eps, float momentum, int group_size) {
+ int tid = threadIdx.x;
+ int c = blockIdx.x;
+ float mean_value = mean[c];
+ float std_value = sqrt(var[c] + eps);
+ if (weight != nullptr) {
+ float weight_value = weight[c];
+ float bias_value = bias[c];
+ if (norm != nullptr) {
+ for (int i = tid; i < num * spatial; i += blockDim.x) {
+ int index =
+ (i / spatial) * channels * spatial + c * spatial + i % spatial;
+ norm[index] =
+ (static_cast(input[index]) - mean_value) / std_value;
+ output[index] =
+ static_cast(norm[index] * weight_value + bias_value);
+ }
+ } else {
+ for (int i = tid; i < num * spatial; i += blockDim.x) {
+ int index =
+ (i / spatial) * channels * spatial + c * spatial + i % spatial;
+ output[index] =
+ static_cast((static_cast(input[index]) - mean_value) /
+ std_value * weight_value +
+ bias_value);
+ }
+ }
+ } else {
+ if (norm != nullptr) {
+ for (int i = tid; i < num * spatial; i += blockDim.x) {
+ int index =
+ (i / spatial) * channels * spatial + c * spatial + i % spatial;
+ norm[index] =
+ (static_cast(input[index]) - mean_value) / std_value;
+ output[index] = static_cast(norm[index]);
+ }
+ } else {
+ for (int i = tid; i < num * spatial; i += blockDim.x) {
+ int index =
+ (i / spatial) * channels * spatial + c * spatial + i % spatial;
+ output[index] = static_cast(
+ (static_cast(input[index]) - mean_value) / std_value);
+ }
+ }
+ }
+ if (tid == 0) {
+ if (std != nullptr) std[c] = std_value;
+ if (running_mean != nullptr) {
+ running_mean[c] =
+ momentum * mean_value + (1 - momentum) * running_mean[c];
+ int count = num * spatial * group_size;
+ float var_unbias = count > 1 ? var[c] * count / (count - 1) : var[c];
+ running_var[c] = momentum * var_unbias + (1 - momentum) * running_var[c];
+ }
+ }
+}
+
+template
+__global__ void sync_bn_backward_param_cuda_kernel(const T *grad_output,
+ const float *norm,
+ float *grad_weight,
+ float *grad_bias, int num,
+ int channels, int spatial) {
+ __shared__ float buffer1[THREADS_PER_BLOCK];
+ __shared__ float buffer2[THREADS_PER_BLOCK];
+
+ int tid = threadIdx.x;
+ int c = blockIdx.x;
+ buffer1[tid] = buffer2[tid] = 0;
+ for (int i = tid; i < num * spatial; i += blockDim.x) {
+ int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
+ buffer1[tid] += grad_output[index] * norm[index];
+ buffer2[tid] += grad_output[index];
+ }
+ __syncthreads();
+
+ for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ buffer1[tid] += buffer1[tid + s];
+ buffer2[tid] += buffer2[tid + s];
+ }
+ __syncthreads();
+ }
+ if (tid == 0) {
+ grad_weight[c] = buffer1[0];
+ grad_bias[c] = buffer2[0];
+ }
+}
+
+template <>
+__global__ void sync_bn_backward_param_cuda_kernel(const phalf *grad_output,
+ const float *norm,
+ float *grad_weight,
+ float *grad_bias, int num,
+ int channels, int spatial) {
+ __shared__ float buffer1[THREADS_PER_BLOCK];
+ __shared__ float buffer2[THREADS_PER_BLOCK];
+
+ int tid = threadIdx.x;
+ int c = blockIdx.x;
+ buffer1[tid] = buffer2[tid] = 0;
+ for (int i = tid; i < num * spatial; i += blockDim.x) {
+ int index = (i / spatial) * channels * spatial + c * spatial + i % spatial;
+ buffer1[tid] += static_cast(grad_output[index]) * norm[index];
+ buffer2[tid] += static_cast(grad_output[index]);
+ }
+ __syncthreads();
+
+ for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ buffer1[tid] += buffer1[tid + s];
+ buffer2[tid] += buffer2[tid + s];
+ }
+ __syncthreads();
+ }
+ if (tid == 0) {
+ grad_weight[c] = buffer1[0];
+ grad_bias[c] = buffer2[0];
+ }
+}
+
+template
+__global__ void sync_bn_backward_data_cuda_kernel(
+ int output_size, const T *grad_output, const float *weight,
+ const float *grad_weight, const float *grad_bias, const float *norm,
+ const float *std, T *grad_input, int num, int channels, int spatial) {
+ int factor = num * spatial;
+ CUDA_1D_KERNEL_LOOP(index, output_size) {
+ int c = (index / spatial) % channels;
+ grad_input[index] =
+ weight[c] *
+ (grad_output[index] -
+ (grad_weight[c] * norm[index] + grad_bias[c]) / factor) /
+ std[c];
+ }
+}
+
+template <>
+__global__ void sync_bn_backward_data_cuda_kernel(
+ int output_size, const phalf *grad_output, const float *weight,
+ const float *grad_weight, const float *grad_bias, const float *norm,
+ const float *std, phalf *grad_input, int num, int channels, int spatial) {
+ int factor = num * spatial;
+ CUDA_1D_KERNEL_LOOP(index, output_size) {
+ int c = (index / spatial) % channels;
+ grad_input[index] = static_cast(
+ weight[c] *
+ (static_cast(grad_output[index]) -
+ (grad_weight[c] * norm[index] + grad_bias[c]) / factor) /
+ std[c]);
+ }
+}
+
+#endif // SYNCBN_CUDA_KERNEL_CUH
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_cpp_helper.hpp b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_cpp_helper.hpp
new file mode 100755
index 0000000000000000000000000000000000000000..f68e8740561ef833c09e1ba9f999922f5d04bce5
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_cpp_helper.hpp
@@ -0,0 +1,27 @@
+#ifndef PYTORCH_CPP_HELPER
+#define PYTORCH_CPP_HELPER
+#include
+
+#include
+
+using namespace at;
+
+#define CHECK_CUDA(x) \
+ TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_MLU(x) \
+ TORCH_CHECK(x.device().type() == at::kMLU, #x " must be a MLU tensor")
+#define CHECK_CPU(x) \
+ TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")
+#define CHECK_CONTIGUOUS(x) \
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_CUDA_INPUT(x) \
+ CHECK_CUDA(x); \
+ CHECK_CONTIGUOUS(x)
+#define CHECK_MLU_INPUT(x) \
+ CHECK_MLU(x); \
+ CHECK_CONTIGUOUS(x)
+#define CHECK_CPU_INPUT(x) \
+ CHECK_CPU(x); \
+ CHECK_CONTIGUOUS(x)
+
+#endif // PYTORCH_CPP_HELPER
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_cuda_helper.hpp b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_cuda_helper.hpp
new file mode 100755
index 0000000000000000000000000000000000000000..9869b535f8a1de758b0c35612dbd4ac2a1701ad9
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_cuda_helper.hpp
@@ -0,0 +1,19 @@
+#ifndef PYTORCH_CUDA_HELPER
+#define PYTORCH_CUDA_HELPER
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include "common_cuda_helper.hpp"
+
+using at::Half;
+using at::Tensor;
+using phalf = at::Half;
+
+#define __PHALF(x) (x)
+
+#endif // PYTORCH_CUDA_HELPER
diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_device_registry.hpp b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_device_registry.hpp
new file mode 100755
index 0000000000000000000000000000000000000000..2a32b7270c3521f960394af7d18cbbd03ba50df1
--- /dev/null
+++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_device_registry.hpp
@@ -0,0 +1,141 @@
+#ifndef PYTORCH_DEVICE_REGISTRY_H
+#define PYTORCH_DEVICE_REGISTRY_H
+
+// Using is recommended in the official documentation in
+// https://pytorch.org/tutorials/advanced/cpp_extension.html#writing-the-c-op.
+// However, we use for compatibility with CUDA 9.0
+// Read https://github.com/pytorch/extension-cpp/issues/35 for more details.
+#include
+
+#include
+#include
+#include