From fb16f2c2f16e28e830ca78161120debb9dd7392a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=B4=E6=B0=B8=E4=B9=90?= Date: Thu, 23 Feb 2023 03:19:01 +0000 Subject: [PATCH 1/3] =?UTF-8?q?Signed-off-by:=20=E5=90=B4=E6=B0=B8?= =?UTF-8?q?=E4=B9=90=20?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit add dbnet pytorch model link #I6BFUO add dbnet pytorch model --- cv/ocr/dbnet/pytorch/.gitignore | 49 + cv/ocr/dbnet/pytorch/CITATION.cff | 9 + cv/ocr/dbnet/pytorch/EGG-INFO/PKG-INFO | 380 ++++++ cv/ocr/dbnet/pytorch/EGG-INFO/SOURCES.txt | 62 + .../pytorch/EGG-INFO/dependency_links.txt | 1 + cv/ocr/dbnet/pytorch/EGG-INFO/not-zip-safe | 1 + cv/ocr/dbnet/pytorch/EGG-INFO/requires.txt | 59 + cv/ocr/dbnet/pytorch/EGG-INFO/top_level.txt | 1 + cv/ocr/dbnet/pytorch/LICENSE | 203 +++ cv/ocr/dbnet/pytorch/README.md | 65 + .../pytorch/configs/_base_/default_runtime.py | 17 + .../configs/_base_/det_datasets/icdar2015.py | 18 + .../det_models/dbnet_mobilenetv3_fpnc.py | 32 + .../_base_/det_models/dbnet_r18_fpnc.py | 31 + .../_base_/det_models/dbnet_r50dcnv2_fpnc.py | 23 + .../_base_/det_pipelines/dbnet_pipeline.py | 88 ++ .../_base_/schedules/schedule_adam_1200e.py | 8 + .../_base_/schedules/schedule_sgd_1200e.py | 8 + .../pytorch/configs/textdet/dbnet/README.md | 33 + .../dbnet_mobilenetv3_fpnc_1200e_icdar2015.py | 33 + .../dbnet/dbnet_r18_fpnc_1200e_icdar2015.py | 33 + .../dbnet_r50dcnv2_fpnc_1200e_icdar2015.py | 37 + .../configs/textdet/dbnet/metafile.yml | 40 + cv/ocr/dbnet/pytorch/dbnet/__init__.py | 70 + cv/ocr/dbnet/pytorch/dbnet/apis/__init__.py | 12 + cv/ocr/dbnet/pytorch/dbnet/apis/inference.py | 238 ++++ cv/ocr/dbnet/pytorch/dbnet/apis/test.py | 157 +++ cv/ocr/dbnet/pytorch/dbnet/apis/train.py | 185 +++ cv/ocr/dbnet/pytorch/dbnet/apis/utils.py | 123 ++ cv/ocr/dbnet/pytorch/dbnet/core/__init__.py | 16 + .../pytorch/dbnet/core/evaluation/__init__.py | 13 + .../pytorch/dbnet/core/evaluation/hmean.py | 172 +++ .../dbnet/core/evaluation/hmean_ic13.py | 217 +++ .../dbnet/core/evaluation/hmean_iou.py | 117 ++ .../pytorch/dbnet/core/evaluation/utils.py | 547 ++++++++ cv/ocr/dbnet/pytorch/dbnet/core/mask.py | 102 ++ cv/ocr/dbnet/pytorch/dbnet/core/visualize.py | 888 +++++++++++++ .../dbnet/pytorch/dbnet/datasets/__init__.py | 25 + .../pytorch/dbnet/datasets/base_dataset.py | 169 +++ .../dbnet/pytorch/dbnet/datasets/builder.py | 15 + .../pytorch/dbnet/datasets/icdar_dataset.py | 191 +++ .../dbnet/datasets/pipelines/__init__.py | 37 + .../dbnet/datasets/pipelines/box_utils.py | 53 + .../pytorch/dbnet/datasets/pipelines/crop.py | 125 ++ .../pipelines/custom_format_bundle.py | 66 + .../datasets/pipelines/dbnet_transforms.py | 325 +++++ .../dbnet/datasets/pipelines/loading.py | 189 +++ .../pipelines/textdet_targets/__init__.py | 13 + .../textdet_targets/base_textdet_targets.py | 168 +++ .../textdet_targets/dbnet_targets.py | 250 ++++ .../datasets/pipelines/transform_wrappers.py | 128 ++ .../dbnet/datasets/pipelines/transforms.py | 1020 ++++++++++++++ .../dbnet/datasets/text_det_dataset.py | 131 ++ .../dbnet/datasets/uniform_concat_dataset.py | 151 +++ .../pytorch/dbnet/datasets/utils/__init__.py | 8 + .../pytorch/dbnet/datasets/utils/backend.py | 198 +++ .../pytorch/dbnet/datasets/utils/loader.py | 112 ++ .../pytorch/dbnet/datasets/utils/parser.py | 82 ++ cv/ocr/dbnet/pytorch/dbnet/models/__init__.py | 19 + cv/ocr/dbnet/pytorch/dbnet/models/builder.py | 152 +++ .../pytorch/dbnet/models/common/__init__.py | 6 + .../dbnet/models/common/backbones/__init__.py | 4 + .../dbnet/models/common/backbones/unet.py | 516 ++++++++ .../dbnet/models/common/detectors/__init__.py | 4 + .../models/common/detectors/single_stage.py | 39 + .../dbnet/models/common/losses/__init__.py | 6 + .../dbnet/models/common/losses/dice_loss.py | 31 + .../dbnet/models/common/losses/focal_loss.py | 31 + .../pytorch/dbnet/models/textdet/__init__.py | 11 + .../models/textdet/dense_heads/__init__.py | 18 + .../models/textdet/dense_heads/db_head.py | 96 ++ .../models/textdet/dense_heads/head_mixin.py | 91 ++ .../models/textdet/detectors/__init__.py | 19 + .../dbnet/models/textdet/detectors/dbnet.py | 27 + .../detectors/single_stage_text_detector.py | 61 + .../textdet/detectors/text_detector_mixin.py | 81 ++ .../dbnet/models/textdet/losses/__init__.py | 13 + .../dbnet/models/textdet/losses/db_loss.py | 166 +++ .../dbnet/models/textdet/necks/__init__.py | 8 + .../dbnet/models/textdet/necks/fpn_cat.py | 271 ++++ .../models/textdet/postprocess/__init__.py | 18 + .../textdet/postprocess/base_postprocessor.py | 15 + .../textdet/postprocess/db_postprocessor.py | 94 ++ .../dbnet/models/textdet/postprocess/utils.py | 482 +++++++ cv/ocr/dbnet/pytorch/dbnet/utils/__init__.py | 25 + cv/ocr/dbnet/pytorch/dbnet/utils/box_util.py | 199 +++ .../pytorch/dbnet/utils/check_argument.py | 72 + .../dbnet/pytorch/dbnet/utils/collect_env.py | 17 + cv/ocr/dbnet/pytorch/dbnet/utils/fileio.py | 38 + cv/ocr/dbnet/pytorch/dbnet/utils/lmdb_util.py | 128 ++ cv/ocr/dbnet/pytorch/dbnet/utils/logger.py | 25 + cv/ocr/dbnet/pytorch/dbnet/utils/model.py | 51 + cv/ocr/dbnet/pytorch/dbnet/utils/ocr.py | 878 ++++++++++++ cv/ocr/dbnet/pytorch/dbnet/utils/setup_env.py | 47 + .../dbnet/pytorch/dbnet/utils/string_util.py | 36 + cv/ocr/dbnet/pytorch/dbnet/version.py | 4 + cv/ocr/dbnet/pytorch/dbnet_cv/__init__.py | 16 + cv/ocr/dbnet/pytorch/dbnet_cv/cnn/__init__.py | 55 + .../pytorch/dbnet_cv/cnn/bricks/__init__.py | 18 + .../pytorch/dbnet_cv/cnn/bricks/activation.py | 93 ++ .../dbnet/pytorch/dbnet_cv/cnn/bricks/conv.py | 44 + .../cnn/bricks/conv2d_adaptive_padding.py | 62 + .../dbnet_cv/cnn/bricks/conv_module.py | 206 +++ .../dbnet/pytorch/dbnet_cv/cnn/bricks/drop.py | 65 + .../pytorch/dbnet_cv/cnn/bricks/hsigmoid.py | 48 + .../pytorch/dbnet_cv/cnn/bricks/hswish.py | 35 + .../dbnet/pytorch/dbnet_cv/cnn/bricks/norm.py | 144 ++ .../pytorch/dbnet_cv/cnn/bricks/padding.py | 36 + .../pytorch/dbnet_cv/cnn/bricks/plugin.py | 89 ++ .../pytorch/dbnet_cv/cnn/bricks/registry.py | 16 + .../pytorch/dbnet_cv/cnn/bricks/upsample.py | 84 ++ .../pytorch/dbnet_cv/cnn/bricks/wrappers.py | 180 +++ cv/ocr/dbnet/pytorch/dbnet_cv/cnn/builder.py | 30 + cv/ocr/dbnet/pytorch/dbnet_cv/cnn/resnet.py | 322 +++++ .../pytorch/dbnet_cv/cnn/utils/__init__.py | 19 + .../dbnet_cv/cnn/utils/flops_counter.py | 603 +++++++++ .../dbnet_cv/cnn/utils/fuse_conv_bn.py | 59 + .../pytorch/dbnet_cv/cnn/utils/sync_bn.py | 61 + .../pytorch/dbnet_cv/cnn/utils/weight_init.py | 708 ++++++++++ .../dbnet/pytorch/dbnet_cv/engine/__init__.py | 8 + cv/ocr/dbnet/pytorch/dbnet_cv/engine/test.py | 213 +++ .../dbnet/pytorch/dbnet_cv/fileio/__init__.py | 11 + .../pytorch/dbnet_cv/fileio/file_client.py | 1173 +++++++++++++++++ .../dbnet_cv/fileio/handlers/__init__.py | 7 + .../pytorch/dbnet_cv/fileio/handlers/base.py | 30 + .../dbnet_cv/fileio/handlers/json_handler.py | 36 + .../fileio/handlers/pickle_handler.py | 26 + .../dbnet_cv/fileio/handlers/yaml_handler.py | 25 + cv/ocr/dbnet/pytorch/dbnet_cv/fileio/io.py | 163 +++ cv/ocr/dbnet/pytorch/dbnet_cv/fileio/parse.py | 99 ++ .../dbnet/pytorch/dbnet_cv/image/__init__.py | 33 + .../pytorch/dbnet_cv/image/colorspace.py | 309 +++++ .../dbnet/pytorch/dbnet_cv/image/geometric.py | 741 +++++++++++ cv/ocr/dbnet/pytorch/dbnet_cv/image/io.py | 314 +++++ cv/ocr/dbnet/pytorch/dbnet_cv/image/misc.py | 53 + .../pytorch/dbnet_cv/image/photometric.py | 471 +++++++ .../dbnet_cv/model_zoo/deprecated.json | 6 + .../pytorch/dbnet_cv/model_zoo/mmcls.json | 59 + .../dbnet_cv/model_zoo/open_mmlab.json | 50 + .../dbnet_cv/model_zoo/torchvision_0.12.json | 57 + cv/ocr/dbnet/pytorch/dbnet_cv/ops/__init__.py | 104 ++ .../dbnet/pytorch/dbnet_cv/ops/csrc/README.md | 170 +++ .../csrc/common/cuda/common_cuda_helper.hpp | 120 ++ .../common/cuda/deform_conv_cuda_kernel.cuh | 367 ++++++ .../cuda/deform_roi_pool_cuda_kernel.cuh | 186 +++ .../modulated_deform_conv_cuda_kernel.cuh | 399 ++++++ .../cuda/ms_deform_attn_cuda_kernel.cuh | 801 +++++++++++ .../common/cuda/roi_align_cuda_kernel.cuh | 212 +++ .../csrc/common/cuda/roi_pool_cuda_kernel.cuh | 93 ++ .../csrc/common/cuda/sync_bn_cuda_kernel.cuh | 331 +++++ .../ops/csrc/common/pytorch_cpp_helper.hpp | 27 + .../ops/csrc/common/pytorch_cuda_helper.hpp | 19 + .../csrc/common/pytorch_device_registry.hpp | 141 ++ .../ops/csrc/pytorch/cuda/cudabind.cpp | 372 ++++++ .../ops/csrc/pytorch/cuda/deform_conv_cuda.cu | 105 ++ .../csrc/pytorch/cuda/deform_roi_pool_cuda.cu | 55 + .../cuda/modulated_deform_conv_cuda.cu | 96 ++ .../csrc/pytorch/cuda/ms_deform_attn_cuda.cu | 351 +++++ .../ops/csrc/pytorch/cuda/roi_align_cuda.cu | 58 + .../ops/csrc/pytorch/cuda/roi_pool_cuda.cu | 50 + .../ops/csrc/pytorch/cuda/sync_bn_cuda.cu | 110 ++ .../dbnet_cv/ops/csrc/pytorch/deform_conv.cpp | 517 ++++++++ .../ops/csrc/pytorch/deform_roi_pool.cpp | 42 + .../dbnet_cv/ops/csrc/pytorch/info.cpp | 56 + .../csrc/pytorch/modulated_deform_conv.cpp | 237 ++++ .../ops/csrc/pytorch/ms_deform_attn.cpp | 60 + .../dbnet_cv/ops/csrc/pytorch/pybind.cpp | 218 +++ .../dbnet_cv/ops/csrc/pytorch/roi_align.cpp | 41 + .../dbnet_cv/ops/csrc/pytorch/roi_pool.cpp | 31 + .../dbnet_cv/ops/csrc/pytorch/sync_bn.cpp | 69 + .../dbnet/pytorch/dbnet_cv/ops/deform_conv.py | 408 ++++++ .../pytorch/dbnet_cv/ops/deform_roi_pool.py | 209 +++ .../dbnet_cv/ops/deprecated_wrappers.py | 46 + cv/ocr/dbnet/pytorch/dbnet_cv/ops/info.py | 36 + .../dbnet_cv/ops/modulated_deform_conv.py | 283 ++++ .../dbnet_cv/ops/multi_scale_deform_attn.py | 357 +++++ .../dbnet/pytorch/dbnet_cv/ops/roi_align.py | 224 ++++ cv/ocr/dbnet/pytorch/dbnet_cv/ops/roi_pool.py | 86 ++ cv/ocr/dbnet/pytorch/dbnet_cv/ops/sync_bn.py | 279 ++++ .../pytorch/dbnet_cv/parallel/__init__.py | 13 + .../pytorch/dbnet_cv/parallel/_functions.py | 76 ++ .../pytorch/dbnet_cv/parallel/collate.py | 84 ++ .../dbnet_cv/parallel/data_container.py | 89 ++ .../dbnet_cv/parallel/data_parallel.py | 97 ++ .../pytorch/dbnet_cv/parallel/distributed.py | 138 ++ .../parallel/distributed_deprecated.py | 70 + .../pytorch/dbnet_cv/parallel/registry.py | 8 + .../dbnet_cv/parallel/scatter_gather.py | 59 + .../dbnet/pytorch/dbnet_cv/parallel/utils.py | 30 + .../dbnet/pytorch/dbnet_cv/runner/__init__.py | 74 ++ .../pytorch/dbnet_cv/runner/base_module.py | 213 +++ .../pytorch/dbnet_cv/runner/base_runner.py | 544 ++++++++ .../dbnet/pytorch/dbnet_cv/runner/builder.py | 25 + .../pytorch/dbnet_cv/runner/checkpoint.py | 800 +++++++++++ .../dbnet_cv/runner/default_constructor.py | 47 + .../pytorch/dbnet_cv/runner/dist_utils.py | 209 +++ .../dbnet_cv/runner/epoch_based_runner.py | 191 +++ .../pytorch/dbnet_cv/runner/fp16_utils.py | 435 ++++++ .../pytorch/dbnet_cv/runner/hooks/__init__.py | 48 + .../dbnet_cv/runner/hooks/checkpoint.py | 167 +++ .../pytorch/dbnet_cv/runner/hooks/closure.py | 11 + .../pytorch/dbnet_cv/runner/hooks/ema.py | 89 ++ .../dbnet_cv/runner/hooks/evaluation.py | 511 +++++++ .../pytorch/dbnet_cv/runner/hooks/hook.py | 92 ++ .../dbnet_cv/runner/hooks/iter_timer.py | 18 + .../dbnet_cv/runner/hooks/logger/__init__.py | 18 + .../dbnet_cv/runner/hooks/logger/base.py | 172 +++ .../dbnet_cv/runner/hooks/logger/clearml.py | 63 + .../dbnet_cv/runner/hooks/logger/dvclive.py | 69 + .../dbnet_cv/runner/hooks/logger/mlflow.py | 81 ++ .../dbnet_cv/runner/hooks/logger/neptune.py | 89 ++ .../dbnet_cv/runner/hooks/logger/pavi.py | 132 ++ .../dbnet_cv/runner/hooks/logger/segmind.py | 48 + .../runner/hooks/logger/tensorboard.py | 69 + .../dbnet_cv/runner/hooks/logger/text.py | 256 ++++ .../dbnet_cv/runner/hooks/logger/wandb.py | 107 ++ .../dbnet_cv/runner/hooks/lr_updater.py | 751 +++++++++++ .../pytorch/dbnet_cv/runner/hooks/memory.py | 25 + .../dbnet_cv/runner/hooks/momentum_updater.py | 566 ++++++++ .../dbnet_cv/runner/hooks/optimizer.py | 554 ++++++++ .../pytorch/dbnet_cv/runner/hooks/profiler.py | 180 +++ .../dbnet_cv/runner/hooks/sampler_seed.py | 20 + .../dbnet_cv/runner/hooks/sync_buffer.py | 22 + .../dbnet_cv/runner/iter_based_runner.py | 277 ++++ .../pytorch/dbnet_cv/runner/log_buffer.py | 41 + .../dbnet_cv/runner/optimizer/__init__.py | 9 + .../dbnet_cv/runner/optimizer/builder.py | 45 + .../runner/optimizer/default_constructor.py | 258 ++++ .../dbnet/pytorch/dbnet_cv/runner/priority.py | 61 + cv/ocr/dbnet/pytorch/dbnet_cv/runner/utils.py | 99 ++ .../dbnet/pytorch/dbnet_cv/utils/__init__.py | 78 ++ cv/ocr/dbnet/pytorch/dbnet_cv/utils/config.py | 741 +++++++++++ .../pytorch/dbnet_cv/utils/device_type.py | 24 + cv/ocr/dbnet/pytorch/dbnet_cv/utils/env.py | 120 ++ .../pytorch/dbnet_cv/utils/ext_loader.py | 72 + cv/ocr/dbnet/pytorch/dbnet_cv/utils/hub.py | 131 ++ .../dbnet/pytorch/dbnet_cv/utils/logging.py | 111 ++ cv/ocr/dbnet/pytorch/dbnet_cv/utils/misc.py | 377 ++++++ .../pytorch/dbnet_cv/utils/parrots_jit.py | 41 + .../pytorch/dbnet_cv/utils/parrots_wrapper.py | 114 ++ cv/ocr/dbnet/pytorch/dbnet_cv/utils/path.py | 101 ++ .../pytorch/dbnet_cv/utils/progressbar.py | 208 +++ .../dbnet/pytorch/dbnet_cv/utils/registry.py | 340 +++++ cv/ocr/dbnet/pytorch/dbnet_cv/utils/seed.py | 23 + .../dbnet/pytorch/dbnet_cv/utils/testing.py | 141 ++ cv/ocr/dbnet/pytorch/dbnet_cv/utils/timer.py | 118 ++ cv/ocr/dbnet/pytorch/dbnet_cv/utils/trace.py | 24 + .../pytorch/dbnet_cv/utils/version_utils.py | 90 ++ cv/ocr/dbnet/pytorch/dbnet_cv/version.py | 35 + cv/ocr/dbnet/pytorch/dist_train.sh | 19 + cv/ocr/dbnet/pytorch/requirements.txt | 12 + cv/ocr/dbnet/pytorch/setup.py | 327 +++++ cv/ocr/dbnet/pytorch/tools/analyze_logs.py | 190 +++ .../pytorch/tools/benchmark_processing.py | 60 + .../dbnet/pytorch/tools/misc/print_config.py | 55 + cv/ocr/dbnet/pytorch/tools/publish_model.py | 39 + cv/ocr/dbnet/pytorch/train.py | 232 ++++ 257 files changed, 37559 insertions(+) create mode 100755 cv/ocr/dbnet/pytorch/.gitignore create mode 100755 cv/ocr/dbnet/pytorch/CITATION.cff create mode 100755 cv/ocr/dbnet/pytorch/EGG-INFO/PKG-INFO create mode 100755 cv/ocr/dbnet/pytorch/EGG-INFO/SOURCES.txt create mode 100755 cv/ocr/dbnet/pytorch/EGG-INFO/dependency_links.txt create mode 100755 cv/ocr/dbnet/pytorch/EGG-INFO/not-zip-safe create mode 100755 cv/ocr/dbnet/pytorch/EGG-INFO/requires.txt create mode 100755 cv/ocr/dbnet/pytorch/EGG-INFO/top_level.txt create mode 100755 cv/ocr/dbnet/pytorch/LICENSE create mode 100755 cv/ocr/dbnet/pytorch/README.md create mode 100755 cv/ocr/dbnet/pytorch/configs/_base_/default_runtime.py create mode 100755 cv/ocr/dbnet/pytorch/configs/_base_/det_datasets/icdar2015.py create mode 100755 cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_mobilenetv3_fpnc.py create mode 100755 cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_r18_fpnc.py create mode 100755 cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_r50dcnv2_fpnc.py create mode 100755 cv/ocr/dbnet/pytorch/configs/_base_/det_pipelines/dbnet_pipeline.py create mode 100755 cv/ocr/dbnet/pytorch/configs/_base_/schedules/schedule_adam_1200e.py create mode 100755 cv/ocr/dbnet/pytorch/configs/_base_/schedules/schedule_sgd_1200e.py create mode 100755 cv/ocr/dbnet/pytorch/configs/textdet/dbnet/README.md create mode 100755 cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_mobilenetv3_fpnc_1200e_icdar2015.py create mode 100755 cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py create mode 100755 cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py create mode 100755 cv/ocr/dbnet/pytorch/configs/textdet/dbnet/metafile.yml create mode 100755 cv/ocr/dbnet/pytorch/dbnet/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/apis/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/apis/inference.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/apis/test.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/apis/train.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/apis/utils.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/core/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/core/evaluation/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean_ic13.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean_iou.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/core/evaluation/utils.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/core/mask.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/core/visualize.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/base_dataset.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/builder.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/icdar_dataset.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/box_utils.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/crop.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/custom_format_bundle.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/dbnet_transforms.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/loading.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/textdet_targets/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/textdet_targets/base_textdet_targets.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/textdet_targets/dbnet_targets.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/transform_wrappers.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/transforms.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/text_det_dataset.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/uniform_concat_dataset.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/utils/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/utils/backend.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/utils/loader.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/datasets/utils/parser.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/builder.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/common/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/common/backbones/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/common/backbones/unet.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/common/detectors/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/common/detectors/single_stage.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/common/losses/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/common/losses/dice_loss.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/common/losses/focal_loss.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/textdet/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/db_head.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/head_mixin.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/dbnet.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/single_stage_text_detector.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/text_detector_mixin.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/textdet/losses/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/textdet/losses/db_loss.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/textdet/necks/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/textdet/necks/fpn_cat.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/base_postprocessor.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/db_postprocessor.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/utils.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/utils/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/utils/box_util.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/utils/check_argument.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/utils/collect_env.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/utils/fileio.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/utils/lmdb_util.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/utils/logger.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/utils/model.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/utils/ocr.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/utils/setup_env.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/utils/string_util.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet/version.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/activation.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv2d_adaptive_padding.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv_module.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/drop.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/hsigmoid.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/hswish.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/norm.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/padding.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/plugin.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/registry.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/upsample.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/wrappers.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/builder.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/resnet.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/flops_counter.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/fuse_conv_bn.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/sync_bn.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/weight_init.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/engine/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/engine/test.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/fileio/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/fileio/file_client.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/base.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/json_handler.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/pickle_handler.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/yaml_handler.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/fileio/io.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/fileio/parse.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/image/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/image/colorspace.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/image/geometric.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/image/io.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/image/misc.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/image/photometric.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/deprecated.json create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/mmcls.json create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/open_mmlab.json create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/torchvision_0.12.json create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/README.md create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/common_cuda_helper.hpp create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/deform_conv_cuda_kernel.cuh create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/deform_roi_pool_cuda_kernel.cuh create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/modulated_deform_conv_cuda_kernel.cuh create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/ms_deform_attn_cuda_kernel.cuh create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/roi_align_cuda_kernel.cuh create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/roi_pool_cuda_kernel.cuh create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/sync_bn_cuda_kernel.cuh create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_cpp_helper.hpp create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_cuda_helper.hpp create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_device_registry.hpp create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/cudabind.cpp create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/deform_conv_cuda.cu create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/deform_roi_pool_cuda.cu create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/modulated_deform_conv_cuda.cu create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/ms_deform_attn_cuda.cu create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/roi_align_cuda.cu create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/roi_pool_cuda.cu create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/sync_bn_cuda.cu create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/deform_conv.cpp create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/deform_roi_pool.cpp create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/info.cpp create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/modulated_deform_conv.cpp create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/ms_deform_attn.cpp create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/pybind.cpp create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/roi_align.cpp create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/roi_pool.cpp create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/sync_bn.cpp create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/deform_conv.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/deform_roi_pool.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/deprecated_wrappers.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/info.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/modulated_deform_conv.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/multi_scale_deform_attn.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/roi_align.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/roi_pool.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/ops/sync_bn.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/parallel/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/parallel/_functions.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/parallel/collate.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/parallel/data_container.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/parallel/data_parallel.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/parallel/distributed.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/parallel/distributed_deprecated.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/parallel/registry.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/parallel/scatter_gather.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/parallel/utils.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/base_module.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/base_runner.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/builder.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/checkpoint.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/default_constructor.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/dist_utils.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/epoch_based_runner.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/fp16_utils.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/checkpoint.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/closure.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/ema.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/evaluation.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/hook.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/iter_timer.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/base.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/clearml.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/dvclive.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/mlflow.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/neptune.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/pavi.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/segmind.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/tensorboard.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/text.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/wandb.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/lr_updater.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/memory.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/momentum_updater.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/optimizer.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/profiler.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/sampler_seed.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/sync_buffer.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/iter_based_runner.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/log_buffer.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/optimizer/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/optimizer/builder.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/optimizer/default_constructor.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/priority.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/runner/utils.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/utils/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/utils/config.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/utils/device_type.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/utils/env.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/utils/ext_loader.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/utils/hub.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/utils/logging.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/utils/misc.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/utils/parrots_jit.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/utils/parrots_wrapper.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/utils/path.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/utils/progressbar.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/utils/registry.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/utils/seed.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/utils/testing.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/utils/timer.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/utils/trace.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/utils/version_utils.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_cv/version.py create mode 100755 cv/ocr/dbnet/pytorch/dist_train.sh create mode 100755 cv/ocr/dbnet/pytorch/requirements.txt create mode 100755 cv/ocr/dbnet/pytorch/setup.py create mode 100755 cv/ocr/dbnet/pytorch/tools/analyze_logs.py create mode 100755 cv/ocr/dbnet/pytorch/tools/benchmark_processing.py create mode 100755 cv/ocr/dbnet/pytorch/tools/misc/print_config.py create mode 100755 cv/ocr/dbnet/pytorch/tools/publish_model.py create mode 100755 cv/ocr/dbnet/pytorch/train.py diff --git a/cv/ocr/dbnet/pytorch/.gitignore b/cv/ocr/dbnet/pytorch/.gitignore new file mode 100755 index 00000000..628f7f26 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/.gitignore @@ -0,0 +1,49 @@ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# PyTorch checkpoint +*.pth + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +OCR_Detect/ +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache + diff --git a/cv/ocr/dbnet/pytorch/CITATION.cff b/cv/ocr/dbnet/pytorch/CITATION.cff new file mode 100755 index 00000000..7d1d93a7 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/CITATION.cff @@ -0,0 +1,9 @@ +cff-version: 1.2.0 +message: "If you use this software, please cite it as below." +title: "OpenMMLab Text Detection, Recognition and Understanding Toolbox" +authors: + - name: "MMOCR Contributors" +version: 0.3.0 +date-released: 2020-08-15 +repository-code: "https://github.com/open-mmlab/mmocr" +license: Apache-2.0 diff --git a/cv/ocr/dbnet/pytorch/EGG-INFO/PKG-INFO b/cv/ocr/dbnet/pytorch/EGG-INFO/PKG-INFO new file mode 100755 index 00000000..9280d7ec --- /dev/null +++ b/cv/ocr/dbnet/pytorch/EGG-INFO/PKG-INFO @@ -0,0 +1,380 @@ +Metadata-Version: 2.1 +Name: dbnet_det +Version: 2.25.0 +Summary: OpenMMLab Detection Toolbox and Benchmark +Home-page: https://github.com/open-mmlab/dbnet_detection +Author: dbnet_detection Contributors +Author-email: openmmlab@gmail.com +License: Apache License 2.0 +Keywords: computer vision,object detection +Classifier: Development Status :: 5 - Production/Stable +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.6 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Description-Content-Type: text/markdown +Provides-Extra: all +Provides-Extra: tests +Provides-Extra: build +Provides-Extra: optional + +
+ +
 
+
+ OpenMMLab website + + + HOT + + +      + OpenMMLab platform + + + TRY IT OUT + + +
+
 
+ +[![PyPI](https://img.shields.io/pypi/v/dbnet_det)](https://pypi.org/project/dbnet_det) +[![docs](https://img.shields.io/badge/docs-latest-blue)](https://dbnet_detection.readthedocs.io/en/latest/) +[![badge](https://github.com/open-mmlab/dbnet_detection/workflows/build/badge.svg)](https://github.com/open-mmlab/dbnet_detection/actions) +[![codecov](https://codecov.io/gh/open-mmlab/dbnet_detection/branch/master/graph/badge.svg)](https://codecov.io/gh/open-mmlab/dbnet_detection) +[![license](https://img.shields.io/github/license/open-mmlab/dbnet_detection.svg)](https://github.com/open-mmlab/dbnet_detection/blob/master/LICENSE) +[![open issues](https://isitmaintained.com/badge/open/open-mmlab/dbnet_detection.svg)](https://github.com/open-mmlab/dbnet_detection/issues) +[![issue resolution](https://isitmaintained.com/badge/resolution/open-mmlab/dbnet_detection.svg)](https://github.com/open-mmlab/dbnet_detection/issues) + +[📘Documentation](https://dbnet_detection.readthedocs.io/en/stable/) | +[🛠️Installation](https://dbnet_detection.readthedocs.io/en/stable/get_started.html) | +[👀Model Zoo](https://dbnet_detection.readthedocs.io/en/stable/model_zoo.html) | +[🆕Update News](https://dbnet_detection.readthedocs.io/en/stable/changelog.html) | +[🚀Ongoing Projects](https://github.com/open-mmlab/dbnet_detection/projects) | +[🤔Reporting Issues](https://github.com/open-mmlab/dbnet_detection/issues/new/choose) + +
+ +
+ +English | [简体中文](README_zh-CN.md) + +
+ +## Introduction + +dbnet_detection is an open source object detection toolbox based on PyTorch. It is +a part of the [OpenMMLab](https://openmmlab.com/) project. + +The master branch works with **PyTorch 1.5+**. + + + +
+Major features + +- **Modular Design** + + We decompose the detection framework into different components and one can easily construct a customized object detection framework by combining different modules. + +- **Support of multiple frameworks out of box** + + The toolbox directly supports popular and contemporary detection frameworks, *e.g.* Faster RCNN, Mask RCNN, RetinaNet, etc. + +- **High efficiency** + + All basic bbox and mask operations run on GPUs. The training speed is faster than or comparable to other codebases, including [Detectron2](https://github.com/facebookresearch/detectron2), [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark) and [SimpleDet](https://github.com/TuSimple/simpledet). + +- **State of the art** + + The toolbox stems from the codebase developed by the *dbnet_det* team, who won [COCO Detection Challenge](http://cocodataset.org/#detection-leaderboard) in 2018, and we keep pushing it forward. + +
+ +Apart from dbnet_detection, we also released a library [dbnet_cv](https://github.com/open-mmlab/dbnet_cv) for computer vision research, which is heavily depended on by this toolbox. + +## What's New + +**2.25.0** was released in 1/6/2022: + +- Support dedicated `dbnet_detWandbHook` hook +- Support [ConvNeXt](configs/convnext), [DDOD](configs/ddod), [SOLOv2](configs/solov2) +- Support [Mask2Former](configs/mask2former) for instance segmentation +- Rename [config files of Mask2Former](configs/mask2former) + +Please refer to [changelog.md](docs/en/changelog.md) for details and release history. + +For compatibility changes between different versions of dbnet_detection, please refer to [compatibility.md](docs/en/compatibility.md). + +## Installation + +Please refer to [Installation](docs/en/get_started.md/#Installation) for installation instructions. + +## Getting Started + +Please see [get_started.md](docs/en/get_started.md) for the basic usage of dbnet_detection. We provide [colab tutorial](demo/dbnet_det_Tutorial.ipynb) and [instance segmentation colab tutorial](demo/dbnet_det_InstanceSeg_Tutorial.ipynb), and other tutorials for: + +- [with existing dataset](docs/en/1_exist_data_model.md) +- [with new dataset](docs/en/2_new_data_model.md) +- [with existing dataset_new_model](docs/en/3_exist_data_new_model.md) +- [learn about configs](docs/en/tutorials/config.md) +- [customize_datasets](docs/en/tutorials/customize_dataset.md) +- [customize data pipelines](docs/en/tutorials/data_pipeline.md) +- [customize_models](docs/en/tutorials/customize_models.md) +- [customize runtime settings](docs/en/tutorials/customize_runtime.md) +- [customize_losses](docs/en/tutorials/customize_losses.md) +- [finetuning models](docs/en/tutorials/finetune.md) +- [export a model to ONNX](docs/en/tutorials/pytorch2onnx.md) +- [export ONNX to TRT](docs/en/tutorials/onnx2tensorrt.md) +- [weight initialization](docs/en/tutorials/init_cfg.md) +- [how to xxx](docs/en/tutorials/how_to.md) + +## Overview of Benchmark and Model Zoo + +Results and models are available in the [model zoo](docs/en/model_zoo.md). + +
+ Architectures +
+ + + + + + + + + + + + + + + + + +
+ Object Detection + + Instance Segmentation + + Panoptic Segmentation + + Other +
+ + + + + + + +
  • Contrastive Learning
  • + + +
  • Distillation
  • + + +
    + +
    + Components +
    + + + + + + + + + + + + + + + + + +
    + Backbones + + Necks + + Loss + + Common +
    + + + + + + + +
    + +Some other methods are also supported in [projects using dbnet_detection](./docs/en/projects.md). + +## FAQ + +Please refer to [FAQ](docs/en/faq.md) for frequently asked questions. + +## Contributing + +We appreciate all contributions to improve dbnet_detection. Ongoing projects can be found in out [GitHub Projects](https://github.com/open-mmlab/dbnet_detection/projects). Welcome community users to participate in these projects. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline. + +## Acknowledgement + +dbnet_detection is an open source project that is contributed by researchers and engineers from various colleges and companies. We appreciate all the contributors who implement their methods or add new features, as well as users who give valuable feedbacks. +We wish that the toolbox and benchmark could serve the growing research community by providing a flexible toolkit to reimplement existing methods and develop their own new detectors. + +## Citation + +If you use this toolbox or benchmark in your research, please cite this project. + +``` +@article{dbnet_detection, + title = {{dbnet_detection}: Open MMLab Detection Toolbox and Benchmark}, + author = {Chen, Kai and Wang, Jiaqi and Pang, Jiangmiao and Cao, Yuhang and + Xiong, Yu and Li, Xiaoxiao and Sun, Shuyang and Feng, Wansen and + Liu, Ziwei and Xu, Jiarui and Zhang, Zheng and Cheng, Dazhi and + Zhu, Chenchen and Cheng, Tianheng and Zhao, Qijie and Li, Buyu and + Lu, Xin and Zhu, Rui and Wu, Yue and Dai, Jifeng and Wang, Jingdong + and Shi, Jianping and Ouyang, Wanli and Loy, Chen Change and Lin, Dahua}, + journal= {arXiv preprint arXiv:1906.07155}, + year={2019} +} +``` + +## License + +This project is released under the [Apache 2.0 license](LICENSE). + +## Projects in OpenMMLab + +- [DBNET_CV](https://github.com/open-mmlab/dbnet_cv): OpenMMLab foundational library for computer vision. +- [MIM](https://github.com/open-mmlab/mim): MIM installs OpenMMLab packages. +- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab image classification toolbox and benchmark. +- [dbnet_detection](https://github.com/open-mmlab/dbnet_detection): OpenMMLab detection toolbox and benchmark. +- [dbnet_detection3D](https://github.com/open-mmlab/dbnet_detection3d): OpenMMLab's next-generation platform for general 3D object detection. +- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab rotated object detection toolbox and benchmark. +- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab semantic segmentation toolbox and benchmark. +- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab text detection, recognition, and understanding toolbox. +- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark. +- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 3D human parametric model toolbox and benchmark. +- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab self-supervised learning toolbox and benchmark. +- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab model compression toolbox and benchmark. +- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab fewshot learning toolbox and benchmark. +- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab's next-generation action understanding toolbox and benchmark. +- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab video perception toolbox and benchmark. +- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab optical flow toolbox and benchmark. +- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab image and video editing toolbox. +- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab image and video generative models toolbox. +- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab model deployment framework. diff --git a/cv/ocr/dbnet/pytorch/EGG-INFO/SOURCES.txt b/cv/ocr/dbnet/pytorch/EGG-INFO/SOURCES.txt new file mode 100755 index 00000000..4f42864d --- /dev/null +++ b/cv/ocr/dbnet/pytorch/EGG-INFO/SOURCES.txt @@ -0,0 +1,62 @@ +README.md +setup.py +dbnet_det/__init__.py +dbnet_det/version.py +dbnet_det.egg-info/PKG-INFO +dbnet_det.egg-info/SOURCES.txt +dbnet_det.egg-info/dependency_links.txt +dbnet_det.egg-info/not-zip-safe +dbnet_det.egg-info/requires.txt +dbnet_det.egg-info/top_level.txt +dbnet_det/core/__init__.py +dbnet_det/core/bbox/__init__.py +dbnet_det/core/bbox/transforms.py +dbnet_det/core/evaluation/__init__.py +dbnet_det/core/evaluation/bbox_overlaps.py +dbnet_det/core/evaluation/class_names.py +dbnet_det/core/evaluation/eval_hooks.py +dbnet_det/core/evaluation/mean_ap.py +dbnet_det/core/evaluation/panoptic_utils.py +dbnet_det/core/evaluation/recall.py +dbnet_det/core/mask/__init__.py +dbnet_det/core/mask/structures.py +dbnet_det/core/mask/utils.py +dbnet_det/core/utils/__init__.py +dbnet_det/core/utils/dist_utils.py +dbnet_det/core/utils/misc.py +dbnet_det/core/visualization/__init__.py +dbnet_det/core/visualization/image.py +dbnet_det/core/visualization/palette.py +dbnet_det/datasets/__init__.py +dbnet_det/datasets/builder.py +dbnet_det/datasets/coco.py +dbnet_det/datasets/custom.py +dbnet_det/datasets/dataset_wrappers.py +dbnet_det/datasets/utils.py +dbnet_det/datasets/api_wrappers/__init__.py +dbnet_det/datasets/api_wrappers/coco_api.py +dbnet_det/datasets/api_wrappers/panoptic_evaluation.py +dbnet_det/datasets/pipelines/__init__.py +dbnet_det/datasets/pipelines/compose.py +dbnet_det/datasets/pipelines/formatting.py +dbnet_det/datasets/pipelines/loading.py +dbnet_det/datasets/pipelines/test_time_aug.py +dbnet_det/datasets/pipelines/transforms.py +dbnet_det/datasets/samplers/__init__.py +dbnet_det/datasets/samplers/class_aware_sampler.py +dbnet_det/datasets/samplers/distributed_sampler.py +dbnet_det/datasets/samplers/group_sampler.py +dbnet_det/datasets/samplers/infinite_sampler.py +dbnet_det/models/__init__.py +dbnet_det/models/builder.py +dbnet_det/models/backbones/__init__.py +dbnet_det/models/backbones/resnet.py +dbnet_det/models/detectors/__init__.py +dbnet_det/models/detectors/base.py +dbnet_det/models/detectors/single_stage.py +dbnet_det/models/utils/__init__.py +dbnet_det/models/utils/res_layer.py +dbnet_det/utils/__init__.py +dbnet_det/utils/logger.py +dbnet_det/utils/profiling.py +dbnet_det/utils/util_distribution.py \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/EGG-INFO/dependency_links.txt b/cv/ocr/dbnet/pytorch/EGG-INFO/dependency_links.txt new file mode 100755 index 00000000..8b137891 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/EGG-INFO/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/cv/ocr/dbnet/pytorch/EGG-INFO/not-zip-safe b/cv/ocr/dbnet/pytorch/EGG-INFO/not-zip-safe new file mode 100755 index 00000000..8b137891 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/EGG-INFO/not-zip-safe @@ -0,0 +1 @@ + diff --git a/cv/ocr/dbnet/pytorch/EGG-INFO/requires.txt b/cv/ocr/dbnet/pytorch/EGG-INFO/requires.txt new file mode 100755 index 00000000..22c49def --- /dev/null +++ b/cv/ocr/dbnet/pytorch/EGG-INFO/requires.txt @@ -0,0 +1,59 @@ +matplotlib +numpy +pycocotools +six +terminaltables + +[all] +cython +numpy +cityscapesscripts +imagecorruptions +scipy +sklearn +timm +matplotlib +pycocotools +six +terminaltables +asynctest +codecov +flake8 +interrogate +isort==4.3.21 +kwarray +mmtrack +onnx==1.7.0 +onnxruntime>=1.8.0 +protobuf<=3.20.1 +pytest +ubelt +xdoctest>=0.10.0 +yapf + +[build] +cython +numpy + +[optional] +cityscapesscripts +imagecorruptions +scipy +sklearn +timm + +[tests] +asynctest +codecov +flake8 +interrogate +isort==4.3.21 +kwarray +mmtrack +onnx==1.7.0 +onnxruntime>=1.8.0 +protobuf<=3.20.1 +pytest +ubelt +xdoctest>=0.10.0 +yapf diff --git a/cv/ocr/dbnet/pytorch/EGG-INFO/top_level.txt b/cv/ocr/dbnet/pytorch/EGG-INFO/top_level.txt new file mode 100755 index 00000000..2d2f43c4 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/EGG-INFO/top_level.txt @@ -0,0 +1 @@ +dbnet_det diff --git a/cv/ocr/dbnet/pytorch/LICENSE b/cv/ocr/dbnet/pytorch/LICENSE new file mode 100755 index 00000000..3076a437 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/LICENSE @@ -0,0 +1,203 @@ +Copyright (c) MMOCR Authors. All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021 MMOCR Authors. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/cv/ocr/dbnet/pytorch/README.md b/cv/ocr/dbnet/pytorch/README.md new file mode 100755 index 00000000..86641bbb --- /dev/null +++ b/cv/ocr/dbnet/pytorch/README.md @@ -0,0 +1,65 @@ +# DBnet +## Model description +Recently, segmentation-based methods are quite popular in scene text detection, as the segmentation results can more accurately describe scene text of various shapes such as curve text. However, the post-processing of binarization is essential for segmentation-based detection, which converts probability maps produced by a segmentation method into bounding boxes/regions of text. In this paper, we propose a module named Differentiable Binarization (DB), which can perform the binarization process in a segmentation network. Optimized along with a DB module, a segmentation network can adaptively set the thresholds for binarization, which not only simplifies the post-processing but also enhances the performance of text detection. Based on a simple segmentation network, we validate the performance improvements of DB on five benchmark datasets, which consistently achieves state-of-the-art results, in terms of both detection accuracy and speed. In particular, with a light-weight backbone, the performance improvements by DB are significant so that we can look for an ideal tradeoff between detection accuracy and efficiency. +## Step 2: Preparing datasets + +```shell +$ mkdir data +$ cd data +``` +ICDAR 2015 +Please [ICDAR 2015](https://rrc.cvc.uab.es/?ch=4&com=downloads) download ICDAR 2015 here +ch4_training_images.zip、ch4_test_images.zip、ch4_training_localization_transcription_gt.zip、Challenge4_Test_Task1_GT.zip + +```shell +mkdir icdar2015 && cd icdar2015 +mkdir imgs && mkdir annotations + +mv ch4_training_images imgs/training +mv ch4_test_images imgs/test + +mv ch4_training_localization_transcription_gt annotations/training +mv Challenge4_Test_Task1_GT annotations/test +``` +Please [instances_training.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_training.json) download instances_training.json here +Please [instances_test.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_test.json) download instances_test.json here + + +├── icdar2015 +│   ├── imgs +│   ├── instances_test.json +│   └── instances_training.json + +### Build Extension + +```shell +$ DBNET_CV_WITH_OPS=1 python3 setup.py build && cp build/lib.linux*/dbnet_cv/_ext.cpython* dbnet_cv +``` +### Install packages + +```shell +$ pip3 install -r requirements.txt +``` + +### Training on single card +```shell +$ python3 train.py configs/textdet/dbnet/dbnet_mobilenetv3_fpnc_1200e_icdar2015.py +``` + +### Training on mutil-cards +```shell +$ bash dist_train.sh configs/textdet/dbnet/dbnet_mobilenetv3_fpnc_1200e_icdar2015.py 8 +``` + +## Results on BI-V100 + +| approach| GPUs | train mem | train FPS | +| :-----: |:-------:| :-------: |:--------: | +| dbnet | BI100x8 | 5426 | 54.375 | + +|0_hmean-iou:recall: | 0_hmean-iou:precision: | 0_hmean-iou:hmean:| +| :-----: | :-------: | :-------: | +| 0.7111 | 0.8062 | 0.7557 | + +## Reference +https://github.com/open-mmlab/mmocr diff --git a/cv/ocr/dbnet/pytorch/configs/_base_/default_runtime.py b/cv/ocr/dbnet/pytorch/configs/_base_/default_runtime.py new file mode 100755 index 00000000..de7f9650 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/configs/_base_/default_runtime.py @@ -0,0 +1,17 @@ +# yapf:disable +log_config = dict( + interval=5, + hooks=[ + dict(type='TextLoggerHook') + ]) +# yapf:enable +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] + +# disable opencv multithreading to avoid system being overloaded +opencv_num_threads = 0 +# set multi-process start method as `fork` to speed up the training +mp_start_method = 'fork' diff --git a/cv/ocr/dbnet/pytorch/configs/_base_/det_datasets/icdar2015.py b/cv/ocr/dbnet/pytorch/configs/_base_/det_datasets/icdar2015.py new file mode 100755 index 00000000..f711c06d --- /dev/null +++ b/cv/ocr/dbnet/pytorch/configs/_base_/det_datasets/icdar2015.py @@ -0,0 +1,18 @@ +dataset_type = 'IcdarDataset' +data_root = 'data/icdar2015' + +train = dict( + type=dataset_type, + ann_file=f'{data_root}/instances_training.json', + img_prefix=f'{data_root}/imgs', + pipeline=None) + +test = dict( + type=dataset_type, + ann_file=f'{data_root}/instances_test.json', + img_prefix=f'{data_root}/imgs', + pipeline=None) + +train_list = [train] + +test_list = [test] diff --git a/cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_mobilenetv3_fpnc.py b/cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_mobilenetv3_fpnc.py new file mode 100755 index 00000000..e6bff598 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_mobilenetv3_fpnc.py @@ -0,0 +1,32 @@ +model = dict( + type='DBNet', + backbone=dict( + type='dbnet_det.MobileNetV3', + arch='large', + # num_stages=3, + out_indices=(3, 6, 12, 16), + norm_cfg=dict(type='BN', requires_grad=True), + init_cfg=dict(type='Pretrained', checkpoint='https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth') + ), + # backbone=dict( + # type='dbnet_det.ResNet', + # depth=18, + # num_stages=4, + # out_indices=(0, 1, 2, 3), + # frozen_stages=-1, + # norm_cfg=dict(type='BN', requires_grad=True), + # # init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18'), + # norm_eval=False, + # style='caffe'), + neck=dict( + type='FPNC', in_channels=[24, 40, 112, 960], lateral_channels=256), + bbox_head=dict( + type='DBHead', + in_channels=256, + loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=False), + postprocessor=dict(type='DBPostprocessor', text_repr_type='quad')), + train_cfg=None, + test_cfg=None) + + + diff --git a/cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_r18_fpnc.py b/cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_r18_fpnc.py new file mode 100755 index 00000000..b26391f4 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_r18_fpnc.py @@ -0,0 +1,31 @@ +model = dict( + type='DBNet', + backbone=dict( + type='dbnet_det.ResNet', + depth=18, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='BN', requires_grad=True), + # init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18'), + norm_eval=False, + style='caffe'), + neck=dict( + type='FPNC', in_channels=[64, 128, 256, 512], lateral_channels=256), + bbox_head=dict( + type='DBHead', + in_channels=256, + loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=False), + postprocessor=dict(type='DBPostprocessor', text_repr_type='quad')), + train_cfg=None, + test_cfg=None) + + + +# backbone=dict( +# type='MobileNetV3', +# arch='small', +# out_indices=(0, 1, 12), +# norm_cfg=dict(type='BN', requires_grad=True), +# init_cfg=dict(type='Pretrained', checkpoint='open-mmlab://contrib/mobilenet_v3_small') +# ), \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_r50dcnv2_fpnc.py b/cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_r50dcnv2_fpnc.py new file mode 100755 index 00000000..32b09f50 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/configs/_base_/det_models/dbnet_r50dcnv2_fpnc.py @@ -0,0 +1,23 @@ +model = dict( + type='DBNet', + backbone=dict( + type='dbnet_det.ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + style='pytorch', + dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False), + # init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'), + stage_with_dcn=(False, True, True, True)), + neck=dict( + type='FPNC', in_channels=[256, 512, 1024, 2048], lateral_channels=256), + bbox_head=dict( + type='DBHead', + in_channels=256, + loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True), + postprocessor=dict(type='DBPostprocessor', text_repr_type='quad')), + train_cfg=None, + test_cfg=None) diff --git a/cv/ocr/dbnet/pytorch/configs/_base_/det_pipelines/dbnet_pipeline.py b/cv/ocr/dbnet/pytorch/configs/_base_/det_pipelines/dbnet_pipeline.py new file mode 100755 index 00000000..aeb94487 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/configs/_base_/det_pipelines/dbnet_pipeline.py @@ -0,0 +1,88 @@ +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +train_pipeline_r18 = [ + dict(type='LoadImageFromFile', color_type='color_ignore_orientation'), + dict( + type='LoadTextAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5), + dict(type='Normalize', **img_norm_cfg), + dict( + type='ImgAug', + args=[['Fliplr', 0.5], + dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]]), + dict(type='EastRandomCrop', target_size=(640, 640)), + dict(type='DBNetTargets', shrink_ratio=0.4), + dict(type='Pad', size_divisor=32), + dict( + type='CustomFormatBundle', + keys=['gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask'], + visualize=dict(flag=False, boundary_key='gt_shrink')), + dict( + type='Collect', + keys=['img', 'gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask']) +] + +test_pipeline_1333_736 = [ + dict(type='LoadImageFromFile', color_type='color_ignore_orientation'), + dict( + type='MultiScaleFlipAug', + img_scale=(1280, 736), # used by Resize + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] + +# for dbnet_r50dcnv2_fpnc +img_norm_cfg_r50dcnv2 = dict( + mean=[122.67891434, 116.66876762, 104.00698793], + std=[58.395, 57.12, 57.375], + to_rgb=True) + +train_pipeline_r50dcnv2 = [ + dict(type='LoadImageFromFile', color_type='color_ignore_orientation'), + dict( + type='LoadTextAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5), + dict(type='Normalize', **img_norm_cfg_r50dcnv2), + dict( + type='ImgAug', + args=[['Fliplr', 0.5], + dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]]), + dict(type='EastRandomCrop', target_size=(640, 640)), + dict(type='DBNetTargets', shrink_ratio=0.4), + dict(type='Pad', size_divisor=32), + dict( + type='CustomFormatBundle', + keys=['gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask'], + visualize=dict(flag=False, boundary_key='gt_shrink')), + dict( + type='Collect', + keys=['img', 'gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask']) +] + +test_pipeline_4068_1024 = [ + dict(type='LoadImageFromFile', color_type='color_ignore_orientation'), + dict( + type='MultiScaleFlipAug', + img_scale=(4068, 1024), # used by Resize + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='Normalize', **img_norm_cfg_r50dcnv2), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] diff --git a/cv/ocr/dbnet/pytorch/configs/_base_/schedules/schedule_adam_1200e.py b/cv/ocr/dbnet/pytorch/configs/_base_/schedules/schedule_adam_1200e.py new file mode 100755 index 00000000..6be9df00 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/configs/_base_/schedules/schedule_adam_1200e.py @@ -0,0 +1,8 @@ +# optimizer +optimizer = dict(type='AdamW', lr=1e-3,betas=(0.9, 0.999), weight_decay=0.05) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='poly', power=0.9) +# running settings +runner = dict(type='EpochBasedRunner', max_epochs=1200) +checkpoint_config = dict(interval=50) \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/configs/_base_/schedules/schedule_sgd_1200e.py b/cv/ocr/dbnet/pytorch/configs/_base_/schedules/schedule_sgd_1200e.py new file mode 100755 index 00000000..bc7fbf69 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/configs/_base_/schedules/schedule_sgd_1200e.py @@ -0,0 +1,8 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.007, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-7, by_epoch=True) +# running settings +runner = dict(type='EpochBasedRunner', max_epochs=1200) +checkpoint_config = dict(interval=100) diff --git a/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/README.md b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/README.md new file mode 100755 index 00000000..d2007c72 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/README.md @@ -0,0 +1,33 @@ +# DBNet + +> [Real-time Scene Text Detection with Differentiable Binarization](https://arxiv.org/abs/1911.08947) + + + +## Abstract + +Recently, segmentation-based methods are quite popular in scene text detection, as the segmentation results can more accurately describe scene text of various shapes such as curve text. However, the post-processing of binarization is essential for segmentation-based detection, which converts probability maps produced by a segmentation method into bounding boxes/regions of text. In this paper, we propose a module named Differentiable Binarization (DB), which can perform the binarization process in a segmentation network. Optimized along with a DB module, a segmentation network can adaptively set the thresholds for binarization, which not only simplifies the post-processing but also enhances the performance of text detection. Based on a simple segmentation network, we validate the performance improvements of DB on five benchmark datasets, which consistently achieves state-of-the-art results, in terms of both detection accuracy and speed. In particular, with a light-weight backbone, the performance improvements by DB are significant so that we can look for an ideal tradeoff between detection accuracy and efficiency. Specifically, with a backbone of ResNet-18, our detector achieves an F-measure of 82.8, running at 62 FPS, on the MSRA-TD500 dataset. + +
    + +
    + +## Results and models + +### ICDAR2015 + +| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download | +| :---------------------------------------: | :-------------------------------------------------: | :-------------: | :------------: | :-----: | :-------: | :----: | :-------: | :---: | :-----------------------------------------: | +| [DBNet_r18](/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py) | ImageNet | ICDAR2015 Train | ICDAR2015 Test | 1200 | 736 | 0.731 | 0.871 | 0.795 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.log.json) | +| [DBNet_r50dcn](/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py) | [Synthtext](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_2e_synthtext_20210325-aa96e477.pth) | ICDAR2015 Train | ICDAR2015 Test | 1200 | 1024 | 0.814 | 0.868 | 0.840 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20211025-9fe3b590.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20211025-9fe3b590.log.json) | + +## Citation + +```bibtex +@article{Liao_Wan_Yao_Chen_Bai_2020, + title={Real-Time Scene Text Detection with Differentiable Binarization}, + journal={Proceedings of the AAAI Conference on Artificial Intelligence}, + author={Liao, Minghui and Wan, Zhaoyi and Yao, Cong and Chen, Kai and Bai, Xiang}, + year={2020}, + pages={11474-11481}} +``` diff --git a/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_mobilenetv3_fpnc_1200e_icdar2015.py b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_mobilenetv3_fpnc_1200e_icdar2015.py new file mode 100755 index 00000000..e8bcd2bd --- /dev/null +++ b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_mobilenetv3_fpnc_1200e_icdar2015.py @@ -0,0 +1,33 @@ +_base_ = [ + '../../_base_/default_runtime.py', + '../../_base_/schedules/schedule_sgd_1200e.py', + '../../_base_/det_models/dbnet_mobilenetv3_fpnc.py', + '../../_base_/det_datasets/icdar2015.py', + '../../_base_/det_pipelines/dbnet_pipeline.py' +] + +train_list = {{_base_.train_list}} +test_list = {{_base_.test_list}} + +train_pipeline_r18 = {{_base_.train_pipeline_r18}} +test_pipeline_1333_736 = {{_base_.test_pipeline_1333_736}} + +data = dict( + samples_per_gpu=8, + workers_per_gpu=8, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), + train=dict( + type='UniformConcatDataset', + datasets=train_list, + pipeline=train_pipeline_r18), + val=dict( + type='UniformConcatDataset', + datasets=test_list, + pipeline=test_pipeline_1333_736), + test=dict( + type='UniformConcatDataset', + datasets=test_list, + pipeline=test_pipeline_1333_736)) +fp16 = dict(loss_scale='dynamic') +evaluation = dict(interval=50, metric='hmean-iou') diff --git a/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py new file mode 100755 index 00000000..9ea1bcc1 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py @@ -0,0 +1,33 @@ +_base_ = [ + '../../_base_/default_runtime.py', + '../../_base_/schedules/schedule_sgd_1200e.py', + '../../_base_/det_models/dbnet_r18_fpnc.py', + '../../_base_/det_datasets/icdar2015.py', + '../../_base_/det_pipelines/dbnet_pipeline.py' +] + +train_list = {{_base_.train_list}} +test_list = {{_base_.test_list}} + +train_pipeline_r18 = {{_base_.train_pipeline_r18}} +test_pipeline_1333_736 = {{_base_.test_pipeline_1333_736}} + +data = dict( + samples_per_gpu=16, + workers_per_gpu=8, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), + train=dict( + type='UniformConcatDataset', + datasets=train_list, + pipeline=train_pipeline_r18), + val=dict( + type='UniformConcatDataset', + datasets=test_list, + pipeline=test_pipeline_1333_736), + test=dict( + type='UniformConcatDataset', + datasets=test_list, + pipeline=test_pipeline_1333_736)) +fp16 = dict(loss_scale='dynamic') +evaluation = dict(interval=100, metric='hmean-iou') diff --git a/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py new file mode 100755 index 00000000..06e2545b --- /dev/null +++ b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py @@ -0,0 +1,37 @@ +_base_ = [ + '../../_base_/default_runtime.py', + '../../_base_/schedules/schedule_sgd_1200e.py', + '../../_base_/det_models/dbnet_r50dcnv2_fpnc.py', + '../../_base_/det_datasets/icdar2015.py', + '../../_base_/det_pipelines/dbnet_pipeline.py' +] + +train_list = {{_base_.train_list}} +test_list = {{_base_.test_list}} + +train_pipeline_r50dcnv2 = {{_base_.train_pipeline_r50dcnv2}} +test_pipeline_4068_1024 = {{_base_.test_pipeline_4068_1024}} + +# load_from = 'checkpoints/textdet/dbnet/res50dcnv2_synthtext.pth' +# fp16 = dict(loss_scale='dynamic') + +data = dict( + samples_per_gpu=8, + workers_per_gpu=4, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), + train=dict( + type='UniformConcatDataset', + datasets=train_list, + pipeline=train_pipeline_r50dcnv2), + val=dict( + type='UniformConcatDataset', + datasets=test_list, + pipeline=test_pipeline_4068_1024), + test=dict( + type='UniformConcatDataset', + datasets=test_list, + pipeline=test_pipeline_4068_1024)) +fp16 = dict(loss_scale='dynamic') +evaluation = dict(interval=1, metric='hmean-iou') +# fp16 = dict(loss_scale='dynamic') diff --git a/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/metafile.yml b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/metafile.yml new file mode 100755 index 00000000..c6abdbca --- /dev/null +++ b/cv/ocr/dbnet/pytorch/configs/textdet/dbnet/metafile.yml @@ -0,0 +1,40 @@ +Collections: +- Name: DBNet + Metadata: + Training Data: ICDAR2015 + Training Techniques: + - SGD with Momentum + - Weight Decay + Training Resources: 1x GeForce GTX 1080 Ti + Architecture: + - ResNet + - FPNC + Paper: + URL: https://arxiv.org/pdf/1911.08947.pdf + Title: 'Real-time Scene Text Detection with Differentiable Binarization' + README: configs/textdet/dbnet/README.md + +Models: + - Name: dbnet_r18_fpnc_1200e_icdar2015 + In Collection: DBNet + Config: configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py + Metadata: + Training Data: ICDAR2015 + Results: + - Task: Text Detection + Dataset: ICDAR2015 + Metrics: + hmean-iou: 0.795 + Weights: https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth + + - Name: dbnet_r50dcnv2_fpnc_1200e_icdar2015 + In Collection: DBNet + Config: configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py + Metadata: + Training Data: ICDAR2015 + Results: + - Task: Text Detection + Dataset: ICDAR2015 + Metrics: + hmean-iou: 0.840 + Weights: https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20211025-9fe3b590.pth diff --git a/cv/ocr/dbnet/pytorch/dbnet/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/__init__.py new file mode 100755 index 00000000..ed46a671 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/__init__.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import dbnet_cv +import dbnet_det +from packaging.version import parse + +from .version import __version__, short_version + + +def digit_version(version_str: str, length: int = 4): + """Convert a version string into a tuple of integers. + + This method is usually used for comparing two versions. For pre-release + versions: alpha < beta < rc. + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Default: 4. + Returns: + tuple[int]: The version info in digits (integers). + """ + version = parse(version_str) + assert version.release, f'failed to parse version {version_str}' + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + mapping = {'a': -3, 'b': -2, 'rc': -1} + val = -4 + # version.pre can be None + if version.pre: + if version.pre[0] not in mapping: + warnings.warn(f'unknown prerelease version {version.pre[0]}, ' + 'version checking may go wrong') + else: + val = mapping[version.pre[0]] + release.extend([val, version.pre[-1]]) + else: + release.extend([val, 0]) + + elif version.is_postrelease: + release.extend([1, version.post]) + else: + release.extend([0, 0]) + return tuple(release) + + +dbnet_cv_minimum_version = '1.3.8' +dbnet_cv_maximum_version = '1.7.0' +dbnet_cv_version = digit_version(dbnet_cv.__version__) + +assert (dbnet_cv_version >= digit_version(dbnet_cv_minimum_version) + and dbnet_cv_version <= digit_version(dbnet_cv_maximum_version)), \ + f'DBNET_CV {dbnet_cv.__version__} is incompatible with MMOCR {__version__}. ' \ + f'Please use DBNET_CV >= {dbnet_cv_minimum_version}, ' \ + f'<= {dbnet_cv_maximum_version} instead.' + +# dbnet_det_minimum_version = '2.21.0' +# dbnet_det_maximum_version = '3.0.0' +# dbnet_det_version = digit_version(dbnet_det.__version__) + +# assert (dbnet_det_version >= digit_version(dbnet_det_minimum_version) +# and dbnet_det_version <= digit_version(dbnet_det_maximum_version)), \ +# f'dbnet_detection {dbnet_det.__version__} is incompatible ' \ +# f'with MMOCR {__version__}. ' \ +# f'Please use dbnet_detection >= {dbnet_det_minimum_version}, ' \ +# f'<= {dbnet_det_maximum_version} instead.' + +__all__ = ['__version__', 'short_version', 'digit_version'] diff --git a/cv/ocr/dbnet/pytorch/dbnet/apis/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/apis/__init__.py new file mode 100755 index 00000000..998771be --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/apis/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .inference import init_detector, model_inference +from .test import single_gpu_test +from .train import init_random_seed, train_detector +# from .utils import (disable_text_recog_aug_test, replace_image_to_tensor, +# tensor2grayimgs) + +# __all__ = [ +# 'model_inference', 'train_detector', 'init_detector', 'init_random_seed', +# 'replace_image_to_tensor', 'disable_text_recog_aug_test', +# 'single_gpu_test', 'tensor2grayimgs' +# ] diff --git a/cv/ocr/dbnet/pytorch/dbnet/apis/inference.py b/cv/ocr/dbnet/pytorch/dbnet/apis/inference.py new file mode 100755 index 00000000..33be6297 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/apis/inference.py @@ -0,0 +1,238 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import dbnet_cv +import numpy as np +import torch +from dbnet_cv.ops import RoIPool +from dbnet_cv.parallel import collate, scatter +from dbnet_cv.runner import load_checkpoint +from dbnet_det.core import get_classes +from dbnet_det.datasets import replace_ImageToTensor +from dbnet_det.datasets.pipelines import Compose + +from dbnet.models import build_detector +from dbnet.utils import is_2dlist +from .utils import disable_text_recog_aug_test + + +def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None): + """Initialize a detector from config file. + + Args: + config (str or :obj:`dbnet_cv.Config`): Config file path or the config + object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + cfg_options (dict): Options to override some settings in the used + config. + + Returns: + nn.Module: The constructed detector. + """ + if isinstance(config, str): + config = dbnet_cv.Config.fromfile(config) + elif not isinstance(config, dbnet_cv.Config): + raise TypeError('config must be a filename or Config object, ' + f'but got {type(config)}') + if cfg_options is not None: + config.merge_from_dict(cfg_options) + if config.model.get('pretrained'): + config.model.pretrained = None + config.model.train_cfg = None + model = build_detector(config.model, test_cfg=config.get('test_cfg')) + if checkpoint is not None: + checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') + if 'CLASSES' in checkpoint.get('meta', {}): + model.CLASSES = checkpoint['meta']['CLASSES'] + else: + warnings.simplefilter('once') + warnings.warn('Class names are not saved in the checkpoint\'s ' + 'meta data, use COCO classes by default.') + model.CLASSES = get_classes('coco') + model.cfg = config # save the config in the model for convenience + model.to(device) + model.eval() + return model + + +def model_inference(model, + imgs, + ann=None, + batch_mode=False, + return_data=False): + """Inference image(s) with the detector. + + Args: + model (nn.Module): The loaded detector. + imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]): + Either image files or loaded images. + batch_mode (bool): If True, use batch mode for inference. + ann (dict): Annotation info for key information extraction. + return_data: Return postprocessed data. + Returns: + result (dict): Predicted results. + """ + + if isinstance(imgs, (list, tuple)): + is_batch = True + if len(imgs) == 0: + raise Exception('empty imgs provided, please check and try again') + if not isinstance(imgs[0], (np.ndarray, str)): + raise AssertionError('imgs must be strings or numpy arrays') + + elif isinstance(imgs, (np.ndarray, str)): + imgs = [imgs] + is_batch = False + else: + raise AssertionError('imgs must be strings or numpy arrays') + + is_ndarray = isinstance(imgs[0], np.ndarray) + + cfg = model.cfg + + if batch_mode: + cfg = disable_text_recog_aug_test(cfg, set_types=['test']) + + device = next(model.parameters()).device # model device + + if cfg.data.test.get('pipeline', None) is None: + if is_2dlist(cfg.data.test.datasets): + cfg.data.test.pipeline = cfg.data.test.datasets[0][0].pipeline + else: + cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline + if is_2dlist(cfg.data.test.pipeline): + cfg.data.test.pipeline = cfg.data.test.pipeline[0] + + if is_ndarray: + cfg = cfg.copy() + # set loading pipeline type + cfg.data.test.pipeline[0].type = 'LoadImageFromNdarray' + + cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) + test_pipeline = Compose(cfg.data.test.pipeline) + + datas = [] + for img in imgs: + # prepare data + if is_ndarray: + # directly add img + data = dict( + img=img, + ann_info=ann, + img_info=dict(width=img.shape[1], height=img.shape[0]), + bbox_fields=[]) + else: + # add information into dict + data = dict( + img_info=dict(filename=img), + img_prefix=None, + ann_info=ann, + bbox_fields=[]) + if ann is not None: + data.update(dict(**ann)) + + # build the data pipeline + data = test_pipeline(data) + # get tensor from list to stack for batch mode (text detection) + if batch_mode: + if cfg.data.test.pipeline[1].type == 'MultiScaleFlipAug': + for key, value in data.items(): + data[key] = value[0] + datas.append(data) + + if isinstance(datas[0]['img'], list) and len(datas) > 1: + raise Exception('aug test does not support ' + f'inference with batch size ' + f'{len(datas)}') + + data = collate(datas, samples_per_gpu=len(imgs)) + + # process img_metas + if isinstance(data['img_metas'], list): + data['img_metas'] = [ + img_metas.data[0] for img_metas in data['img_metas'] + ] + else: + data['img_metas'] = data['img_metas'].data + + if isinstance(data['img'], list): + data['img'] = [img.data for img in data['img']] + if isinstance(data['img'][0], list): + data['img'] = [img[0] for img in data['img']] + else: + data['img'] = data['img'].data + + # for KIE models + if ann is not None: + data['relations'] = data['relations'].data[0] + data['gt_bboxes'] = data['gt_bboxes'].data[0] + data['texts'] = data['texts'].data[0] + data['img'] = data['img'][0] + data['img_metas'] = data['img_metas'][0] + + if next(model.parameters()).is_cuda: + # scatter to specified GPU + data = scatter(data, [device])[0] + else: + for m in model.modules(): + assert not isinstance( + m, RoIPool + ), 'CPU inference with RoIPool is not supported currently.' + + # forward the model + with torch.no_grad(): + results = model(return_loss=False, rescale=True, **data) + + if not is_batch: + if not return_data: + return results[0] + return results[0], datas[0] + else: + if not return_data: + return results + return results, datas + + +def text_model_inference(model, input_sentence): + """Inference text(s) with the entity recognizer. + + Args: + model (nn.Module): The loaded recognizer. + input_sentence (str): A text entered by the user. + + Returns: + result (dict): Predicted results. + """ + + assert isinstance(input_sentence, str) + + cfg = model.cfg + if cfg.data.test.get('pipeline', None) is None: + if is_2dlist(cfg.data.test.datasets): + cfg.data.test.pipeline = cfg.data.test.datasets[0][0].pipeline + else: + cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline + if is_2dlist(cfg.data.test.pipeline): + cfg.data.test.pipeline = cfg.data.test.pipeline[0] + test_pipeline = Compose(cfg.data.test.pipeline) + data = {'text': input_sentence, 'label': {}} + + # build the data pipeline + data = test_pipeline(data) + if isinstance(data['img_metas'], dict): + img_metas = data['img_metas'] + else: + img_metas = data['img_metas'].data + + assert isinstance(img_metas, dict) + img_metas = { + 'input_ids': img_metas['input_ids'].unsqueeze(0), + 'attention_masks': img_metas['attention_masks'].unsqueeze(0), + 'token_type_ids': img_metas['token_type_ids'].unsqueeze(0), + 'labels': img_metas['labels'].unsqueeze(0) + } + # forward the model + with torch.no_grad(): + result = model(None, img_metas, return_loss=False) + return result diff --git a/cv/ocr/dbnet/pytorch/dbnet/apis/test.py b/cv/ocr/dbnet/pytorch/dbnet/apis/test.py new file mode 100755 index 00000000..fdfa7470 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/apis/test.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +import dbnet_cv +import numpy as np +import torch +from dbnet_cv.image import tensor2imgs +from dbnet_cv.parallel import DataContainer +from dbnet_det.core import encode_mask_results + +from .utils import tensor2grayimgs + + +def retrieve_img_tensor_and_meta(data): + """Retrieval img_tensor, img_metas and img_norm_cfg. + + Args: + data (dict): One batch data from data_loader. + + Returns: + tuple: Returns (img_tensor, img_metas, img_norm_cfg). + + - | img_tensor (Tensor): Input image tensor with shape + :math:`(N, C, H, W)`. + - | img_metas (list[dict]): The metadata of images. + - | img_norm_cfg (dict): Config for image normalization. + """ + + if isinstance(data['img'], torch.Tensor): + # for textrecog with batch_size > 1 + # and not use 'DefaultFormatBundle' in pipeline + img_tensor = data['img'] + img_metas = data['img_metas'].data[0] + elif isinstance(data['img'], list): + if isinstance(data['img'][0], torch.Tensor): + # for textrecog with aug_test and batch_size = 1 + img_tensor = data['img'][0] + elif isinstance(data['img'][0], DataContainer): + # for textdet with 'MultiScaleFlipAug' + # and 'DefaultFormatBundle' in pipeline + img_tensor = data['img'][0].data[0] + img_metas = data['img_metas'][0].data[0] + elif isinstance(data['img'], DataContainer): + # for textrecog with 'DefaultFormatBundle' in pipeline + img_tensor = data['img'].data[0] + img_metas = data['img_metas'].data[0] + + must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape', 'ori_shape'] + for key in must_keys: + if key not in img_metas[0]: + raise KeyError( + f'Please add {key} to the "meta_keys" in the pipeline') + + img_norm_cfg = img_metas[0]['img_norm_cfg'] + if max(img_norm_cfg['mean']) <= 1: + img_norm_cfg['mean'] = [255 * x for x in img_norm_cfg['mean']] + img_norm_cfg['std'] = [255 * x for x in img_norm_cfg['std']] + + return img_tensor, img_metas, img_norm_cfg + + +def single_gpu_test(model, + data_loader, + show=False, + out_dir=None, + is_kie=False, + show_score_thr=0.3): + model.eval() + results = [] + dataset = data_loader.dataset + prog_bar = dbnet_cv.ProgressBar(len(dataset)) + for data in data_loader: + with torch.no_grad(): + result = model(return_loss=False, rescale=True, **data) + + batch_size = len(result) + if show or out_dir: + if is_kie: + img_tensor = data['img'].data[0] + if img_tensor.shape[0] != 1: + raise KeyError('Visualizing KIE outputs in batches is' + 'currently not supported.') + gt_bboxes = data['gt_bboxes'].data[0] + img_metas = data['img_metas'].data[0] + must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape'] + for key in must_keys: + if key not in img_metas[0]: + raise KeyError( + f'Please add {key} to the "meta_keys" in config.') + # for no visual model + if np.prod(img_tensor.shape) == 0: + imgs = [] + for img_meta in img_metas: + try: + img = dbnet_cv.imread(img_meta['filename']) + except Exception as e: + print(f'Load image with error: {e}, ' + 'use empty image instead.') + img = np.ones( + img_meta['img_shape'], dtype=np.uint8) + imgs.append(img) + else: + imgs = tensor2imgs(img_tensor, + **img_metas[0]['img_norm_cfg']) + for i, img in enumerate(imgs): + h, w, _ = img_metas[i]['img_shape'] + img_show = img[:h, :w, :] + if out_dir: + out_file = osp.join(out_dir, + img_metas[i]['ori_filename']) + else: + out_file = None + + model.module.show_result( + img_show, + result[i], + gt_bboxes[i], + show=show, + out_file=out_file) + else: + img_tensor, img_metas, img_norm_cfg = \ + retrieve_img_tensor_and_meta(data) + + if img_tensor.size(1) == 1: + imgs = tensor2grayimgs(img_tensor, **img_norm_cfg) + else: + imgs = tensor2imgs(img_tensor, **img_norm_cfg) + assert len(imgs) == len(img_metas) + + for j, (img, img_meta) in enumerate(zip(imgs, img_metas)): + img_shape, ori_shape = img_meta['img_shape'], img_meta[ + 'ori_shape'] + img_show = img[:img_shape[0], :img_shape[1]] + img_show = dbnet_cv.imresize(img_show, + (ori_shape[1], ori_shape[0])) + + if out_dir: + out_file = osp.join(out_dir, img_meta['ori_filename']) + else: + out_file = None + + model.module.show_result( + img_show, + result[j], + show=show, + out_file=out_file, + score_thr=show_score_thr) + + # encode mask results + if isinstance(result[0], tuple): + result = [(bbox_results, encode_mask_results(mask_results)) + for bbox_results, mask_results in result] + results.extend(result) + + for _ in range(batch_size): + prog_bar.update() + return results diff --git a/cv/ocr/dbnet/pytorch/dbnet/apis/train.py b/cv/ocr/dbnet/pytorch/dbnet/apis/train.py new file mode 100755 index 00000000..e1008be2 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/apis/train.py @@ -0,0 +1,185 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import dbnet_cv +import numpy as np +import torch +import torch.distributed as dist +from dbnet_cv.parallel import MMDataParallel, MMDistributedDataParallel +from dbnet_cv.runner import (DistSamplerSeedHook, EpochBasedRunner, + Fp16OptimizerHook, OptimizerHook, build_optimizer, + build_runner, get_dist_info) +from dbnet_det.core import DistEvalHook, EvalHook +from dbnet_det.datasets import build_dataloader, build_dataset + +from dbnet import digit_version +from dbnet.apis.utils import (disable_text_recog_aug_test, + replace_image_to_tensor) +from dbnet.utils import get_root_logger + + +def train_detector(model, + dataset, + cfg, + distributed=False, + validate=False, + timestamp=None, + meta=None): + logger = get_root_logger(cfg.log_level) + + # prepare data loaders + dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] + # step 1: give default values and override (if exist) from cfg.data + default_loader_cfg = { + **dict( + num_gpus=len(cfg.gpu_ids), + dist=distributed, + seed=cfg.get('seed'), + drop_last=False, + persistent_workers=False), + **({} if torch.__version__ != 'parrots' else dict( + prefetch_num=2, + pin_memory=False, + )), + } + # update overall dataloader(for train, val and test) setting + default_loader_cfg.update({ + k: v + for k, v in cfg.data.items() if k not in [ + 'train', 'val', 'test', 'train_dataloader', 'val_dataloader', + 'test_dataloader' + ] + }) + + # step 2: cfg.data.train_dataloader has highest priority + train_loader_cfg = dict(default_loader_cfg, + **cfg.data.get('train_dataloader', {})) + + data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset] + + # put model on gpus + if distributed: + find_unused_parameters = cfg.get('find_unused_parameters', False) + # Sets the `find_unused_parameters` parameter in + # torch.nn.parallel.DistributedDataParallel + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False, + find_unused_parameters=find_unused_parameters) + else: + if not torch.cuda.is_available(): + assert digit_version(dbnet_cv.__version__) >= digit_version('1.4.4'), \ + 'Please use DBNET_CV >= 1.4.4 for CPU training!' + model = MMDataParallel(model, device_ids=cfg.gpu_ids) + + # build runner + optimizer = build_optimizer(model, cfg.optimizer) + + if 'runner' not in cfg: + cfg.runner = { + 'type': 'EpochBasedRunner', + 'max_epochs': cfg.total_epochs + } + warnings.warn( + 'config is now expected to have a `runner` section, ' + 'please set `runner` in your config.', UserWarning) + else: + if 'total_epochs' in cfg: + assert cfg.total_epochs == cfg.runner.max_epochs + + runner = build_runner( + cfg.runner, + default_args=dict( + model=model, + optimizer=optimizer, + work_dir=cfg.work_dir, + logger=logger, + meta=meta)) + + # an ugly workaround to make .log and .log.json filenames the same + runner.timestamp = timestamp + + # fp16 setting + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + optimizer_config = Fp16OptimizerHook( + **cfg.optimizer_config, **fp16_cfg, distributed=distributed) + elif distributed and 'type' not in cfg.optimizer_config: + optimizer_config = OptimizerHook(**cfg.optimizer_config) + else: + optimizer_config = cfg.optimizer_config + + # register hooks + runner.register_training_hooks( + cfg.lr_config, + optimizer_config, + cfg.checkpoint_config, + cfg.log_config, + cfg.get('momentum_config', None), + custom_hooks_config=cfg.get('custom_hooks', None)) + if distributed: + if isinstance(runner, EpochBasedRunner): + runner.register_hook(DistSamplerSeedHook()) + + # register eval hooks + if validate: + val_samples_per_gpu = (cfg.data.get('val_dataloader', {})).get( + 'samples_per_gpu', cfg.data.get('samples_per_gpu', 1)) + if val_samples_per_gpu > 1: + # Support batch_size > 1 in test for text recognition + # by disable MultiRotateAugOCR since it is useless for most case + cfg = disable_text_recog_aug_test(cfg) + cfg = replace_image_to_tensor(cfg) + + val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) + + val_loader_cfg = { + **default_loader_cfg, + **dict(shuffle=False, drop_last=False), + **cfg.data.get('val_dataloader', {}), + **dict(samples_per_gpu=val_samples_per_gpu) + } + + val_dataloader = build_dataloader(val_dataset, **val_loader_cfg) + + eval_cfg = cfg.get('evaluation', {}) + eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' + eval_hook = DistEvalHook if distributed else EvalHook + runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) + + if cfg.resume_from: + runner.resume(cfg.resume_from) + elif cfg.load_from: + runner.load_checkpoint(cfg.load_from) + runner.run(data_loaders, cfg.workflow) + + +def init_random_seed(seed=None, device='cuda'): + """Initialize random seed. If the seed is None, it will be replaced by a + random number, and then broadcasted to all processes. + + Args: + seed (int, Optional): The seed. + device (str): The device where the seed will be put on. + + Returns: + int: Seed to be used. + """ + if seed is not None: + return seed + + # Make sure all ranks share the same random seed to prevent + # some potential bugs. Please refer to + # https://github.com/open-mmlab/dbnet_detection/issues/6339 + rank, world_size = get_dist_info() + seed = np.random.randint(2**31) + if world_size == 1: + return seed + + if rank == 0: + random_num = torch.tensor(seed, dtype=torch.int32, device=device) + else: + random_num = torch.tensor(0, dtype=torch.int32, device=device) + dist.broadcast(random_num, src=0) + return random_num.item() diff --git a/cv/ocr/dbnet/pytorch/dbnet/apis/utils.py b/cv/ocr/dbnet/pytorch/dbnet/apis/utils.py new file mode 100755 index 00000000..b724713a --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/apis/utils.py @@ -0,0 +1,123 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import warnings + +import dbnet_cv +import numpy as np +import torch +from dbnet_det.datasets import replace_ImageToTensor + +from dbnet.utils import is_2dlist, is_type_list + + +def update_pipeline(cfg, idx=None): + if idx is None: + if cfg.pipeline is not None: + cfg.pipeline = replace_ImageToTensor(cfg.pipeline) + else: + cfg.pipeline[idx] = replace_ImageToTensor(cfg.pipeline[idx]) + + +def replace_image_to_tensor(cfg, set_types=None): + """Replace 'ImageToTensor' to 'DefaultFormatBundle'.""" + assert set_types is None or isinstance(set_types, list) + if set_types is None: + set_types = ['val', 'test'] + + cfg = copy.deepcopy(cfg) + for set_type in set_types: + assert set_type in ['val', 'test'] + uniform_pipeline = cfg.data[set_type].get('pipeline', None) + if is_type_list(uniform_pipeline, dict): + update_pipeline(cfg.data[set_type]) + elif is_2dlist(uniform_pipeline): + for idx, _ in enumerate(uniform_pipeline): + update_pipeline(cfg.data[set_type], idx) + + for dataset in cfg.data[set_type].get('datasets', []): + if isinstance(dataset, list): + for each_dataset in dataset: + update_pipeline(each_dataset) + else: + update_pipeline(dataset) + + return cfg + + +def update_pipeline_recog(cfg, idx=None): + warning_msg = 'Remove "MultiRotateAugOCR" to support batch ' + \ + 'inference since samples_per_gpu > 1.' + if idx is None: + if cfg.get('pipeline', + None) and cfg.pipeline[1].type == 'MultiRotateAugOCR': + warnings.warn(warning_msg) + cfg.pipeline = [cfg.pipeline[0], *cfg.pipeline[1].transforms] + else: + if cfg[idx][1].type == 'MultiRotateAugOCR': + warnings.warn(warning_msg) + cfg[idx] = [cfg[idx][0], *cfg[idx][1].transforms] + + +def disable_text_recog_aug_test(cfg, set_types=None): + """Remove aug_test from test pipeline for text recognition. + Args: + cfg (dbnet_cv.Config): Input config. + set_types (list[str]): Type of dataset source. Should be + None or sublist of ['test', 'val']. + """ + assert set_types is None or isinstance(set_types, list) + if set_types is None: + set_types = ['val', 'test'] + + cfg = copy.deepcopy(cfg) + warnings.simplefilter('once') + for set_type in set_types: + assert set_type in ['val', 'test'] + dataset_type = cfg.data[set_type].type + if dataset_type not in [ + 'ConcatDataset', 'UniformConcatDataset', 'OCRDataset', + 'OCRSegDataset' + ]: + continue + + uniform_pipeline = cfg.data[set_type].get('pipeline', None) + if is_type_list(uniform_pipeline, dict): + update_pipeline_recog(cfg.data[set_type]) + elif is_2dlist(uniform_pipeline): + for idx, _ in enumerate(uniform_pipeline): + update_pipeline_recog(cfg.data[set_type].pipeline, idx) + + for dataset in cfg.data[set_type].get('datasets', []): + if isinstance(dataset, list): + for each_dataset in dataset: + update_pipeline_recog(each_dataset) + else: + update_pipeline_recog(dataset) + + return cfg + + +def tensor2grayimgs(tensor, mean=(127, ), std=(127, ), **kwargs): + """Convert tensor to 1-channel gray images. + Args: + tensor (torch.Tensor): Tensor that contains multiple images, shape ( + N, C, H, W). + mean (tuple[float], optional): Mean of images. Defaults to (127). + std (tuple[float], optional): Standard deviation of images. + Defaults to (127). + Returns: + list[np.ndarray]: A list that contains multiple images. + """ + + assert torch.is_tensor(tensor) and tensor.ndim == 4 + assert tensor.size(1) == len(mean) == len(std) == 1 + + num_imgs = tensor.size(0) + mean = np.array(mean, dtype=np.float32) + std = np.array(std, dtype=np.float32) + imgs = [] + for img_id in range(num_imgs): + img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0) + img = dbnet_cv.imdenormalize(img, mean, std, to_bgr=False).astype(np.uint8) + imgs.append(np.ascontiguousarray(img)) + return imgs \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/dbnet/core/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/core/__init__.py new file mode 100755 index 00000000..beae1ba4 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/core/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from . import evaluation +from .evaluation import * # NOQA +from .mask import extract_boundary, points2boundary, seg2boundary +from .visualize import (det_recog_show_result, imshow_edge, imshow_node, + imshow_pred_boundary, imshow_text_char_boundary, + imshow_text_label, overlay_mask_img, show_feature, + show_img_boundary, show_pred_gt) + +__all__ = [ + 'points2boundary', 'seg2boundary', 'extract_boundary', 'overlay_mask_img', + 'show_feature', 'show_img_boundary', 'show_pred_gt', + 'imshow_pred_boundary', 'imshow_text_char_boundary', 'imshow_text_label', + 'imshow_node', 'det_recog_show_result', 'imshow_edge' +] +__all__ += evaluation.__all__ diff --git a/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/__init__.py new file mode 100755 index 00000000..3c23516f --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hmean import eval_hmean +from .hmean_ic13 import eval_hmean_ic13 +from .hmean_iou import eval_hmean_iou +# from .kie_metric import compute_f1_score +# from .ner_metric import eval_ner_f1 +# from .ocr_metric import eval_ocr_metric + +__all__ = [ + 'eval_hmean_ic13', 'eval_hmean_iou', + # 'eval_ocr_metric', 'eval_hmean', + # 'compute_f1_score', 'eval_ner_f1' +] diff --git a/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean.py b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean.py new file mode 100755 index 00000000..e9797760 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean.py @@ -0,0 +1,172 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from operator import itemgetter + +import dbnet_cv +import numpy as np +from dbnet_cv.utils import print_log + +import dbnet.utils as utils +from dbnet.core.evaluation import hmean_ic13, hmean_iou +from dbnet.core.evaluation.utils import (filter_2dlist_result, + select_top_boundary) +from dbnet.core.mask import extract_boundary + + +def output_ranklist(img_results, img_infos, out_file): + """Output the worst results for debugging. + + Args: + img_results (list[dict]): Image result list. + img_infos (list[dict]): Image information list. + out_file (str): The output file path. + + Returns: + sorted_results (list[dict]): Image results sorted by hmean. + """ + assert utils.is_type_list(img_results, dict) + assert utils.is_type_list(img_infos, dict) + assert isinstance(out_file, str) + assert out_file.endswith('json') + + sorted_results = [] + for idx, result in enumerate(img_results): + name = img_infos[idx]['file_name'] + img_result = result + img_result['file_name'] = name + sorted_results.append(img_result) + sorted_results = sorted( + sorted_results, key=itemgetter('hmean'), reverse=False) + + dbnet_cv.dump(sorted_results, file=out_file) + + return sorted_results + + +def get_gt_masks(ann_infos): + """Get ground truth masks and ignored masks. + + Args: + ann_infos (list[dict]): Each dict contains annotation + infos of one image, containing following keys: + masks, masks_ignore. + Returns: + gt_masks (list[list[list[int]]]): Ground truth masks. + gt_masks_ignore (list[list[list[int]]]): Ignored masks. + """ + assert utils.is_type_list(ann_infos, dict) + + gt_masks = [] + gt_masks_ignore = [] + for ann_info in ann_infos: + masks = ann_info['masks'] + mask_gt = [] + for mask in masks: + assert len(mask[0]) >= 8 and len(mask[0]) % 2 == 0 + mask_gt.append(mask[0]) + gt_masks.append(mask_gt) + + masks_ignore = ann_info['masks_ignore'] + mask_gt_ignore = [] + for mask_ignore in masks_ignore: + assert len(mask_ignore[0]) >= 8 and len(mask_ignore[0]) % 2 == 0 + mask_gt_ignore.append(mask_ignore[0]) + gt_masks_ignore.append(mask_gt_ignore) + + return gt_masks, gt_masks_ignore + + +def eval_hmean(results, + img_infos, + ann_infos, + metrics={'hmean-iou'}, + score_thr=None, + min_score_thr=0.3, + max_score_thr=0.9, + step=0.1, + rank_list=None, + logger=None, + **kwargs): + """Evaluation in hmean metric. It conducts grid search over a range of + boundary score thresholds and reports the best result. + + Args: + results (list[dict]): Each dict corresponds to one image, + containing the following keys: boundary_result + img_infos (list[dict]): Each dict corresponds to one image, + containing the following keys: filename, height, width + ann_infos (list[dict]): Each dict corresponds to one image, + containing the following keys: masks, masks_ignore + score_thr (float): Deprecated. Please use min_score_thr instead. + min_score_thr (float): Minimum score threshold of prediction map. + max_score_thr (float): Maximum score threshold of prediction map. + step (float): The spacing between score thresholds. + metrics (set{str}): Hmean metric set, should be one or all of + {'hmean-iou', 'hmean-ic13'} + Returns: + dict[str: float] + """ + assert utils.is_type_list(results, dict) + assert utils.is_type_list(img_infos, dict) + assert utils.is_type_list(ann_infos, dict) + + if score_thr: + warnings.warn('score_thr is deprecated. Please use min_score_thr ' + 'instead.') + min_score_thr = score_thr + + assert 0 <= min_score_thr <= max_score_thr <= 1 + assert 0 <= step <= 1 + assert len(results) == len(img_infos) == len(ann_infos) + assert isinstance(metrics, set) + + min_score_thr = float(min_score_thr) + max_score_thr = float(max_score_thr) + step = float(step) + + gts, gts_ignore = get_gt_masks(ann_infos) + + preds = [] + pred_scores = [] + for result in results: + _, texts, scores = extract_boundary(result) + if len(texts) > 0: + assert utils.valid_boundary(texts[0], False) + valid_texts, valid_text_scores = filter_2dlist_result( + texts, scores, min_score_thr) + preds.append(valid_texts) + pred_scores.append(valid_text_scores) + + eval_results = {} + + for metric in metrics: + msg = f'Evaluating {metric}...' + if logger is None: + msg = '\n' + msg + print_log(msg, logger=logger) + best_result = dict(hmean=-1) + for thr in np.arange(min_score_thr, min(max_score_thr + step, 1.0), + step): + top_preds = select_top_boundary(preds, pred_scores, thr) + if metric == 'hmean-iou': + result, img_result = hmean_iou.eval_hmean_iou( + top_preds, gts, gts_ignore) + elif metric == 'hmean-ic13': + result, img_result = hmean_ic13.eval_hmean_ic13( + top_preds, gts, gts_ignore) + else: + raise NotImplementedError + if rank_list is not None: + output_ranklist(img_result, img_infos, rank_list) + + print_log( + 'thr {0:.2f}, recall: {1[recall]:.3f}, ' + 'precision: {1[precision]:.3f}, ' + 'hmean: {1[hmean]:.3f}'.format(thr, result), + logger=logger) + if result['hmean'] > best_result['hmean']: + best_result = result + eval_results[metric + ':recall'] = best_result['recall'] + eval_results[metric + ':precision'] = best_result['precision'] + eval_results[metric + ':hmean'] = best_result['hmean'] + return eval_results diff --git a/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean_ic13.py b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean_ic13.py new file mode 100755 index 00000000..a7139cab --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean_ic13.py @@ -0,0 +1,217 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +import dbnet.utils as utils +from . import utils as eval_utils + + +def compute_recall_precision(gt_polys, pred_polys): + """Compute the recall and the precision matrices between gt and predicted + polygons. + + Args: + gt_polys (list[Polygon]): List of gt polygons. + pred_polys (list[Polygon]): List of predicted polygons. + + Returns: + recall (ndarray): Recall matrix of size gt_num x det_num. + precision (ndarray): Precision matrix of size gt_num x det_num. + """ + assert isinstance(gt_polys, list) + assert isinstance(pred_polys, list) + + gt_num = len(gt_polys) + det_num = len(pred_polys) + sz = [gt_num, det_num] + + recall = np.zeros(sz) + precision = np.zeros(sz) + # compute area recall and precision for each (gt, det) pair + # in one img + for gt_id in range(gt_num): + for pred_id in range(det_num): + gt = gt_polys[gt_id] + det = pred_polys[pred_id] + + inter_area = eval_utils.poly_intersection(det, gt) + gt_area = gt.area + det_area = det.area + if gt_area != 0: + recall[gt_id, pred_id] = inter_area / gt_area + if det_area != 0: + precision[gt_id, pred_id] = inter_area / det_area + + return recall, precision + + +def eval_hmean_ic13(det_boxes, + gt_boxes, + gt_ignored_boxes, + precision_thr=0.4, + recall_thr=0.8, + center_dist_thr=1.0, + one2one_score=1., + one2many_score=0.8, + many2one_score=1.): + """Evaluate hmean of text detection using the icdar2013 standard. + + Args: + det_boxes (list[list[list[float]]]): List of arrays of shape (n, 2k). + Each element is the det_boxes for one img. k>=4. + gt_boxes (list[list[list[float]]]): List of arrays of shape (m, 2k). + Each element is the gt_boxes for one img. k>=4. + gt_ignored_boxes (list[list[list[float]]]): List of arrays of + (l, 2k). Each element is the ignored gt_boxes for one img. k>=4. + precision_thr (float): Precision threshold of the iou of one + (gt_box, det_box) pair. + recall_thr (float): Recall threshold of the iou of one + (gt_box, det_box) pair. + center_dist_thr (float): Distance threshold of one (gt_box, det_box) + center point pair. + one2one_score (float): Reward when one gt matches one det_box. + one2many_score (float): Reward when one gt matches many det_boxes. + many2one_score (float): Reward when many gts match one det_box. + + Returns: + hmean (tuple[dict]): Tuple of dicts which encodes the hmean for + the dataset and all images. + """ + assert utils.is_3dlist(det_boxes) + assert utils.is_3dlist(gt_boxes) + assert utils.is_3dlist(gt_ignored_boxes) + + assert 0 <= precision_thr <= 1 + assert 0 <= recall_thr <= 1 + assert center_dist_thr > 0 + assert 0 <= one2one_score <= 1 + assert 0 <= one2many_score <= 1 + assert 0 <= many2one_score <= 1 + + img_num = len(det_boxes) + assert img_num == len(gt_boxes) + assert img_num == len(gt_ignored_boxes) + + dataset_gt_num = 0 + dataset_pred_num = 0 + dataset_hit_recall = 0.0 + dataset_hit_prec = 0.0 + + img_results = [] + + for i in range(img_num): + gt = gt_boxes[i] + gt_ignored = gt_ignored_boxes[i] + pred = det_boxes[i] + + gt_num = len(gt) + ignored_num = len(gt_ignored) + pred_num = len(pred) + + accum_recall = 0. + accum_precision = 0. + + gt_points = gt + gt_ignored + gt_polys = [eval_utils.points2polygon(p) for p in gt_points] + gt_ignored_index = [gt_num + i for i in range(len(gt_ignored))] + gt_num = len(gt_polys) + + pred_polys, pred_points, pred_ignored_index = eval_utils.ignore_pred( + pred, gt_ignored_index, gt_polys, precision_thr) + + if pred_num > 0 and gt_num > 0: + + gt_hit = np.zeros(gt_num, np.int8).tolist() + pred_hit = np.zeros(pred_num, np.int8).tolist() + + # compute area recall and precision for each (gt, pred) pair + # in one img. + recall_mat, precision_mat = compute_recall_precision( + gt_polys, pred_polys) + + # match one gt to one pred box. + for gt_id in range(gt_num): + for pred_id in range(pred_num): + if (gt_hit[gt_id] != 0 or pred_hit[pred_id] != 0 + or gt_id in gt_ignored_index + or pred_id in pred_ignored_index): + continue + match = eval_utils.one2one_match_ic13( + gt_id, pred_id, recall_mat, precision_mat, recall_thr, + precision_thr) + + if match: + gt_point = np.array(gt_points[gt_id]) + det_point = np.array(pred_points[pred_id]) + + norm_dist = eval_utils.box_center_distance( + det_point, gt_point) + norm_dist /= eval_utils.box_diag( + det_point) + eval_utils.box_diag(gt_point) + norm_dist *= 2.0 + + if norm_dist < center_dist_thr: + gt_hit[gt_id] = 1 + pred_hit[pred_id] = 1 + accum_recall += one2one_score + accum_precision += one2one_score + + # match one gt to many det boxes. + for gt_id in range(gt_num): + if gt_id in gt_ignored_index: + continue + match, match_det_set = eval_utils.one2many_match_ic13( + gt_id, recall_mat, precision_mat, recall_thr, + precision_thr, gt_hit, pred_hit, pred_ignored_index) + + if match: + gt_hit[gt_id] = 1 + accum_recall += one2many_score + accum_precision += one2many_score * len(match_det_set) + for pred_id in match_det_set: + pred_hit[pred_id] = 1 + + # match many gt to one det box. One pair of (det,gt) are matched + # successfully if their recall, precision, normalized distance + # meet some thresholds. + for pred_id in range(pred_num): + if pred_id in pred_ignored_index: + continue + + match, match_gt_set = eval_utils.many2one_match_ic13( + pred_id, recall_mat, precision_mat, recall_thr, + precision_thr, gt_hit, pred_hit, gt_ignored_index) + + if match: + pred_hit[pred_id] = 1 + accum_recall += many2one_score * len(match_gt_set) + accum_precision += many2one_score + for gt_id in match_gt_set: + gt_hit[gt_id] = 1 + + gt_care_number = gt_num - ignored_num + pred_care_number = pred_num - len(pred_ignored_index) + + r, p, h = eval_utils.compute_hmean(accum_recall, accum_precision, + gt_care_number, pred_care_number) + + img_results.append({'recall': r, 'precision': p, 'hmean': h}) + + dataset_gt_num += gt_care_number + dataset_pred_num += pred_care_number + dataset_hit_recall += accum_recall + dataset_hit_prec += accum_precision + + total_r, total_p, total_h = eval_utils.compute_hmean( + dataset_hit_recall, dataset_hit_prec, dataset_gt_num, dataset_pred_num) + + dataset_results = { + 'num_gts': dataset_gt_num, + 'num_dets': dataset_pred_num, + 'num_recall': dataset_hit_recall, + 'num_precision': dataset_hit_prec, + 'recall': total_r, + 'precision': total_p, + 'hmean': total_h + } + + return dataset_results, img_results diff --git a/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean_iou.py b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean_iou.py new file mode 100755 index 00000000..ade2aaa9 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/hmean_iou.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +import dbnet.utils as utils +from . import utils as eval_utils + + +def eval_hmean_iou(pred_boxes, + gt_boxes, + gt_ignored_boxes, + iou_thr=0.5, + precision_thr=0.5): + """Evaluate hmean of text detection using IOU standard. + + Args: + pred_boxes (list[list[list[float]]]): Text boxes for an img list. Each + box has 2k (>=8) values. + gt_boxes (list[list[list[float]]]): Ground truth text boxes for an img + list. Each box has 2k (>=8) values. + gt_ignored_boxes (list[list[list[float]]]): Ignored ground truth text + boxes for an img list. Each box has 2k (>=8) values. + iou_thr (float): Iou threshold when one (gt_box, det_box) pair is + matched. + precision_thr (float): Precision threshold when one (gt_box, det_box) + pair is matched. + + Returns: + hmean (tuple[dict]): Tuple of dicts indicates the hmean for the dataset + and all images. + """ + assert utils.is_3dlist(pred_boxes) + assert utils.is_3dlist(gt_boxes) + assert utils.is_3dlist(gt_ignored_boxes) + assert 0 <= iou_thr <= 1 + assert 0 <= precision_thr <= 1 + + img_num = len(pred_boxes) + assert img_num == len(gt_boxes) + assert img_num == len(gt_ignored_boxes) + + dataset_gt_num = 0 + dataset_pred_num = 0 + dataset_hit_num = 0 + + img_results = [] + + for i in range(img_num): + gt = gt_boxes[i] + gt_ignored = gt_ignored_boxes[i] + pred = pred_boxes[i] + + gt_num = len(gt) + gt_ignored_num = len(gt_ignored) + pred_num = len(pred) + + hit_num = 0 + + # get gt polygons. + gt_all = gt + gt_ignored + gt_polys = [eval_utils.points2polygon(p) for p in gt_all] + gt_ignored_index = [gt_num + i for i in range(len(gt_ignored))] + gt_num = len(gt_polys) + pred_polys, _, pred_ignored_index = eval_utils.ignore_pred( + pred, gt_ignored_index, gt_polys, precision_thr) + + # match. + if gt_num > 0 and pred_num > 0: + sz = [gt_num, pred_num] + iou_mat = np.zeros(sz) + + gt_hit = np.zeros(gt_num, np.int8) + pred_hit = np.zeros(pred_num, np.int8) + + for gt_id in range(gt_num): + for pred_id in range(pred_num): + gt_pol = gt_polys[gt_id] + det_pol = pred_polys[pred_id] + + iou_mat[gt_id, + pred_id] = eval_utils.poly_iou(det_pol, gt_pol) + + for gt_id in range(gt_num): + for pred_id in range(pred_num): + if (gt_hit[gt_id] != 0 or pred_hit[pred_id] != 0 + or gt_id in gt_ignored_index + or pred_id in pred_ignored_index): + continue + if iou_mat[gt_id, pred_id] > iou_thr: + gt_hit[gt_id] = 1 + pred_hit[pred_id] = 1 + hit_num += 1 + + gt_care_number = gt_num - gt_ignored_num + pred_care_number = pred_num - len(pred_ignored_index) + + r, p, h = eval_utils.compute_hmean(hit_num, hit_num, gt_care_number, + pred_care_number) + + img_results.append({'recall': r, 'precision': p, 'hmean': h}) + + dataset_hit_num += hit_num + dataset_gt_num += gt_care_number + dataset_pred_num += pred_care_number + + dataset_r, dataset_p, dataset_h = eval_utils.compute_hmean( + dataset_hit_num, dataset_hit_num, dataset_gt_num, dataset_pred_num) + + dataset_results = { + 'num_gts': dataset_gt_num, + 'num_dets': dataset_pred_num, + 'num_match': dataset_hit_num, + 'recall': dataset_r, + 'precision': dataset_p, + 'hmean': dataset_h + } + + return dataset_results, img_results diff --git a/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/utils.py b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/utils.py new file mode 100755 index 00000000..4c0db33b --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/core/evaluation/utils.py @@ -0,0 +1,547 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from shapely.geometry import Polygon as plg + +import dbnet.utils as utils + + +def ignore_pred(pred_boxes, gt_ignored_index, gt_polys, precision_thr): + """Ignore the predicted box if it hits any ignored ground truth. + + Args: + pred_boxes (list[ndarray or list]): The predicted boxes of one image. + gt_ignored_index (list[int]): The ignored ground truth index list. + gt_polys (list[Polygon]): The polygon list of one image. + precision_thr (float): The precision threshold. + + Returns: + pred_polys (list[Polygon]): The predicted polygon list. + pred_points (list[list]): The predicted box list represented + by point sequences. + pred_ignored_index (list[int]): The ignored text index list. + """ + + assert isinstance(pred_boxes, list) + assert isinstance(gt_ignored_index, list) + assert isinstance(gt_polys, list) + assert 0 <= precision_thr <= 1 + + pred_polys = [] + pred_points = [] + pred_ignored_index = [] + + gt_ignored_num = len(gt_ignored_index) + # get detection polygons + for box_id, box in enumerate(pred_boxes): + poly = points2polygon(box) + pred_polys.append(poly) + pred_points.append(box) + + if gt_ignored_num < 1: + continue + + # ignore the current detection box + # if its overlap with any ignored gt > precision_thr + for ignored_box_id in gt_ignored_index: + ignored_box = gt_polys[ignored_box_id] + inter_area = poly_intersection(poly, ignored_box) + area = poly.area + precision = 0 if area == 0 else inter_area / area + if precision > precision_thr: + pred_ignored_index.append(box_id) + break + + return pred_polys, pred_points, pred_ignored_index + + +def compute_hmean(accum_hit_recall, accum_hit_prec, gt_num, pred_num): + """Compute hmean given hit number, ground truth number and prediction + number. + + Args: + accum_hit_recall (int|float): Accumulated hits for computing recall. + accum_hit_prec (int|float): Accumulated hits for computing precision. + gt_num (int): Ground truth number. + pred_num (int): Prediction number. + + Returns: + recall (float): The recall value. + precision (float): The precision value. + hmean (float): The hmean value. + """ + + assert isinstance(accum_hit_recall, (float, int)) + assert isinstance(accum_hit_prec, (float, int)) + + assert isinstance(gt_num, int) + assert isinstance(pred_num, int) + assert accum_hit_recall >= 0.0 + assert accum_hit_prec >= 0.0 + assert gt_num >= 0.0 + assert pred_num >= 0.0 + + if gt_num == 0: + recall = 1.0 + precision = 0.0 if pred_num > 0 else 1.0 + else: + recall = float(accum_hit_recall) / gt_num + precision = 0.0 if pred_num == 0 else float(accum_hit_prec) / pred_num + + denom = recall + precision + + hmean = 0.0 if denom == 0 else (2.0 * precision * recall / denom) + + return recall, precision, hmean + + +def box2polygon(box): + """Convert box to polygon. + + Args: + box (ndarray or list): A ndarray or a list of shape (4) + that indicates 2 points. + + Returns: + polygon (Polygon): A polygon object. + """ + if isinstance(box, list): + box = np.array(box) + + assert isinstance(box, np.ndarray) + assert box.size == 4 + boundary = np.array( + [box[0], box[1], box[2], box[1], box[2], box[3], box[0], box[3]]) + + point_mat = boundary.reshape([-1, 2]) + return plg(point_mat) + + +def points2polygon(points): + """Convert k points to 1 polygon. + + Args: + points (ndarray or list): A ndarray or a list of shape (2k) + that indicates k points. + + Returns: + polygon (Polygon): A polygon object. + """ + if isinstance(points, list): + points = np.array(points) + + assert isinstance(points, np.ndarray) + assert (points.size % 2 == 0) and (points.size >= 8) + + point_mat = points.reshape([-1, 2]) + return plg(point_mat) + + +def poly_make_valid(poly): + """Convert a potentially invalid polygon to a valid one by eliminating + self-crossing or self-touching parts. + + Args: + poly (Polygon): A polygon needed to be converted. + + Returns: + A valid polygon. + """ + return poly if poly.is_valid else poly.buffer(0) + + +def poly_intersection(poly_det, poly_gt, invalid_ret=None, return_poly=False): + """Calculate the intersection area between two polygon. + + Args: + poly_det (Polygon): A polygon predicted by detector. + poly_gt (Polygon): A gt polygon. + invalid_ret (None|float|int): The return value when the invalid polygon + exists. If it is not specified, the function allows the computation + to proceed with invalid polygons by cleaning the their + self-touching or self-crossing parts. + return_poly (bool): Whether to return the polygon of the intersection + area. + + Returns: + intersection_area (float): The intersection area between two polygons. + poly_obj (Polygon, optional): The Polygon object of the intersection + area. Set as `None` if the input is invalid. + """ + assert isinstance(poly_det, plg) + assert isinstance(poly_gt, plg) + assert invalid_ret is None or isinstance(invalid_ret, float) or \ + isinstance(invalid_ret, int) + + if invalid_ret is None: + poly_det = poly_make_valid(poly_det) + poly_gt = poly_make_valid(poly_gt) + + poly_obj = None + area = invalid_ret + if poly_det.is_valid and poly_gt.is_valid: + poly_obj = poly_det.intersection(poly_gt) + area = poly_obj.area + return (area, poly_obj) if return_poly else area + + +def poly_union(poly_det, poly_gt, invalid_ret=None, return_poly=False): + """Calculate the union area between two polygon. + Args: + poly_det (Polygon): A polygon predicted by detector. + poly_gt (Polygon): A gt polygon. + invalid_ret (None|float|int): The return value when the invalid polygon + exists. If it is not specified, the function allows the computation + to proceed with invalid polygons by cleaning the their + self-touching or self-crossing parts. + return_poly (bool): Whether to return the polygon of the intersection + area. + + Returns: + union_area (float): The union area between two polygons. + poly_obj (Polygon|MultiPolygon, optional): The Polygon or MultiPolygon + object of the union of the inputs. The type of object depends on + whether they intersect or not. Set as `None` if the input is + invalid. + """ + assert isinstance(poly_det, plg) + assert isinstance(poly_gt, plg) + assert invalid_ret is None or isinstance(invalid_ret, float) or \ + isinstance(invalid_ret, int) + + if invalid_ret is None: + poly_det = poly_make_valid(poly_det) + poly_gt = poly_make_valid(poly_gt) + + poly_obj = None + area = invalid_ret + if poly_det.is_valid and poly_gt.is_valid: + poly_obj = poly_det.union(poly_gt) + area = poly_obj.area + return (area, poly_obj) if return_poly else area + + +def boundary_iou(src, target, zero_division=0): + """Calculate the IOU between two boundaries. + + Args: + src (list): Source boundary. + target (list): Target boundary. + zero_division (int|float): The return value when invalid + boundary exists. + + Returns: + iou (float): The iou between two boundaries. + """ + assert utils.valid_boundary(src, False) + assert utils.valid_boundary(target, False) + src_poly = points2polygon(src) + target_poly = points2polygon(target) + + return poly_iou(src_poly, target_poly, zero_division=zero_division) + + +def poly_iou(poly_det, poly_gt, zero_division=0): + """Calculate the IOU between two polygons. + + Args: + poly_det (Polygon): A polygon predicted by detector. + poly_gt (Polygon): A gt polygon. + zero_division (int|float): The return value when invalid + polygon exists. + + Returns: + iou (float): The IOU between two polygons. + """ + assert isinstance(poly_det, plg) + assert isinstance(poly_gt, plg) + area_inters = poly_intersection(poly_det, poly_gt) + area_union = poly_union(poly_det, poly_gt) + return area_inters / area_union if area_union != 0 else zero_division + + +def one2one_match_ic13(gt_id, det_id, recall_mat, precision_mat, recall_thr, + precision_thr): + """One-to-One match gt and det with icdar2013 standards. + + Args: + gt_id (int): The ground truth id index. + det_id (int): The detection result id index. + recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j) + being the recall ratio of gt i to det j. + precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j) + being the precision ratio of gt i to det j. + recall_thr (float): The recall threshold. + precision_thr (float): The precision threshold. + Returns: + True|False: Whether the gt and det are matched. + """ + assert isinstance(gt_id, int) + assert isinstance(det_id, int) + assert isinstance(recall_mat, np.ndarray) + assert isinstance(precision_mat, np.ndarray) + assert 0 <= recall_thr <= 1 + assert 0 <= precision_thr <= 1 + + cont = 0 + for i in range(recall_mat.shape[1]): + if recall_mat[gt_id, + i] > recall_thr and precision_mat[gt_id, + i] > precision_thr: + cont += 1 + if cont != 1: + return False + + cont = 0 + for i in range(recall_mat.shape[0]): + if recall_mat[i, det_id] > recall_thr and precision_mat[ + i, det_id] > precision_thr: + cont += 1 + if cont != 1: + return False + + if recall_mat[gt_id, det_id] > recall_thr and precision_mat[ + gt_id, det_id] > precision_thr: + return True + + return False + + +def one2many_match_ic13(gt_id, recall_mat, precision_mat, recall_thr, + precision_thr, gt_match_flag, det_match_flag, + det_ignored_index): + """One-to-Many match gt and detections with icdar2013 standards. + + Args: + gt_id (int): gt index. + recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j) + being the recall ratio of gt i to det j. + precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j) + being the precision ratio of gt i to det j. + recall_thr (float): The recall threshold. + precision_thr (float): The precision threshold. + gt_match_flag (ndarray): An array indicates each gt matched already. + det_match_flag (ndarray): An array indicates each box has been + matched already or not. + det_ignored_index (list): A list indicates each detection box can be + ignored or not. + + Returns: + tuple (True|False, list): The first indicates the gt is matched or not; + the second is the matched detection ids. + """ + assert isinstance(gt_id, int) + assert isinstance(recall_mat, np.ndarray) + assert isinstance(precision_mat, np.ndarray) + assert 0 <= recall_thr <= 1 + assert 0 <= precision_thr <= 1 + + assert isinstance(gt_match_flag, list) + assert isinstance(det_match_flag, list) + assert isinstance(det_ignored_index, list) + + many_sum = 0. + det_ids = [] + for det_id in range(recall_mat.shape[1]): + if gt_match_flag[gt_id] == 0 and det_match_flag[ + det_id] == 0 and det_id not in det_ignored_index: + if precision_mat[gt_id, det_id] >= precision_thr: + many_sum += recall_mat[gt_id, det_id] + det_ids.append(det_id) + if many_sum >= recall_thr: + return True, det_ids + return False, [] + + +def many2one_match_ic13(det_id, recall_mat, precision_mat, recall_thr, + precision_thr, gt_match_flag, det_match_flag, + gt_ignored_index): + """Many-to-One match gt and detections with icdar2013 standards. + + Args: + det_id (int): Detection index. + recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j) + being the recall ratio of gt i to det j. + precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j) + being the precision ratio of gt i to det j. + recall_thr (float): The recall threshold. + precision_thr (float): The precision threshold. + gt_match_flag (ndarray): An array indicates each gt has been matched + already. + det_match_flag (ndarray): An array indicates each detection box has + been matched already or not. + gt_ignored_index (list): A list indicates each gt box can be ignored + or not. + + Returns: + tuple (True|False, list): The first indicates the detection is matched + or not; the second is the matched gt ids. + """ + assert isinstance(det_id, int) + assert isinstance(recall_mat, np.ndarray) + assert isinstance(precision_mat, np.ndarray) + assert 0 <= recall_thr <= 1 + assert 0 <= precision_thr <= 1 + + assert isinstance(gt_match_flag, list) + assert isinstance(det_match_flag, list) + assert isinstance(gt_ignored_index, list) + many_sum = 0. + gt_ids = [] + for gt_id in range(recall_mat.shape[0]): + if gt_match_flag[gt_id] == 0 and det_match_flag[ + det_id] == 0 and gt_id not in gt_ignored_index: + if recall_mat[gt_id, det_id] >= recall_thr: + many_sum += precision_mat[gt_id, det_id] + gt_ids.append(gt_id) + if many_sum >= precision_thr: + return True, gt_ids + return False, [] + + +def points_center(points): + + assert isinstance(points, np.ndarray) + assert points.size % 2 == 0 + + points = points.reshape([-1, 2]) + return np.mean(points, axis=0) + + +def point_distance(p1, p2): + assert isinstance(p1, np.ndarray) + assert isinstance(p2, np.ndarray) + + assert p1.size == 2 + assert p2.size == 2 + + dist = np.square(p2 - p1) + dist = np.sum(dist) + dist = np.sqrt(dist) + return dist + + +def box_center_distance(b1, b2): + assert isinstance(b1, np.ndarray) + assert isinstance(b2, np.ndarray) + return point_distance(points_center(b1), points_center(b2)) + + +def box_diag(box): + assert isinstance(box, np.ndarray) + assert box.size == 8 + + return point_distance(box[0:2], box[4:6]) + + +def filter_2dlist_result(results, scores, score_thr): + """Find out detected results whose score > score_thr. + + Args: + results (list[list[float]]): The result list. + score (list): The score list. + score_thr (float): The score threshold. + Returns: + valid_results (list[list[float]]): The valid results. + valid_score (list[float]): The scores which correspond to the valid + results. + """ + assert isinstance(results, list) + assert len(results) == len(scores) + assert isinstance(score_thr, float) + assert 0 <= score_thr <= 1 + + inds = np.array(scores) > score_thr + valid_results = [results[idx] for idx in np.where(inds)[0].tolist()] + valid_scores = [scores[idx] for idx in np.where(inds)[0].tolist()] + return valid_results, valid_scores + + +def filter_result(results, scores, score_thr): + """Find out detected results whose score > score_thr. + + Args: + results (ndarray): The results matrix of shape (n, k). + score (ndarray): The score vector of shape (n,). + score_thr (float): The score threshold. + Returns: + valid_results (ndarray): The valid results of shape (m,k) with m<=n. + valid_score (ndarray): The scores which correspond to the + valid results. + """ + assert results.ndim == 2 + assert scores.shape[0] == results.shape[0] + assert isinstance(score_thr, float) + assert 0 <= score_thr <= 1 + + inds = scores > score_thr + valid_results = results[inds, :] + valid_scores = scores[inds] + return valid_results, valid_scores + + +def select_top_boundary(boundaries_list, scores_list, score_thr): + """Select poly boundaries with scores >= score_thr. + + Args: + boundaries_list (list[list[list[float]]]): List of boundaries. + The 1st, 2nd, and 3rd indices are for image, text and + vertice, respectively. + scores_list (list(list[float])): List of lists of scores. + score_thr (float): The score threshold to filter out bboxes. + + Returns: + selected_bboxes (list[list[list[float]]]): List of boundaries. + The 1st, 2nd, and 3rd indices are for image, text and vertice, + respectively. + """ + assert isinstance(boundaries_list, list) + assert isinstance(scores_list, list) + assert isinstance(score_thr, float) + assert len(boundaries_list) == len(scores_list) + assert 0 <= score_thr <= 1 + + selected_boundaries = [] + for boundary, scores in zip(boundaries_list, scores_list): + if len(scores) > 0: + assert len(scores) == len(boundary) + inds = [ + iter for iter in range(len(scores)) + if scores[iter] >= score_thr + ] + selected_boundaries.append([boundary[i] for i in inds]) + else: + selected_boundaries.append(boundary) + return selected_boundaries + + +def select_bboxes_via_score(bboxes_list, scores_list, score_thr): + """Select bboxes with scores >= score_thr. + + Args: + bboxes_list (list[ndarray]): List of bboxes. Each element is ndarray of + shape (n,8) + scores_list (list(list[float])): List of lists of scores. + score_thr (float): The score threshold to filter out bboxes. + + Returns: + selected_bboxes (list[ndarray]): List of bboxes. Each element is + ndarray of shape (m,8) with m<=n. + """ + assert isinstance(bboxes_list, list) + assert isinstance(scores_list, list) + assert isinstance(score_thr, float) + assert len(bboxes_list) == len(scores_list) + assert 0 <= score_thr <= 1 + + selected_bboxes = [] + for bboxes, scores in zip(bboxes_list, scores_list): + if len(scores) > 0: + assert len(scores) == bboxes.shape[0] + inds = [ + iter for iter in range(len(scores)) + if scores[iter] >= score_thr + ] + selected_bboxes.append(bboxes[inds, :]) + else: + selected_bboxes.append(bboxes) + return selected_bboxes diff --git a/cv/ocr/dbnet/pytorch/dbnet/core/mask.py b/cv/ocr/dbnet/pytorch/dbnet/core/mask.py new file mode 100755 index 00000000..dc575362 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/core/mask.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import cv2 +import numpy as np + +import dbnet.utils as utils + + +def points2boundary(points, text_repr_type, text_score=None, min_width=-1): + """Convert a text mask represented by point coordinates sequence into a + text boundary. + + Args: + points (ndarray): Mask index of size (n, 2). + text_repr_type (str): Text instance encoding type + ('quad' for quadrangle or 'poly' for polygon). + text_score (float): Text score. + + Returns: + boundary (list[float]): The text boundary point coordinates (x, y) + list. Return None if no text boundary found. + """ + assert isinstance(points, np.ndarray) + assert points.shape[1] == 2 + assert text_repr_type in ['quad', 'poly'] + assert text_score is None or 0 <= text_score <= 1 + + if text_repr_type == 'quad': + rect = cv2.minAreaRect(points) + vertices = cv2.boxPoints(rect) + boundary = [] + if min(rect[1]) > min_width: + boundary = [p for p in vertices.flatten().tolist()] + + elif text_repr_type == 'poly': + + height = np.max(points[:, 1]) + 10 + width = np.max(points[:, 0]) + 10 + + mask = np.zeros((height, width), np.uint8) + mask[points[:, 1], points[:, 0]] = 255 + + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_SIMPLE) + boundary = list(contours[0].flatten().tolist()) + + if text_score is not None: + boundary = boundary + [text_score] + if len(boundary) < 8: + return None + + return boundary + + +def seg2boundary(seg, text_repr_type, text_score=None): + """Convert a segmentation mask to a text boundary. + + Args: + seg (ndarray): The segmentation mask. + text_repr_type (str): Text instance encoding type + ('quad' for quadrangle or 'poly' for polygon). + text_score (float): The text score. + + Returns: + boundary (list): The text boundary. Return None if no text found. + """ + assert isinstance(seg, np.ndarray) + assert isinstance(text_repr_type, str) + assert text_score is None or 0 <= text_score <= 1 + + points = np.where(seg) + # x, y order + points = np.concatenate([points[1], points[0]]).reshape(2, -1).transpose() + boundary = None + if len(points) != 0: + boundary = points2boundary(points, text_repr_type, text_score) + + return boundary + + +def extract_boundary(result): + """Extract boundaries and their scores from result. + + Args: + result (dict): The detection result with the key 'boundary_result' + of one image. + + Returns: + boundaries_with_scores (list[list[float]]): The boundary and score + list. + boundaries (list[list[float]]): The boundary list. + scores (list[float]): The boundary score list. + """ + assert isinstance(result, dict) + assert 'boundary_result' in result.keys() + + boundaries_with_scores = result['boundary_result'] + assert utils.is_2dlist(boundaries_with_scores) + + boundaries = [b[:-1] for b in boundaries_with_scores] + scores = [b[-1] for b in boundaries_with_scores] + + return (boundaries_with_scores, boundaries, scores) diff --git a/cv/ocr/dbnet/pytorch/dbnet/core/visualize.py b/cv/ocr/dbnet/pytorch/dbnet/core/visualize.py new file mode 100755 index 00000000..2934255b --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/core/visualize.py @@ -0,0 +1,888 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import os +import shutil +import urllib +import warnings + +import cv2 +import dbnet_cv +import numpy as np +import torch +from matplotlib import pyplot as plt +from PIL import Image, ImageDraw, ImageFont + +import dbnet.utils as utils + + +def overlay_mask_img(img, mask): + """Draw mask boundaries on image for visualization. + + Args: + img (ndarray): The input image. + mask (ndarray): The instance mask. + + Returns: + img (ndarray): The output image with instance boundaries on it. + """ + assert isinstance(img, np.ndarray) + assert isinstance(mask, np.ndarray) + + contours, _ = cv2.findContours( + mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + cv2.drawContours(img, contours, -1, (0, 255, 0), 1) + + return img + + +def show_feature(features, names, to_uint8, out_file=None): + """Visualize a list of feature maps. + + Args: + features (list(ndarray)): The feature map list. + names (list(str)): The visualized title list. + to_uint8 (list(1|0)): The list indicating whether to convent + feature maps to uint8. + out_file (str): The output file name. If set to None, + the output image will be shown without saving. + """ + assert utils.is_type_list(features, np.ndarray) + assert utils.is_type_list(names, str) + assert utils.is_type_list(to_uint8, int) + assert utils.is_none_or_type(out_file, str) + assert utils.equal_len(features, names, to_uint8) + + num = len(features) + row = col = math.ceil(math.sqrt(num)) + + for i, (f, n) in enumerate(zip(features, names)): + plt.subplot(row, col, i + 1) + plt.title(n) + if to_uint8[i]: + f = f.astype(np.uint8) + plt.imshow(f) + if out_file is None: + plt.show() + else: + plt.savefig(out_file) + + +def show_img_boundary(img, boundary): + """Show image and instance boundaires. + + Args: + img (ndarray): The input image. + boundary (list[float or int]): The input boundary. + """ + assert isinstance(img, np.ndarray) + assert utils.is_type_list(boundary, (int, float)) + + cv2.polylines( + img, [np.array(boundary).astype(np.int32).reshape(-1, 1, 2)], + True, + color=(0, 255, 0), + thickness=1) + plt.imshow(img) + plt.show() + + +def show_pred_gt(preds, + gts, + show=False, + win_name='', + wait_time=0, + out_file=None): + """Show detection and ground truth for one image. + + Args: + preds (list[list[float]]): The detection boundary list. + gts (list[list[float]]): The ground truth boundary list. + show (bool): Whether to show the image. + win_name (str): The window name. + wait_time (int): The value of waitKey param. + out_file (str): The filename of the output. + """ + assert utils.is_2dlist(preds) + assert utils.is_2dlist(gts) + assert isinstance(show, bool) + assert isinstance(win_name, str) + assert isinstance(wait_time, int) + assert utils.is_none_or_type(out_file, str) + + p_xy = [p for boundary in preds for p in boundary] + gt_xy = [g for gt in gts for g in gt] + + max_xy = np.max(np.array(p_xy + gt_xy).reshape(-1, 2), axis=0) + + width = int(max_xy[0]) + 100 + height = int(max_xy[1]) + 100 + + img = np.ones((height, width, 3), np.int8) * 255 + pred_color = dbnet_cv.color_val('red') + gt_color = dbnet_cv.color_val('blue') + thickness = 1 + + for boundary in preds: + cv2.polylines( + img, [np.array(boundary).astype(np.int32).reshape(-1, 1, 2)], + True, + color=pred_color, + thickness=thickness) + for gt in gts: + cv2.polylines( + img, [np.array(gt).astype(np.int32).reshape(-1, 1, 2)], + True, + color=gt_color, + thickness=thickness) + if show: + dbnet_cv.imshow(img, win_name, wait_time) + if out_file is not None: + dbnet_cv.imwrite(img, out_file) + + return img + + +def imshow_pred_boundary(img, + boundaries_with_scores, + labels, + score_thr=0, + boundary_color='blue', + text_color='blue', + thickness=1, + font_scale=0.5, + show=True, + win_name='', + wait_time=0, + out_file=None, + show_score=False): + """Draw boundaries and class labels (with scores) on an image. + + Args: + img (str or ndarray): The image to be displayed. + boundaries_with_scores (list[list[float]]): Boundaries with scores. + labels (list[int]): Labels of boundaries. + score_thr (float): Minimum score of boundaries to be shown. + boundary_color (str or tuple or :obj:`Color`): Color of boundaries. + text_color (str or tuple or :obj:`Color`): Color of texts. + thickness (int): Thickness of lines. + font_scale (float): Font scales of texts. + show (bool): Whether to show the image. + win_name (str): The window name. + wait_time (int): Value of waitKey param. + out_file (str or None): The filename of the output. + show_score (bool): Whether to show text instance score. + """ + assert isinstance(img, (str, np.ndarray)) + assert utils.is_2dlist(boundaries_with_scores) + assert utils.is_type_list(labels, int) + assert utils.equal_len(boundaries_with_scores, labels) + if len(boundaries_with_scores) == 0: + warnings.warn('0 text found in ' + out_file) + return None + + utils.valid_boundary(boundaries_with_scores[0]) + img = dbnet_cv.imread(img) + + scores = np.array([b[-1] for b in boundaries_with_scores]) + inds = scores > score_thr + boundaries = [boundaries_with_scores[i][:-1] for i in np.where(inds)[0]] + scores = [scores[i] for i in np.where(inds)[0]] + labels = [labels[i] for i in np.where(inds)[0]] + + boundary_color = dbnet_cv.color_val(boundary_color) + text_color = dbnet_cv.color_val(text_color) + font_scale = 0.5 + + for boundary, score in zip(boundaries, scores): + boundary_int = np.array(boundary).astype(np.int32) + + cv2.polylines( + img, [boundary_int.reshape(-1, 1, 2)], + True, + color=boundary_color, + thickness=thickness) + + if show_score: + label_text = f'{score:.02f}' + cv2.putText(img, label_text, + (boundary_int[0], boundary_int[1] - 2), + cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color) + if show: + dbnet_cv.imshow(img, win_name, wait_time) + if out_file is not None: + dbnet_cv.imwrite(img, out_file) + + return img + + +def imshow_text_char_boundary(img, + text_quads, + boundaries, + char_quads, + chars, + show=False, + thickness=1, + font_scale=0.5, + win_name='', + wait_time=-1, + out_file=None): + """Draw text boxes and char boxes on img. + + Args: + img (str or ndarray): The img to be displayed. + text_quads (list[list[int|float]]): The text boxes. + boundaries (list[list[int|float]]): The boundary list. + char_quads (list[list[list[int|float]]]): A 2d list of char boxes. + char_quads[i] is for the ith text, and char_quads[i][j] is the jth + char of the ith text. + chars (list[list[char]]). The string for each text box. + thickness (int): Thickness of lines. + font_scale (float): Font scales of texts. + show (bool): Whether to show the image. + win_name (str): The window name. + wait_time (int): Value of waitKey param. + out_file (str or None): The filename of the output. + """ + assert isinstance(img, (np.ndarray, str)) + assert utils.is_2dlist(text_quads) + assert utils.is_2dlist(boundaries) + assert utils.is_3dlist(char_quads) + assert utils.is_2dlist(chars) + assert utils.equal_len(text_quads, char_quads, boundaries) + + img = dbnet_cv.imread(img) + char_color = [dbnet_cv.color_val('blue'), dbnet_cv.color_val('green')] + text_color = dbnet_cv.color_val('red') + text_inx = 0 + for text_box, boundary, char_box, txt in zip(text_quads, boundaries, + char_quads, chars): + text_box = np.array(text_box) + boundary = np.array(boundary) + + text_box = text_box.reshape(-1, 2).astype(np.int32) + cv2.polylines( + img, [text_box.reshape(-1, 1, 2)], + True, + color=text_color, + thickness=thickness) + if boundary.shape[0] > 0: + cv2.polylines( + img, [boundary.reshape(-1, 1, 2)], + True, + color=text_color, + thickness=thickness) + + for b in char_box: + b = np.array(b) + c = char_color[text_inx % 2] + b = b.astype(np.int32) + cv2.polylines( + img, [b.reshape(-1, 1, 2)], True, color=c, thickness=thickness) + + label_text = ''.join(txt) + cv2.putText(img, label_text, (text_box[0, 0], text_box[0, 1] - 2), + cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color) + text_inx = text_inx + 1 + + if show: + dbnet_cv.imshow(img, win_name, wait_time) + if out_file is not None: + dbnet_cv.imwrite(img, out_file) + + return img + + +def tile_image(images): + """Combined multiple images to one vertically. + + Args: + images (list[np.ndarray]): Images to be combined. + """ + assert isinstance(images, list) + assert len(images) > 0 + + for i, _ in enumerate(images): + if len(images[i].shape) == 2: + images[i] = cv2.cvtColor(images[i], cv2.COLOR_GRAY2BGR) + + widths = [img.shape[1] for img in images] + heights = [img.shape[0] for img in images] + h, w = sum(heights), max(widths) + vis_img = np.zeros((h, w, 3), dtype=np.uint8) + + offset_y = 0 + for image in images: + img_h, img_w = image.shape[:2] + vis_img[offset_y:(offset_y + img_h), 0:img_w, :] = image + offset_y += img_h + + return vis_img + + +def imshow_text_label(img, + pred_label, + gt_label, + show=False, + win_name='', + wait_time=-1, + out_file=None): + """Draw predicted texts and ground truth texts on images. + + Args: + img (str or np.ndarray): Image filename or loaded image. + pred_label (str): Predicted texts. + gt_label (str): Ground truth texts. + show (bool): Whether to show the image. + win_name (str): The window name. + wait_time (int): Value of waitKey param. + out_file (str): The filename of the output. + """ + assert isinstance(img, (np.ndarray, str)) + assert isinstance(pred_label, str) + assert isinstance(gt_label, str) + assert isinstance(show, bool) + assert isinstance(win_name, str) + assert isinstance(wait_time, int) + + img = dbnet_cv.imread(img) + + src_h, src_w = img.shape[:2] + resize_height = 64 + resize_width = int(1.0 * src_w / src_h * resize_height) + img = cv2.resize(img, (resize_width, resize_height)) + h, w = img.shape[:2] + + if is_contain_chinese(pred_label): + pred_img = draw_texts_by_pil(img, [pred_label], None) + else: + pred_img = np.ones((h, w, 3), dtype=np.uint8) * 255 + cv2.putText(pred_img, pred_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX, + 0.9, (0, 0, 255), 2) + images = [pred_img, img] + + if gt_label != '': + if is_contain_chinese(gt_label): + gt_img = draw_texts_by_pil(img, [gt_label], None) + else: + gt_img = np.ones((h, w, 3), dtype=np.uint8) * 255 + cv2.putText(gt_img, gt_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX, + 0.9, (255, 0, 0), 2) + images.append(gt_img) + + img = tile_image(images) + + if show: + dbnet_cv.imshow(img, win_name, wait_time) + if out_file is not None: + dbnet_cv.imwrite(img, out_file) + + return img + + +def imshow_node(img, + result, + boxes, + idx_to_cls={}, + show=False, + win_name='', + wait_time=-1, + out_file=None): + + img = dbnet_cv.imread(img) + h, w = img.shape[:2] + + max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1) + node_pred_label = max_idx.numpy().tolist() + node_pred_score = max_value.numpy().tolist() + + texts, text_boxes = [], [] + for i, box in enumerate(boxes): + new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]], + [box[0], box[3]]] + Pts = np.array([new_box], np.int32) + cv2.polylines( + img, [Pts.reshape((-1, 1, 2))], + True, + color=(255, 255, 0), + thickness=1) + x_min = int(min([point[0] for point in new_box])) + y_min = int(min([point[1] for point in new_box])) + + # text + pred_label = str(node_pred_label[i]) + if pred_label in idx_to_cls: + pred_label = idx_to_cls[pred_label] + pred_score = '{:.2f}'.format(node_pred_score[i]) + text = pred_label + '(' + pred_score + ')' + texts.append(text) + + # text box + font_size = int( + min( + abs(new_box[3][1] - new_box[0][1]), + abs(new_box[1][0] - new_box[0][0]))) + char_num = len(text) + text_box = [ + x_min * 2, y_min, x_min * 2 + font_size * char_num, y_min, + x_min * 2 + font_size * char_num, y_min + font_size, x_min * 2, + y_min + font_size + ] + text_boxes.append(text_box) + + pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255 + pred_img = draw_texts_by_pil( + pred_img, texts, text_boxes, draw_box=False, on_ori_img=True) + + vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255 + vis_img[:, :w] = img + vis_img[:, w:] = pred_img + + if show: + dbnet_cv.imshow(vis_img, win_name, wait_time) + if out_file is not None: + dbnet_cv.imwrite(vis_img, out_file) + + return vis_img + + +def gen_color(): + """Generate BGR color schemes.""" + color_list = [(101, 67, 254), (154, 157, 252), (173, 205, 249), + (123, 151, 138), (187, 200, 178), (148, 137, 69), + (169, 200, 200), (155, 175, 131), (154, 194, 182), + (178, 190, 137), (140, 211, 222), (83, 156, 222)] + return color_list + + +def draw_polygons(img, polys): + """Draw polygons on image. + + Args: + img (np.ndarray): The original image. + polys (list[list[float]]): Detected polygons. + Return: + out_img (np.ndarray): Visualized image. + """ + dst_img = img.copy() + color_list = gen_color() + out_img = dst_img + for idx, poly in enumerate(polys): + poly = np.array(poly).reshape((-1, 1, 2)).astype(np.int32) + cv2.drawContours( + img, + np.array([poly]), + -1, + color_list[idx % len(color_list)], + thickness=cv2.FILLED) + out_img = cv2.addWeighted(dst_img, 0.5, img, 0.5, 0) + return out_img + + +def get_optimal_font_scale(text, width): + """Get optimal font scale for cv2.putText. + + Args: + text (str): Text in one box. + width (int): The box width. + """ + for scale in reversed(range(0, 60, 1)): + textSize = cv2.getTextSize( + text, + fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=scale / 10, + thickness=1) + new_width = textSize[0][0] + if new_width <= width: + return scale / 10 + return 1 + + +def draw_texts(img, texts, boxes=None, draw_box=True, on_ori_img=False): + """Draw boxes and texts on empty img. + + Args: + img (np.ndarray): The original image. + texts (list[str]): Recognized texts. + boxes (list[list[float]]): Detected bounding boxes. + draw_box (bool): Whether draw box or not. If False, draw text only. + on_ori_img (bool): If True, draw box and text on input image, + else, on a new empty image. + Return: + out_img (np.ndarray): Visualized image. + """ + color_list = gen_color() + h, w = img.shape[:2] + if boxes is None: + boxes = [[0, 0, w, 0, w, h, 0, h]] + assert len(texts) == len(boxes) + + if on_ori_img: + out_img = img + else: + out_img = np.ones((h, w, 3), dtype=np.uint8) * 255 + for idx, (box, text) in enumerate(zip(boxes, texts)): + if draw_box: + new_box = [[x, y] for x, y in zip(box[0::2], box[1::2])] + Pts = np.array([new_box], np.int32) + cv2.polylines( + out_img, [Pts.reshape((-1, 1, 2))], + True, + color=color_list[idx % len(color_list)], + thickness=1) + min_x = int(min(box[0::2])) + max_y = int( + np.mean(np.array(box[1::2])) + 0.2 * + (max(box[1::2]) - min(box[1::2]))) + font_scale = get_optimal_font_scale( + text, int(max(box[0::2]) - min(box[0::2]))) + cv2.putText(out_img, text, (min_x, max_y), cv2.FONT_HERSHEY_SIMPLEX, + font_scale, (0, 0, 0), 1) + + return out_img + + +def draw_texts_by_pil(img, + texts, + boxes=None, + draw_box=True, + on_ori_img=False, + font_size=None, + fill_color=None, + draw_pos=None, + return_text_size=False): + """Draw boxes and texts on empty image, especially for Chinese. + + Args: + img (np.ndarray): The original image. + texts (list[str]): Recognized texts. + boxes (list[list[float]]): Detected bounding boxes. + draw_box (bool): Whether draw box or not. If False, draw text only. + on_ori_img (bool): If True, draw box and text on input image, + else on a new empty image. + font_size (int, optional): Size to create a font object for a font. + fill_color (tuple(int), optional): Fill color for text. + draw_pos (list[tuple(int)], optional): Start point to draw each text. + return_text_size (bool): If True, return the list of text size. + + Returns: + (np.ndarray, list[tuple]) or np.ndarray: Return a tuple + ``(out_img, text_sizes)``, where ``out_img`` is the output image + with texts drawn on it and ``text_sizes`` are the size of drawing + texts. If ``return_text_size`` is False, only the output image will be + returned. + """ + + color_list = gen_color() + h, w = img.shape[:2] + if boxes is None: + boxes = [[0, 0, w, 0, w, h, 0, h]] + if draw_pos is None: + draw_pos = [None for _ in texts] + assert len(boxes) == len(texts) == len(draw_pos) + + if fill_color is None: + fill_color = (0, 0, 0) + + if on_ori_img: + out_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + else: + out_img = Image.new('RGB', (w, h), color=(255, 255, 255)) + out_draw = ImageDraw.Draw(out_img) + + text_sizes = [] + for idx, (box, text, ori_point) in enumerate(zip(boxes, texts, draw_pos)): + if len(text) == 0: + continue + min_x, max_x = min(box[0::2]), max(box[0::2]) + min_y, max_y = min(box[1::2]), max(box[1::2]) + color = tuple(list(color_list[idx % len(color_list)])[::-1]) + if draw_box: + out_draw.line(box, fill=color, width=1) + dirname, _ = os.path.split(os.path.abspath(__file__)) + font_path = os.path.join(dirname, 'font.TTF') + if not os.path.exists(font_path): + url = ('https://download.openmmlab.com/mmocr/data/font.TTF') + print(f'Downloading {url} ...') + local_filename, _ = urllib.request.urlretrieve(url) + shutil.move(local_filename, font_path) + tmp_font_size = font_size + if tmp_font_size is None: + box_width = max(max_x - min_x, max_y - min_y) + tmp_font_size = int(0.9 * box_width / len(text)) + fnt = ImageFont.truetype(font_path, tmp_font_size) + if ori_point is None: + ori_point = (min_x + 1, min_y + 1) + out_draw.text(ori_point, text, font=fnt, fill=fill_color) + text_sizes.append(fnt.getsize(text)) + + del out_draw + + out_img = cv2.cvtColor(np.asarray(out_img), cv2.COLOR_RGB2BGR) + + if return_text_size: + return out_img, text_sizes + + return out_img + + +def is_contain_chinese(check_str): + """Check whether string contains Chinese or not. + + Args: + check_str (str): String to be checked. + + Return True if contains Chinese, else False. + """ + for ch in check_str: + if u'\u4e00' <= ch <= u'\u9fff': + return True + return False + + +def det_recog_show_result(img, end2end_res, out_file=None): + """Draw `result`(boxes and texts) on `img`. + + Args: + img (str or np.ndarray): The image to be displayed. + end2end_res (dict): Text detect and recognize results. + out_file (str): Image path where the visualized image should be saved. + Return: + out_img (np.ndarray): Visualized image. + """ + img = dbnet_cv.imread(img) + boxes, texts = [], [] + for res in end2end_res['result']: + boxes.append(res['box']) + texts.append(res['text']) + box_vis_img = draw_polygons(img, boxes) + + if is_contain_chinese(''.join(texts)): + text_vis_img = draw_texts_by_pil(img, texts, boxes) + else: + text_vis_img = draw_texts(img, texts, boxes) + + h, w = img.shape[:2] + out_img = np.ones((h, w * 2, 3), dtype=np.uint8) + out_img[:, :w, :] = box_vis_img + out_img[:, w:, :] = text_vis_img + + if out_file: + dbnet_cv.imwrite(out_img, out_file) + + return out_img + + +def draw_edge_result(img, result, edge_thresh=0.5, keynode_thresh=0.5): + """Draw text and their relationship on empty images. + + Args: + img (np.ndarray): The original image. + result (dict): The result of model forward_test, including: + - img_metas (list[dict]): List of meta information dictionary. + - nodes (Tensor): Node prediction with size: + number_node * node_classes. + - edges (Tensor): Edge prediction with size: number_edge * 2. + edge_thresh (float): Score threshold for edge classification. + keynode_thresh (float): Score threshold for node + (``key``) classification. + + Returns: + np.ndarray: The image with key, value and relation drawn on it. + """ + + h, w = img.shape[:2] + + vis_area_width = w // 3 * 2 + vis_area_height = h + dist_key_to_value = vis_area_width // 2 + dist_pair_to_pair = 30 + + bbox_x1 = dist_pair_to_pair + bbox_y1 = 0 + + new_w = vis_area_width + new_h = vis_area_height + pred_edge_img = np.ones((new_h, new_w, 3), dtype=np.uint8) * 255 + + nodes = result['nodes'].detach().cpu() + texts = result['img_metas'][0]['ori_texts'] + num_nodes = result['nodes'].size(0) + edges = result['edges'].detach().cpu()[:, -1].view(num_nodes, num_nodes) + + # (i, j) will be a valid pair + # either edge_score(node_i->node_j) > edge_thresh + # or edge_score(node_j->node_i) > edge_thresh + pairs = (torch.max(edges, edges.T) > edge_thresh).nonzero(as_tuple=True) + pairs = (pairs[0].numpy().tolist(), pairs[1].numpy().tolist()) + + # 1. "for n1, n2 in zip(*pairs) if n1 < n2": + # Only (n1, n2) will be included if n1 < n2 but not (n2, n1), to + # avoid duplication. + # 2. "(n1, n2) if nodes[n1, 1] > nodes[n1, 2]": + # nodes[n1, 1] is the score that this node is predicted as key, + # nodes[n1, 2] is the score that this node is predicted as value. + # If nodes[n1, 1] > nodes[n1, 2], n1 will be the index of key, + # so that n2 will be the index of value. + result_pairs = [(n1, n2) if nodes[n1, 1] > nodes[n1, 2] else (n2, n1) + for n1, n2 in zip(*pairs) if n1 < n2] + + result_pairs.sort() + result_pairs_score = [ + torch.max(edges[n1, n2], edges[n2, n1]) for n1, n2 in result_pairs + ] + + key_current_idx = -1 + pos_current = (-1, -1) + newline_flag = False + + key_font_size = 15 + value_font_size = 15 + key_font_color = (0, 0, 0) + value_font_color = (0, 0, 255) + arrow_color = (0, 0, 255) + score_color = (0, 255, 0) + for pair, pair_score in zip(result_pairs, result_pairs_score): + key_idx = pair[0] + if nodes[key_idx, 1] < keynode_thresh: + continue + if key_idx != key_current_idx: + # move y-coords down for a new key + bbox_y1 += 10 + # enlarge blank area to show key-value info + if newline_flag: + bbox_x1 += vis_area_width + tmp_img = np.ones( + (new_h, new_w + vis_area_width, 3), dtype=np.uint8) * 255 + tmp_img[:new_h, :new_w] = pred_edge_img + pred_edge_img = tmp_img + new_w += vis_area_width + newline_flag = False + bbox_y1 = 10 + key_text = texts[key_idx] + key_pos = (bbox_x1, bbox_y1) + value_idx = pair[1] + value_text = texts[value_idx] + value_pos = (bbox_x1 + dist_key_to_value, bbox_y1) + if key_idx != key_current_idx: + # draw text for a new key + key_current_idx = key_idx + pred_edge_img, text_sizes = draw_texts_by_pil( + pred_edge_img, [key_text], + draw_box=False, + on_ori_img=True, + font_size=key_font_size, + fill_color=key_font_color, + draw_pos=[key_pos], + return_text_size=True) + pos_right_bottom = (key_pos[0] + text_sizes[0][0], + key_pos[1] + text_sizes[0][1]) + pos_current = (pos_right_bottom[0] + 5, bbox_y1 + 10) + pred_edge_img = cv2.arrowedLine( + pred_edge_img, (pos_right_bottom[0] + 5, bbox_y1 + 10), + (bbox_x1 + dist_key_to_value - 5, bbox_y1 + 10), arrow_color, + 1) + score_pos_x = int( + (pos_right_bottom[0] + bbox_x1 + dist_key_to_value) / 2.) + score_pos_y = bbox_y1 + 10 - int(key_font_size * 0.3) + else: + # draw arrow from key to value + if newline_flag: + tmp_img = np.ones((new_h + dist_pair_to_pair, new_w, 3), + dtype=np.uint8) * 255 + tmp_img[:new_h, :new_w] = pred_edge_img + pred_edge_img = tmp_img + new_h += dist_pair_to_pair + pred_edge_img = cv2.arrowedLine(pred_edge_img, pos_current, + (bbox_x1 + dist_key_to_value - 5, + bbox_y1 + 10), arrow_color, 1) + score_pos_x = int( + (pos_current[0] + bbox_x1 + dist_key_to_value - 5) / 2.) + score_pos_y = int((pos_current[1] + bbox_y1 + 10) / 2.) + # draw edge score + cv2.putText(pred_edge_img, '{:.2f}'.format(pair_score), + (score_pos_x, score_pos_y), cv2.FONT_HERSHEY_COMPLEX, 0.4, + score_color) + # draw text for value + pred_edge_img = draw_texts_by_pil( + pred_edge_img, [value_text], + draw_box=False, + on_ori_img=True, + font_size=value_font_size, + fill_color=value_font_color, + draw_pos=[value_pos], + return_text_size=False) + bbox_y1 += dist_pair_to_pair + if bbox_y1 + dist_pair_to_pair >= new_h: + newline_flag = True + + return pred_edge_img + + +def imshow_edge(img, + result, + boxes, + show=False, + win_name='', + wait_time=-1, + out_file=None): + """Display the prediction results of the nodes and edges of the KIE model. + + Args: + img (np.ndarray): The original image. + result (dict): The result of model forward_test, including: + - img_metas (list[dict]): List of meta information dictionary. + - nodes (Tensor): Node prediction with size: \ + number_node * node_classes. + - edges (Tensor): Edge prediction with size: number_edge * 2. + boxes (list): The text boxes corresponding to the nodes. + show (bool): Whether to show the image. Default: False. + win_name (str): The window name. Default: '' + wait_time (float): Value of waitKey param. Default: 0. + out_file (str or None): The filename to write the image. + Default: None. + + Returns: + np.ndarray: The image with key, value and relation drawn on it. + """ + img = dbnet_cv.imread(img) + h, w = img.shape[:2] + color_list = gen_color() + + for i, box in enumerate(boxes): + new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]], + [box[0], box[3]]] + Pts = np.array([new_box], np.int32) + cv2.polylines( + img, [Pts.reshape((-1, 1, 2))], + True, + color=color_list[i % len(color_list)], + thickness=1) + + pred_img_h = h + pred_img_w = w + + pred_edge_img = draw_edge_result(img, result) + pred_img_h = max(pred_img_h, pred_edge_img.shape[0]) + pred_img_w += pred_edge_img.shape[1] + + vis_img = np.zeros((pred_img_h, pred_img_w, 3), dtype=np.uint8) + vis_img[:h, :w] = img + vis_img[:, w:] = 255 + + height_t, width_t = pred_edge_img.shape[:2] + vis_img[:height_t, w:(w + width_t)] = pred_edge_img + + if show: + dbnet_cv.imshow(vis_img, win_name, wait_time) + if out_file is not None: + dbnet_cv.imwrite(vis_img, out_file) + res_dic = { + 'boxes': boxes, + 'nodes': result['nodes'].detach().cpu(), + 'edges': result['edges'].detach().cpu(), + 'metas': result['img_metas'][0] + } + dbnet_cv.dump(res_dic, f'{out_file}_res.pkl') + + return vis_img diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/__init__.py new file mode 100755 index 00000000..53e86ff5 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dbnet_det.datasets.builder import DATASETS, build_dataloader, build_dataset + +from . import utils +from .base_dataset import BaseDataset +from .icdar_dataset import IcdarDataset +# from .kie_dataset import KIEDataset +# from .ner_dataset import NerDataset +# from .ocr_dataset import OCRDataset +# from .ocr_seg_dataset import OCRSegDataset +# from .openset_kie_dataset import OpensetKIEDataset +from .pipelines import CustomFormatBundle, DBNetTargets +# from .pipelines import CustomFormatBundle, DBNetTargets, FCENetTargets +from .text_det_dataset import TextDetDataset +from .uniform_concat_dataset import UniformConcatDataset +from .utils import * # NOQA + +# __all__ = [ +# 'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset', +# 'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle', +# 'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets', +# 'NerDataset', 'UniformConcatDataset', 'OpensetKIEDataset' +# ] + +# __all__ += utils.__all__ diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/base_dataset.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/base_dataset.py new file mode 100755 index 00000000..879454da --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/base_dataset.py @@ -0,0 +1,169 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from dbnet_cv.utils import print_log +from dbnet_det.datasets.builder import DATASETS +from dbnet_det.datasets.pipelines import Compose +from torch.utils.data import Dataset + +from dbnet.datasets.builder import build_loader + + +@DATASETS.register_module() +class BaseDataset(Dataset): + """Custom dataset for text detection, text recognition, and their + downstream tasks. + + 1. The text detection annotation format is as follows: + The `annotations` field is optional for testing + (this is one line of anno_file, with line-json-str + converted to dict for visualizing only). + + .. code-block:: json + + { + "file_name": "sample.jpg", + "height": 1080, + "width": 960, + "annotations": + [ + { + "iscrowd": 0, + "category_id": 1, + "bbox": [357.0, 667.0, 804.0, 100.0], + "segmentation": [[361, 667, 710, 670, + 72, 767, 357, 763]] + } + ] + } + + 2. The two text recognition annotation formats are as follows: + The `x1,y1,x2,y2,x3,y3,x4,y4` field is used for online crop + augmentation during training. + + format1: sample.jpg hello + format2: sample.jpg 20 20 100 20 100 40 20 40 hello + + Args: + ann_file (str): Annotation file path. + pipeline (list[dict]): Processing pipeline. + loader (dict): Dictionary to construct loader + to load annotation infos. + img_prefix (str, optional): Image prefix to generate full + image path. + test_mode (bool, optional): If set True, try...except will + be turned off in __getitem__. + """ + + def __init__(self, + ann_file, + loader, + pipeline, + img_prefix='', + test_mode=False): + super().__init__() + self.test_mode = test_mode + self.img_prefix = img_prefix + self.ann_file = ann_file + # load annotations + loader.update(ann_file=ann_file) + self.data_infos = build_loader(loader) + # processing pipeline + self.pipeline = Compose(pipeline) + # set group flag and class, no meaning + # for text detect and recognize + self._set_group_flag() + self.CLASSES = 0 + + def __len__(self): + return len(self.data_infos) + + def _set_group_flag(self): + """Set flag.""" + self.flag = np.zeros(len(self), dtype=np.uint8) + + def pre_pipeline(self, results): + """Prepare results dict for pipeline.""" + results['img_prefix'] = self.img_prefix + + def prepare_train_img(self, index): + """Get training data and annotations from pipeline. + + Args: + index (int): Index of data. + + Returns: + dict: Training data and annotation after pipeline with new keys + introduced by pipeline. + """ + img_info = self.data_infos[index] + results = dict(img_info=img_info) + self.pre_pipeline(results) + return self.pipeline(results) + + def prepare_test_img(self, img_info): + """Get testing data from pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Testing data after pipeline with new keys introduced by + pipeline. + """ + return self.prepare_train_img(img_info) + + def _log_error_index(self, index): + """Logging data info of bad index.""" + try: + data_info = self.data_infos[index] + img_prefix = self.img_prefix + print_log(f'Warning: skip broken file {data_info} ' + f'with img_prefix {img_prefix}') + except Exception as e: + print_log(f'load index {index} with error {e}') + + def _get_next_index(self, index): + """Get next index from dataset.""" + self._log_error_index(index) + index = (index + 1) % len(self) + return index + + def __getitem__(self, index): + """Get training/test data from pipeline. + + Args: + index (int): Index of data. + + Returns: + dict: Training/test data. + """ + if self.test_mode: + return self.prepare_test_img(index) + + while True: + try: + data = self.prepare_train_img(index) + if data is None: + raise Exception('prepared train data empty') + break + except Exception as e: + print_log(f'prepare index {index} with error {e}') + index = self._get_next_index(index) + return data + + def format_results(self, results, **kwargs): + """Placeholder to format result to dataset-specific output.""" + pass + + def evaluate(self, results, metric=None, logger=None, **kwargs): + """Evaluate the dataset. + + Args: + results (list): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + Returns: + dict[str: float] + """ + raise NotImplementedError diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/builder.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/builder.py new file mode 100755 index 00000000..a4dcfbc7 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/builder.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dbnet_cv.utils import Registry, build_from_cfg + +LOADERS = Registry('loader') +PARSERS = Registry('parser') + + +def build_loader(cfg): + """Build anno file loader.""" + return build_from_cfg(cfg, LOADERS) + + +def build_parser(cfg): + """Build anno file parser.""" + return build_from_cfg(cfg, PARSERS) diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/icdar_dataset.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/icdar_dataset.py new file mode 100755 index 00000000..f18cece8 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/icdar_dataset.py @@ -0,0 +1,191 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dbnet_cv +import numpy as np +from dbnet_det.datasets.api_wrappers import COCO +from dbnet_det.datasets.builder import DATASETS +from dbnet_det.datasets.coco import CocoDataset + +import dbnet.utils as utils +from dbnet import digit_version +from dbnet.core.evaluation.hmean import eval_hmean + + +@DATASETS.register_module() +class IcdarDataset(CocoDataset): + """Dataset for text detection while ann_file in coco format. + + Args: + ann_file_backend (str): Storage backend for annotation file, + should be one in ['disk', 'petrel', 'http']. Default to 'disk'. + """ + CLASSES = ('text') + + def __init__(self, + ann_file, + pipeline, + classes=None, + data_root=None, + img_prefix='', + seg_prefix=None, + proposal_file=None, + test_mode=False, + filter_empty_gt=True, + select_first_k=-1, + ann_file_backend='disk'): + # select first k images for fast debugging. + self.select_first_k = select_first_k + assert ann_file_backend in ['disk', 'petrel', 'http'] + self.ann_file_backend = ann_file_backend + + super().__init__(ann_file, pipeline, classes, data_root, img_prefix, + seg_prefix, proposal_file, test_mode, filter_empty_gt) + + # Set dummy flags just to be compatible with dbnet_det + self.flag = np.zeros(len(self), dtype=np.uint8) + + def load_annotations(self, ann_file): + """Load annotation from COCO style annotation file. + + Args: + ann_file (str): Path of annotation file. + + Returns: + list[dict]: Annotation info from COCO api. + """ + if self.ann_file_backend == 'disk': + self.coco = COCO(ann_file) + else: + dbnet_cv_version = digit_version(dbnet_cv.__version__) + if dbnet_cv_version < digit_version('1.3.16'): + raise Exception('Please update dbnet_cv to 1.3.16 or higher ' + 'to enable "get_local_path" of "FileClient".') + file_client = dbnet_cv.FileClient(backend=self.ann_file_backend) + with file_client.get_local_path(ann_file) as local_path: + self.coco = COCO(local_path) + self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES) + self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} + self.img_ids = self.coco.get_img_ids() + data_infos = [] + + count = 0 + for i in self.img_ids: + info = self.coco.load_imgs([i])[0] + info['filename'] = info['file_name'] + data_infos.append(info) + count = count + 1 + if count > self.select_first_k and self.select_first_k > 0: + break + return data_infos + + def _parse_ann_info(self, img_info, ann_info): + """Parse bbox and mask annotation. + + Args: + ann_info (list[dict]): Annotation info of an image. + + Returns: + dict: A dict containing the following keys: bboxes, bboxes_ignore, + labels, masks, masks_ignore, seg_map. "masks" and + "masks_ignore" are represented by polygon boundary + point sequences. + """ + gt_bboxes = [] + gt_labels = [] + gt_bboxes_ignore = [] + gt_masks_ignore = [] + gt_masks_ann = [] + + for ann in ann_info: + if ann.get('ignore', False): + continue + x1, y1, w, h = ann['bbox'] + if ann['area'] <= 0 or w < 1 or h < 1: + continue + if ann['category_id'] not in self.cat_ids: + continue + bbox = [x1, y1, x1 + w, y1 + h] + if ann.get('iscrowd', False): + gt_bboxes_ignore.append(bbox) + gt_masks_ignore.append(ann.get( + 'segmentation', None)) # to float32 for latter processing + + else: + gt_bboxes.append(bbox) + gt_labels.append(self.cat2label[ann['category_id']]) + gt_masks_ann.append(ann.get('segmentation', None)) + if gt_bboxes: + gt_bboxes = np.array(gt_bboxes, dtype=np.float32) + gt_labels = np.array(gt_labels, dtype=np.int64) + else: + gt_bboxes = np.zeros((0, 4), dtype=np.float32) + gt_labels = np.array([], dtype=np.int64) + + if gt_bboxes_ignore: + gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32) + else: + gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) + + seg_map = img_info['filename'].replace('jpg', 'png') + + ann = dict( + bboxes=gt_bboxes, + labels=gt_labels, + bboxes_ignore=gt_bboxes_ignore, + masks_ignore=gt_masks_ignore, + masks=gt_masks_ann, + seg_map=seg_map) + + return ann + + def evaluate(self, + results, + metric='hmean-iou', + logger=None, + score_thr=None, + min_score_thr=0.3, + max_score_thr=0.9, + step=0.1, + rank_list=None, + **kwargs): + """Evaluate the hmean metric. + + Args: + results (list[dict]): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + score_thr (float): Deprecated. Please use min_score_thr instead. + min_score_thr (float): Minimum score threshold of prediction map. + max_score_thr (float): Maximum score threshold of prediction map. + step (float): The spacing between score thresholds. + rank_list (str): json file used to save eval result + of each image after ranking. + Returns: + dict[dict[str: float]]: The evaluation results. + """ + assert utils.is_type_list(results, dict) + + metrics = metric if isinstance(metric, list) else [metric] + allowed_metrics = ['hmean-iou', 'hmean-ic13'] + metrics = set(metrics) & set(allowed_metrics) + + img_infos = [] + ann_infos = [] + for i in range(len(self)): + img_info = {'filename': self.data_infos[i]['file_name']} + img_infos.append(img_info) + ann_infos.append(self.get_ann_info(i)) + + eval_results = eval_hmean( + results, + img_infos, + ann_infos, + metrics=metrics, + score_thr=score_thr, + min_score_thr=min_score_thr, + max_score_thr=max_score_thr, + step=step, + logger=logger, + rank_list=rank_list) + + return eval_results diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/__init__.py new file mode 100755 index 00000000..ac360637 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .box_utils import sort_vertex, sort_vertex8 +from .custom_format_bundle import CustomFormatBundle +from .dbnet_transforms import EastRandomCrop, ImgAug +# from .kie_transforms import KIEFormatBundle, ResizeNoImg +from .loading import (LoadImageFromLMDB, LoadImageFromNdarray, + LoadTextAnnotations) +# from .ner_transforms import NerTransform, ToTensorNER +# from .ocr_seg_targets import OCRSegTargets +# from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR, +# OpencvToPil, PilToOpencv, RandomPaddingOCR, +# RandomRotateImageBox, ResizeOCR, ToTensorOCR) +# from .test_time_aug import MultiRotateAugOCR +from .textdet_targets import DBNetTargets +# from .textdet_targets import (DBNetTargets +# # FCENetTargets, PANetTargets, +# # TextSnakeTargets +# ) +from .transform_wrappers import OneOfWrapper, RandomWrapper, TorchVisionWrapper +from .transforms import (ColorJitter, PyramidRescale, RandomCropFlip, + RandomCropInstances, RandomCropPolyInstances, + RandomRotatePolyInstances, RandomRotateTextDet, + RandomScaling, ScaleAspectJitter, SquareResizePad) + +# __all__ = [ +# 'LoadTextAnnotations', 'NormalizeOCR', 'OnlineCropOCR', 'ResizeOCR', +# 'ToTensorOCR', 'CustomFormatBundle', 'DBNetTargets', 'PANetTargets', +# 'ColorJitter', 'RandomCropInstances', 'RandomRotateTextDet', +# 'ScaleAspectJitter', 'MultiRotateAugOCR', 'OCRSegTargets', 'FancyPCA', +# 'RandomCropPolyInstances', 'RandomRotatePolyInstances', 'RandomPaddingOCR', +# 'ImgAug', 'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil', +# 'PilToOpencv', 'KIEFormatBundle', 'SquareResizePad', 'TextSnakeTargets', +# 'sort_vertex', 'LoadImageFromNdarray', 'sort_vertex8', 'FCENetTargets', +# 'RandomScaling', 'RandomCropFlip', 'NerTransform', 'ToTensorNER', +# 'ResizeNoImg', 'PyramidRescale', 'OneOfWrapper', 'RandomWrapper', +# 'TorchVisionWrapper', 'LoadImageFromLMDB' +# ] diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/box_utils.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/box_utils.py new file mode 100755 index 00000000..9e887e3e --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/box_utils.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +import dbnet.utils as utils + + +def sort_vertex(points_x, points_y): + """Sort box vertices in clockwise order from left-top first. + + Args: + points_x (list[float]): x of four vertices. + points_y (list[float]): y of four vertices. + Returns: + sorted_points_x (list[float]): x of sorted four vertices. + sorted_points_y (list[float]): y of sorted four vertices. + """ + assert utils.is_type_list(points_x, (float, int)) + assert utils.is_type_list(points_y, (float, int)) + assert len(points_x) == 4 + assert len(points_y) == 4 + vertices = np.stack((points_x, points_y), axis=-1).astype(np.float32) + vertices = _sort_vertex(vertices) + sorted_points_x = list(vertices[:, 0]) + sorted_points_y = list(vertices[:, 1]) + return sorted_points_x, sorted_points_y + + +def _sort_vertex(vertices): + assert vertices.ndim == 2 + assert vertices.shape[-1] == 2 + N = vertices.shape[0] + if N == 0: + return vertices + + center = np.mean(vertices, axis=0) + directions = vertices - center + angles = np.arctan2(directions[:, 1], directions[:, 0]) + sort_idx = np.argsort(angles) + vertices = vertices[sort_idx] + + left_top = np.min(vertices, axis=0) + dists = np.linalg.norm(left_top - vertices, axis=-1, ord=2) + lefttop_idx = np.argmin(dists) + indexes = (np.arange(N, dtype=np.int) + lefttop_idx) % N + return vertices[indexes] + + +def sort_vertex8(points): + """Sort vertex with 8 points [x1 y1 x2 y2 x3 y3 x4 y4]""" + assert len(points) == 8 + vertices = _sort_vertex(np.array(points, dtype=np.float32).reshape(-1, 2)) + sorted_box = list(vertices.flatten()) + return sorted_box diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/crop.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/crop.py new file mode 100755 index 00000000..153365f5 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/crop.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import cv2 +import numpy as np +from shapely.geometry import LineString, Point + +import dbnet.utils as utils +from .box_utils import sort_vertex + + +def box_jitter(points_x, points_y, jitter_ratio_x=0.5, jitter_ratio_y=0.1): + """Jitter on the coordinates of bounding box. + + Args: + points_x (list[float | int]): List of y for four vertices. + points_y (list[float | int]): List of x for four vertices. + jitter_ratio_x (float): Horizontal jitter ratio relative to the height. + jitter_ratio_y (float): Vertical jitter ratio relative to the height. + """ + assert len(points_x) == 4 + assert len(points_y) == 4 + assert isinstance(jitter_ratio_x, float) + assert isinstance(jitter_ratio_y, float) + assert 0 <= jitter_ratio_x < 1 + assert 0 <= jitter_ratio_y < 1 + + points = [Point(points_x[i], points_y[i]) for i in range(4)] + line_list = [ + LineString([points[i], points[i + 1 if i < 3 else 0]]) + for i in range(4) + ] + + tmp_h = max(line_list[1].length, line_list[3].length) + + for i in range(4): + jitter_pixel_x = (np.random.rand() - 0.5) * 2 * jitter_ratio_x * tmp_h + jitter_pixel_y = (np.random.rand() - 0.5) * 2 * jitter_ratio_y * tmp_h + points_x[i] += jitter_pixel_x + points_y[i] += jitter_pixel_y + + +def warp_img(src_img, + box, + jitter_flag=False, + jitter_ratio_x=0.5, + jitter_ratio_y=0.1): + """Crop box area from image using opencv warpPerspective w/o box jitter. + + Args: + src_img (np.array): Image before cropping. + box (list[float | int]): Coordinates of quadrangle. + """ + assert utils.is_type_list(box, (float, int)) + assert len(box) == 8 + + h, w = src_img.shape[:2] + points_x = [min(max(x, 0), w) for x in box[0:8:2]] + points_y = [min(max(y, 0), h) for y in box[1:9:2]] + + points_x, points_y = sort_vertex(points_x, points_y) + + if jitter_flag: + box_jitter( + points_x, + points_y, + jitter_ratio_x=jitter_ratio_x, + jitter_ratio_y=jitter_ratio_y) + + points = [Point(points_x[i], points_y[i]) for i in range(4)] + edges = [ + LineString([points[i], points[i + 1 if i < 3 else 0]]) + for i in range(4) + ] + + pts1 = np.float32([[points[i].x, points[i].y] for i in range(4)]) + box_width = max(edges[0].length, edges[2].length) + box_height = max(edges[1].length, edges[3].length) + + pts2 = np.float32([[0, 0], [box_width, 0], [box_width, box_height], + [0, box_height]]) + M = cv2.getPerspectiveTransform(pts1, pts2) + dst_img = cv2.warpPerspective(src_img, M, + (int(box_width), int(box_height))) + + return dst_img + + +def crop_img(src_img, box, long_edge_pad_ratio=0.4, short_edge_pad_ratio=0.2): + """Crop text region with their bounding box. + + Args: + src_img (np.array): The original image. + box (list[float | int]): Points of quadrangle. + long_edge_pad_ratio (float): Box pad ratio for long edge + corresponding to font size. + short_edge_pad_ratio (float): Box pad ratio for short edge + corresponding to font size. + """ + assert utils.is_type_list(box, (float, int)) + assert len(box) == 8 + assert 0. <= long_edge_pad_ratio < 1.0 + assert 0. <= short_edge_pad_ratio < 1.0 + + h, w = src_img.shape[:2] + points_x = np.clip(np.array(box[0::2]), 0, w) + points_y = np.clip(np.array(box[1::2]), 0, h) + + box_width = np.max(points_x) - np.min(points_x) + box_height = np.max(points_y) - np.min(points_y) + font_size = min(box_height, box_width) + + if box_height < box_width: + horizontal_pad = long_edge_pad_ratio * font_size + vertical_pad = short_edge_pad_ratio * font_size + else: + horizontal_pad = short_edge_pad_ratio * font_size + vertical_pad = long_edge_pad_ratio * font_size + + left = np.clip(int(np.min(points_x) - horizontal_pad), 0, w) + top = np.clip(int(np.min(points_y) - vertical_pad), 0, h) + right = np.clip(int(np.max(points_x) + horizontal_pad), 0, w) + bottom = np.clip(int(np.max(points_y) + vertical_pad), 0, h) + + dst_img = src_img[top:bottom, left:right] + + return dst_img diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/custom_format_bundle.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/custom_format_bundle.py new file mode 100755 index 00000000..e64938b4 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/custom_format_bundle.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from dbnet_cv.parallel import DataContainer as DC +from dbnet_det.datasets.builder import PIPELINES +from dbnet_det.datasets.pipelines.formatting import DefaultFormatBundle + +from dbnet.core.visualize import overlay_mask_img, show_feature + + +@PIPELINES.register_module() +class CustomFormatBundle(DefaultFormatBundle): + """Custom formatting bundle. + + It formats common fields such as 'img' and 'proposals' as done in + DefaultFormatBundle, while other fields such as 'gt_kernels' and + 'gt_effective_region_mask' will be formatted to DC as follows: + + - gt_kernels: to DataContainer (cpu_only=True) + - gt_effective_mask: to DataContainer (cpu_only=True) + + Args: + keys (list[str]): Fields to be formatted to DC only. + call_super (bool): If True, format common fields + by DefaultFormatBundle, else format fields in keys above only. + visualize (dict): If flag=True, visualize gt mask for debugging. + """ + + def __init__(self, + keys=[], + call_super=True, + visualize=dict(flag=False, boundary_key=None)): + + super().__init__() + self.visualize = visualize + self.keys = keys + self.call_super = call_super + + def __call__(self, results): + + if self.visualize['flag']: + img = results['img'].astype(np.uint8) + boundary_key = self.visualize['boundary_key'] + if boundary_key is not None: + img = overlay_mask_img(img, results[boundary_key].masks[0]) + + features = [img] + names = ['img'] + to_uint8 = [1] + + for k in results['mask_fields']: + for iter in range(len(results[k].masks)): + features.append(results[k].masks[iter]) + names.append(k + str(iter)) + to_uint8.append(0) + show_feature(features, names, to_uint8) + + if self.call_super: + results = super().__call__(results) + + for k in self.keys: + results[k] = DC(results[k], cpu_only=True) + + return results + + def __repr__(self): + return self.__class__.__name__ diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/dbnet_transforms.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/dbnet_transforms.py new file mode 100755 index 00000000..006fc12a --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/dbnet_transforms.py @@ -0,0 +1,325 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import imgaug +import imgaug.augmenters as iaa +import dbnet_cv +import numpy as np +from dbnet_det.core.mask import PolygonMasks +from dbnet_det.datasets.builder import PIPELINES + + +class AugmenterBuilder: + """Build imgaug object according ImgAug argmentations.""" + + def __init__(self): + pass + + def build(self, args, root=True): + if args is None: + return None + if isinstance(args, (int, float, str)): + return args + if isinstance(args, list): + if root: + sequence = [self.build(value, root=False) for value in args] + return iaa.Sequential(sequence) + arg_list = [self.to_tuple_if_list(a) for a in args[1:]] + return getattr(iaa, args[0])(*arg_list) + if isinstance(args, dict): + if 'cls' in args: + cls = getattr(iaa, args['cls']) + return cls( + **{ + k: self.to_tuple_if_list(v) + for k, v in args.items() if not k == 'cls' + }) + else: + return { + key: self.build(value, root=False) + for key, value in args.items() + } + raise RuntimeError('unknown augmenter arg: ' + str(args)) + + def to_tuple_if_list(self, obj): + if isinstance(obj, list): + return tuple(obj) + return obj + + +@PIPELINES.register_module() +class ImgAug: + """A wrapper to use imgaug https://github.com/aleju/imgaug. + + Args: + args ([list[list|dict]]): The argumentation list. For details, please + refer to imgaug document. Take args=[['Fliplr', 0.5], + dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]] as an + example. The args horizontally flip images with probability 0.5, + followed by random rotation with angles in range [-10, 10], and + resize with an independent scale in range [0.5, 3.0] for each + side of images. + clip_invalid_polys (bool): Whether to clip invalid polygons after + transformation. False persists to the behavior in DBNet. + """ + + def __init__(self, args=None, clip_invalid_ploys=True): + self.augmenter_args = args + self.augmenter = AugmenterBuilder().build(self.augmenter_args) + self.clip_invalid_polys = clip_invalid_ploys + + def __call__(self, results): + # img is bgr + image = results['img'] + aug = None + shape = image.shape + + if self.augmenter: + aug = self.augmenter.to_deterministic() + results['img'] = aug.augment_image(image) + results['img_shape'] = results['img'].shape + results['flip'] = 'unknown' # it's unknown + results['flip_direction'] = 'unknown' # it's unknown + target_shape = results['img_shape'] + + self.may_augment_annotation(aug, shape, target_shape, results) + + return results + + def may_augment_annotation(self, aug, shape, target_shape, results): + if aug is None: + return results + + # augment polygon mask + for key in results['mask_fields']: + if self.clip_invalid_polys: + masks = self.may_augment_poly(aug, shape, results[key]) + results[key] = PolygonMasks(masks, *target_shape[:2]) + else: + masks = self.may_augment_poly_legacy(aug, shape, results[key]) + if len(masks) > 0: + results[key] = PolygonMasks(masks, *target_shape[:2]) + + # augment bbox + for key in results['bbox_fields']: + bboxes = self.may_augment_bbox(aug, shape, results[key]) + results[key] = np.zeros(0) + if len(bboxes) > 0: + results[key] = np.stack(bboxes) + + return results + + def may_augment_bbox(self, aug, ori_shape, bboxes): + imgaug_bboxes = [] + for bbox in bboxes: + x1, y1, x2, y2 = bbox + imgaug_bboxes.append( + imgaug.BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2)) + imgaug_bboxes = aug.augment_bounding_boxes([ + imgaug.BoundingBoxesOnImage(imgaug_bboxes, shape=ori_shape) + ])[0].clip_out_of_image() + + new_bboxes = [] + for box in imgaug_bboxes.bounding_boxes: + new_bboxes.append( + np.array([box.x1, box.y1, box.x2, box.y2], dtype=np.float32)) + + return new_bboxes + + def may_augment_poly(self, aug, img_shape, polys): + imgaug_polys = [] + for poly in polys: + poly = poly[0] + poly = poly.reshape(-1, 2) + imgaug_polys.append(imgaug.Polygon(poly)) + imgaug_polys = aug.augment_polygons( + [imgaug.PolygonsOnImage(imgaug_polys, + shape=img_shape)])[0].clip_out_of_image() + + new_polys = [] + for poly in imgaug_polys.polygons: + new_poly = [] + for point in poly: + new_poly.append(np.array(point, dtype=np.float32)) + new_poly = np.array(new_poly, dtype=np.float32).flatten() + new_polys.append([new_poly]) + + return new_polys + + def may_augment_poly_legacy(self, aug, img_shape, polys): + key_points, poly_point_nums = [], [] + for poly in polys: + poly = poly[0] + poly = poly.reshape(-1, 2) + key_points.extend([imgaug.Keypoint(p[0], p[1]) for p in poly]) + poly_point_nums.append(poly.shape[0]) + # Warning: we do not clip the out-of-boudnary polygons + key_points = aug.augment_keypoints( + [imgaug.KeypointsOnImage(keypoints=key_points, + shape=img_shape)])[0].keypoints + + new_polys = [] + start_idx = 0 + for poly_point_num in poly_point_nums: + new_poly = [] + for key_point in key_points[start_idx:(start_idx + + poly_point_num)]: + new_poly.append([key_point.x, key_point.y]) + start_idx += poly_point_num + new_poly = np.array(new_poly).flatten() + new_polys.append([new_poly]) + + return new_polys + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class EastRandomCrop: + + def __init__(self, + target_size=(640, 640), + max_tries=10, + min_crop_side_ratio=0.1): + self.target_size = target_size + self.max_tries = max_tries + self.min_crop_side_ratio = min_crop_side_ratio + + def __call__(self, results): + # sampling crop + # crop image, boxes, masks + img = results['img'] + crop_x, crop_y, crop_w, crop_h = self.crop_area( + img, results['gt_masks']) + scale_w = self.target_size[0] / crop_w + scale_h = self.target_size[1] / crop_h + scale = min(scale_w, scale_h) + h = int(crop_h * scale) + w = int(crop_w * scale) + padded_img = np.zeros( + (self.target_size[1], self.target_size[0], img.shape[2]), + img.dtype) + padded_img[:h, :w] = dbnet_cv.imresize( + img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) + + # for bboxes + for key in results['bbox_fields']: + lines = [] + for box in results[key]: + box = box.reshape(2, 2) + poly = ((box - (crop_x, crop_y)) * scale) + if not self.is_poly_outside_rect(poly, 0, 0, w, h): + lines.append(poly.flatten()) + results[key] = np.array(lines) + # for masks + for key in results['mask_fields']: + polys = [] + polys_label = [] + for poly in results[key]: + poly = np.array(poly).reshape(-1, 2) + poly = ((poly - (crop_x, crop_y)) * scale) + if not self.is_poly_outside_rect(poly, 0, 0, w, h): + polys.append([poly]) + polys_label.append(0) + results[key] = PolygonMasks(polys, *self.target_size) + if key == 'gt_masks': + results['gt_labels'] = polys_label + + results['img'] = padded_img + results['img_shape'] = padded_img.shape + + return results + + def is_poly_in_rect(self, poly, x, y, w, h): + poly = np.array(poly) + if poly[:, 0].min() < x or poly[:, 0].max() > x + w: + return False + if poly[:, 1].min() < y or poly[:, 1].max() > y + h: + return False + return True + + def is_poly_outside_rect(self, poly, x, y, w, h): + poly = np.array(poly).reshape(-1, 2) + if poly[:, 0].max() < x or poly[:, 0].min() > x + w: + return True + if poly[:, 1].max() < y or poly[:, 1].min() > y + h: + return True + return False + + def split_regions(self, axis): + regions = [] + min_axis = 0 + for i in range(1, axis.shape[0]): + if axis[i] != axis[i - 1] + 1: + region = axis[min_axis:i] + min_axis = i + regions.append(region) + return regions + + def random_select(self, axis, max_size): + xx = np.random.choice(axis, size=2) + xmin = np.min(xx) + xmax = np.max(xx) + xmin = np.clip(xmin, 0, max_size - 1) + xmax = np.clip(xmax, 0, max_size - 1) + return xmin, xmax + + def region_wise_random_select(self, regions): + selected_index = list(np.random.choice(len(regions), 2)) + selected_values = [] + for index in selected_index: + axis = regions[index] + xx = int(np.random.choice(axis, size=1)) + selected_values.append(xx) + xmin = min(selected_values) + xmax = max(selected_values) + return xmin, xmax + + def crop_area(self, img, polys): + h, w, _ = img.shape + h_array = np.zeros(h, dtype=np.int32) + w_array = np.zeros(w, dtype=np.int32) + for points in polys: + points = np.round( + points, decimals=0).astype(np.int32).reshape(-1, 2) + min_x = np.min(points[:, 0]) + max_x = np.max(points[:, 0]) + w_array[min_x:max_x] = 1 + min_y = np.min(points[:, 1]) + max_y = np.max(points[:, 1]) + h_array[min_y:max_y] = 1 + # ensure the cropped area not across a text + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + + if len(h_axis) == 0 or len(w_axis) == 0: + return 0, 0, w, h + + h_regions = self.split_regions(h_axis) + w_regions = self.split_regions(w_axis) + + for i in range(self.max_tries): + if len(w_regions) > 1: + xmin, xmax = self.region_wise_random_select(w_regions) + else: + xmin, xmax = self.random_select(w_axis, w) + if len(h_regions) > 1: + ymin, ymax = self.region_wise_random_select(h_regions) + else: + ymin, ymax = self.random_select(h_axis, h) + + if (xmax - xmin < self.min_crop_side_ratio * w + or ymax - ymin < self.min_crop_side_ratio * h): + # area too small + continue + num_poly_in_rect = 0 + for poly in polys: + if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, + ymax - ymin): + num_poly_in_rect += 1 + break + + if num_poly_in_rect > 0: + return xmin, ymin, xmax - xmin, ymax - ymin + + return 0, 0, w, h diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/loading.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/loading.py new file mode 100755 index 00000000..d22aa0d8 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/loading.py @@ -0,0 +1,189 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import lmdb +import dbnet_cv +import numpy as np +from dbnet_det.core import BitmapMasks, PolygonMasks +from dbnet_det.datasets.builder import PIPELINES +from dbnet_det.datasets.pipelines.loading import LoadAnnotations, LoadImageFromFile + + +@PIPELINES.register_module() +class LoadTextAnnotations(LoadAnnotations): + """Load annotations for text detection. + + Args: + with_bbox (bool): Whether to parse and load the bbox annotation. + Default: True. + with_label (bool): Whether to parse and load the label annotation. + Default: True. + with_mask (bool): Whether to parse and load the mask annotation. + Default: False. + with_seg (bool): Whether to parse and load the semantic segmentation + annotation. Default: False. + poly2mask (bool): Whether to convert the instance masks from polygons + to bitmaps. Default: True. + use_img_shape (bool): Use the shape of loaded image from + previous pipeline ``LoadImageFromFile`` to generate mask. + """ + + def __init__(self, + with_bbox=True, + with_label=True, + with_mask=False, + with_seg=False, + poly2mask=True, + use_img_shape=False): + super().__init__( + with_bbox=with_bbox, + with_label=with_label, + with_mask=with_mask, + with_seg=with_seg, + poly2mask=poly2mask) + + self.use_img_shape = use_img_shape + + def process_polygons(self, polygons): + """Convert polygons to list of ndarray and filter invalid polygons. + + Args: + polygons (list[list]): Polygons of one instance. + + Returns: + list[numpy.ndarray]: Processed polygons. + """ + + polygons = [np.array(p).astype(np.float32) for p in polygons] + valid_polygons = [] + for polygon in polygons: + if len(polygon) % 2 == 0 and len(polygon) >= 6: + valid_polygons.append(polygon) + return valid_polygons + + def _load_masks(self, results): + ann_info = results['ann_info'] + h, w = results['img_info']['height'], results['img_info']['width'] + if self.use_img_shape: + if results.get('ori_shape', None): + h, w = results['ori_shape'][:2] + results['img_info']['height'] = h + results['img_info']['width'] = w + else: + warnings.warn('"ori_shape" not in results, use the shape ' + 'in "img_info" instead.') + gt_masks = ann_info['masks'] + if self.poly2mask: + gt_masks = BitmapMasks( + [self._poly2mask(mask, h, w) for mask in gt_masks], h, w) + else: + gt_masks = PolygonMasks( + [self.process_polygons(polygons) for polygons in gt_masks], h, + w) + gt_masks_ignore = ann_info.get('masks_ignore', None) + if gt_masks_ignore is not None: + if self.poly2mask: + gt_masks_ignore = BitmapMasks( + [self._poly2mask(mask, h, w) for mask in gt_masks_ignore], + h, w) + else: + gt_masks_ignore = PolygonMasks([ + self.process_polygons(polygons) + for polygons in gt_masks_ignore + ], h, w) + results['gt_masks_ignore'] = gt_masks_ignore + results['mask_fields'].append('gt_masks_ignore') + + results['gt_masks'] = gt_masks + results['mask_fields'].append('gt_masks') + return results + + +@PIPELINES.register_module() +class LoadImageFromNdarray(LoadImageFromFile): + """Load an image from np.ndarray. + + Similar with :obj:`LoadImageFromFile`, but the image read from + ``results['img']``, which is np.ndarray. + """ + + def __call__(self, results): + """Call functions to add image meta information. + + Args: + results (dict): Result dict with Webcam read image in + ``results['img']``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + assert results['img'].dtype == 'uint8' + + img = results['img'] + if self.color_type == 'grayscale' and img.shape[2] == 3: + img = dbnet_cv.bgr2gray(img, keepdim=True) + if self.color_type == 'color' and img.shape[2] == 1: + img = dbnet_cv.gray2bgr(img) + if self.to_float32: + img = img.astype(np.float32) + + results['filename'] = None + results['ori_filename'] = None + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + results['img_fields'] = ['img'] + return results + + +@PIPELINES.register_module() +class LoadImageFromLMDB(object): + """Load an image from lmdb file. + + Similar with :obj:'LoadImageFromFile', but the image read from + "results['img_info']['filename']", which is a data index of lmdb file. + """ + + def __init__(self, color_type='color'): + self.color_type = color_type + self.env = None + self.txn = None + + def __call__(self, results): + img_key = results['img_info']['filename'] + lmdb_path = results['img_prefix'] + + # lmdb env + if self.env is None: + self.env = lmdb.open( + lmdb_path, + max_readers=1, + readonly=True, + lock=False, + readahead=False, + meminit=False, + ) + # read image + with self.env.begin(write=False) as txn: + imgbuf = txn.get(img_key.encode('utf-8')) + try: + img = dbnet_cv.imfrombytes(imgbuf, flag=self.color_type) + except IOError: + print('Corrupted image for {}'.format(img_key)) + return None + + results['filename'] = img_key + results['ori_filename'] = img_key + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + results['img_fields'] = ['img'] + return results + + def __repr__(self): + return '{} (color_type={})'.format(self.__class__.__name__, + self.color_type) + + def __del__(self): + if self.env is not None: + self.env.close() diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/textdet_targets/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/textdet_targets/__init__.py new file mode 100755 index 00000000..ddbf10ed --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/textdet_targets/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_textdet_targets import BaseTextDetTargets +from .dbnet_targets import DBNetTargets +# from .drrg_targets import DRRGTargets +# from .fcenet_targets import FCENetTargets +# from .panet_targets import PANetTargets +# from .psenet_targets import PSENetTargets +# from .textsnake_targets import TextSnakeTargets + +# __all__ = [ +# 'BaseTextDetTargets', 'PANetTargets', 'PSENetTargets', 'DBNetTargets', +# 'FCENetTargets', 'TextSnakeTargets', 'DRRGTargets' +# ] diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/textdet_targets/base_textdet_targets.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/textdet_targets/base_textdet_targets.py new file mode 100755 index 00000000..a034ce77 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/textdet_targets/base_textdet_targets.py @@ -0,0 +1,168 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys + +import cv2 +import numpy as np +import pyclipper +from dbnet_cv.utils import print_log +from shapely.geometry import Polygon as plg + +import dbnet.utils.check_argument as check_argument + + +class BaseTextDetTargets: + """Generate text detector ground truths.""" + + def __init__(self): + pass + + def point2line(self, xs, ys, point_1, point_2): + """Compute the distance from point to a line. This is adapted from + https://github.com/MhLiao/DB. + + Args: + xs (ndarray): The x coordinates of size hxw. + ys (ndarray): The y coordinates of size hxw. + point_1 (ndarray): The first point with shape 1x2. + point_2 (ndarray): The second point with shape 1x2. + + Returns: + result (ndarray): The distance matrix of size hxw. + """ + # suppose a triangle with three edge abc with c=point_1 point_2 + # a^2 + a_square = np.square(xs - point_1[0]) + np.square(ys - point_1[1]) + # b^2 + b_square = np.square(xs - point_2[0]) + np.square(ys - point_2[1]) + # c^2 + c_square = np.square(point_1[0] - point_2[0]) + np.square(point_1[1] - + point_2[1]) + # -cosC=(c^2-a^2-b^2)/2(ab) + neg_cos_c = ( + (c_square - a_square - b_square) / + (np.finfo(np.float32).eps + 2 * np.sqrt(a_square * b_square))) + # sinC^2=1-cosC^2 + square_sin = 1 - np.square(neg_cos_c) + square_sin = np.nan_to_num(square_sin) + # distance=a*b*sinC/c=a*h/c=2*area/c + result = np.sqrt(a_square * b_square * square_sin / + (np.finfo(np.float32).eps + c_square)) + # set result to minimum edge if C 0: + padded_polygon = np.array(padded_polygon[0]) + else: + print(f'padding {polygon} with {distance} gets {padded_polygon}') + padded_polygon = polygon.copy().astype(np.int32) + + x_min = padded_polygon[:, 0].min() + x_max = padded_polygon[:, 0].max() + y_min = padded_polygon[:, 1].min() + y_max = padded_polygon[:, 1].max() + + width = x_max - x_min + 1 + height = y_max - y_min + 1 + + polygon[:, 0] = polygon[:, 0] - x_min + polygon[:, 1] = polygon[:, 1] - y_min + + xs = np.broadcast_to( + np.linspace(0, width - 1, num=width).reshape(1, width), + (height, width)) + ys = np.broadcast_to( + np.linspace(0, height - 1, num=height).reshape(height, 1), + (height, width)) + + distance_map = np.zeros((polygon.shape[0], height, width), + dtype=np.float32) + for i in range(polygon.shape[0]): + j = (i + 1) % polygon.shape[0] + absolute_distance = self.point2line(xs, ys, polygon[i], polygon[j]) + distance_map[i] = np.clip(absolute_distance / distance, 0, 1) + distance_map = distance_map.min(axis=0) + + x_min_valid = min(max(0, x_min), canvas.shape[1] - 1) + x_max_valid = min(max(0, x_max), canvas.shape[1] - 1) + y_min_valid = min(max(0, y_min), canvas.shape[0] - 1) + y_max_valid = min(max(0, y_max), canvas.shape[0] - 1) + + if x_min_valid - x_min >= width or y_min_valid - y_min >= height: + return + + cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) + canvas[y_min_valid:y_max_valid + 1, + x_min_valid:x_max_valid + 1] = np.fmax( + 1 - distance_map[y_min_valid - y_min:y_max_valid - y_max + + height, x_min_valid - x_min:x_max_valid - + x_max + width], + canvas[y_min_valid:y_max_valid + 1, + x_min_valid:x_max_valid + 1]) + + def generate_targets(self, results): + """Generate the gt targets for DBNet. + + Args: + results (dict): The input result dictionary. + + Returns: + results (dict): The output result dictionary. + """ + assert isinstance(results, dict) + + if 'bbox_fields' in results: + results['bbox_fields'].clear() + + ignore_tags = self.find_invalid(results) + results, ignore_tags = self.ignore_texts(results, ignore_tags) + + h, w, _ = results['img_shape'] + polygons = results['gt_masks'].masks + + # generate gt_shrink_kernel + gt_shrink, ignore_tags = self.generate_kernels((h, w), + polygons, + self.shrink_ratio, + ignore_tags=ignore_tags) + + results, ignore_tags = self.ignore_texts(results, ignore_tags) + # genenrate gt_shrink_mask + polygons_ignore = results['gt_masks_ignore'].masks + gt_shrink_mask = self.generate_effective_mask((h, w), polygons_ignore) + + # generate gt_threshold and gt_threshold_mask + polygons = results['gt_masks'].masks + gt_thr, gt_thr_mask = self.generate_thr_map((h, w), polygons) + + results['mask_fields'].clear() # rm gt_masks encoded by polygons + results.pop('gt_labels', None) + results.pop('gt_masks', None) + results.pop('gt_bboxes', None) + results.pop('gt_bboxes_ignore', None) + + mapping = { + 'gt_shrink': gt_shrink, + 'gt_shrink_mask': gt_shrink_mask, + 'gt_thr': gt_thr, + 'gt_thr_mask': gt_thr_mask + } + for key, value in mapping.items(): + value = value if isinstance(value, list) else [value] + results[key] = BitmapMasks(value, h, w) + results['mask_fields'].append(key) + + return results diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/transform_wrappers.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/transform_wrappers.py new file mode 100755 index 00000000..dc5af00e --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/transform_wrappers.py @@ -0,0 +1,128 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +import random + +import dbnet_cv +import numpy as np +import torchvision.transforms as torchvision_transforms +from dbnet_cv.utils import build_from_cfg +from dbnet_det.datasets.builder import PIPELINES +from dbnet_det.datasets.pipelines import Compose +from PIL import Image + + +@PIPELINES.register_module() +class OneOfWrapper: + """Randomly select and apply one of the transforms, each with the equal + chance. + + Warning: + Different from albumentations, this wrapper only runs the selected + transform, but doesn't guarantee the transform can always be applied to + the input if the transform comes with a probability to run. + + Args: + transforms (list[dict|callable]): Candidate transforms to be applied. + """ + + def __init__(self, transforms): + assert isinstance(transforms, list) or isinstance(transforms, tuple) + assert len(transforms) > 0, 'Need at least one transform.' + self.transforms = [] + for t in transforms: + if isinstance(t, dict): + self.transforms.append(build_from_cfg(t, PIPELINES)) + elif callable(t): + self.transforms.append(t) + else: + raise TypeError('transform must be callable or a dict') + + def __call__(self, results): + return random.choice(self.transforms)(results) + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(transforms={self.transforms})' + return repr_str + + +@PIPELINES.register_module() +class RandomWrapper: + """Run a transform or a sequence of transforms with probability p. + + Args: + transforms (list[dict|callable]): Transform(s) to be applied. + p (int|float): Probability of running transform(s). + """ + + def __init__(self, transforms, p): + assert 0 <= p <= 1 + self.transforms = Compose(transforms) + self.p = p + + def __call__(self, results): + return results if np.random.uniform() > self.p else self.transforms( + results) + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(transforms={self.transforms}, ' + repr_str += f'p={self.p})' + return repr_str + + +@PIPELINES.register_module() +class TorchVisionWrapper: + """A wrapper of torchvision trasnforms. It applies specific transform to + ``img`` and updates ``img_shape`` accordingly. + + Warning: + This transform only affects the image but not its associated + annotations, such as word bounding boxes and polygon masks. Therefore, + it may only be applicable to text recognition tasks. + + Args: + op (str): The name of any transform class in + :func:`torchvision.transforms`. + **kwargs: Arguments that will be passed to initializer of torchvision + transform. + + :Required Keys: + - | ``img`` (ndarray): The input image. + + :Affected Keys: + :Modified: + - | ``img`` (ndarray): The modified image. + :Added: + - | ``img_shape`` (tuple(int)): Size of the modified image. + """ + + def __init__(self, op, **kwargs): + assert type(op) is str + + if dbnet_cv.is_str(op): + obj_cls = getattr(torchvision_transforms, op) + elif inspect.isclass(op): + obj_cls = op + else: + raise TypeError( + f'type must be a str or valid type, but got {type(type)}') + self.transform = obj_cls(**kwargs) + self.kwargs = kwargs + + def __call__(self, results): + assert 'img' in results + # BGR -> RGB + img = results['img'][..., ::-1] + img = Image.fromarray(img) + img = self.transform(img) + img = np.asarray(img) + img = img[..., ::-1] + results['img'] = img + results['img_shape'] = img.shape + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(transform={self.transform})' + return repr_str diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/transforms.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/transforms.py new file mode 100755 index 00000000..e4a7a08f --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/pipelines/transforms.py @@ -0,0 +1,1020 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import cv2 +import dbnet_cv +import numpy as np +import torchvision.transforms as transforms +from dbnet_det.core import BitmapMasks, PolygonMasks +from dbnet_det.datasets.builder import PIPELINES +from dbnet_det.datasets.pipelines.transforms import Resize +from PIL import Image +from shapely.geometry import Polygon as plg + +import dbnet.core.evaluation.utils as eval_utils +from dbnet.utils import check_argument + + +@PIPELINES.register_module() +class RandomCropInstances: + """Randomly crop images and make sure to contain text instances. + + Args: + target_size (tuple or int): (height, width) + positive_sample_ratio (float): The probability of sampling regions + that go through positive regions. + """ + + def __init__( + self, + target_size, + instance_key, + mask_type='inx0', # 'inx0' or 'union_all' + positive_sample_ratio=5.0 / 8.0): + + assert mask_type in ['inx0', 'union_all'] + + self.mask_type = mask_type + self.instance_key = instance_key + self.positive_sample_ratio = positive_sample_ratio + self.target_size = target_size if (target_size is None or isinstance( + target_size, tuple)) else (target_size, target_size) + + def sample_offset(self, img_gt, img_size): + h, w = img_size + t_h, t_w = self.target_size + + # target size is bigger than origin size + t_h = t_h if t_h < h else h + t_w = t_w if t_w < w else w + if (img_gt is not None + and np.random.random_sample() < self.positive_sample_ratio + and np.max(img_gt) > 0): + + # make sure to crop the positive region + + # the minimum top left to crop positive region (h,w) + tl = np.min(np.where(img_gt > 0), axis=1) - (t_h, t_w) + tl[tl < 0] = 0 + # the maximum top left to crop positive region + br = np.max(np.where(img_gt > 0), axis=1) - (t_h, t_w) + br[br < 0] = 0 + # if br is too big so that crop the outside region of img + br[0] = min(br[0], h - t_h) + br[1] = min(br[1], w - t_w) + # + h = np.random.randint(tl[0], br[0]) if tl[0] < br[0] else 0 + w = np.random.randint(tl[1], br[1]) if tl[1] < br[1] else 0 + else: + # make sure not to crop outside of img + + h = np.random.randint(0, h - t_h) if h - t_h > 0 else 0 + w = np.random.randint(0, w - t_w) if w - t_w > 0 else 0 + + return (h, w) + + @staticmethod + def crop_img(img, offset, target_size): + h, w = img.shape[:2] + br = np.min( + np.stack((np.array(offset) + np.array(target_size), np.array( + (h, w)))), + axis=0) + return img[offset[0]:br[0], offset[1]:br[1]], np.array( + [offset[1], offset[0], br[1], br[0]]) + + def crop_bboxes(self, bboxes, canvas_bbox): + kept_bboxes = [] + kept_inx = [] + canvas_poly = eval_utils.box2polygon(canvas_bbox) + tl = canvas_bbox[0:2] + + for idx, bbox in enumerate(bboxes): + poly = eval_utils.box2polygon(bbox) + area, inters = eval_utils.poly_intersection( + poly, canvas_poly, return_poly=True) + if area == 0: + continue + xmin, ymin, xmax, ymax = inters.bounds + kept_bboxes += [ + np.array( + [xmin - tl[0], ymin - tl[1], xmax - tl[0], ymax - tl[1]], + dtype=np.float32) + ] + kept_inx += [idx] + + if len(kept_inx) == 0: + return np.array([]).astype(np.float32).reshape(0, 4), kept_inx + + return np.stack(kept_bboxes), kept_inx + + @staticmethod + def generate_mask(gt_mask, type): + + if type == 'inx0': + return gt_mask.masks[0] + if type == 'union_all': + mask = gt_mask.masks[0].copy() + for idx in range(1, len(gt_mask.masks)): + mask = np.logical_or(mask, gt_mask.masks[idx]) + return mask + + raise NotImplementedError + + def __call__(self, results): + + gt_mask = results[self.instance_key] + mask = None + if len(gt_mask.masks) > 0: + mask = self.generate_mask(gt_mask, self.mask_type) + results['crop_offset'] = self.sample_offset(mask, + results['img'].shape[:2]) + + # crop img. bbox = [x1,y1,x2,y2] + img, bbox = self.crop_img(results['img'], results['crop_offset'], + self.target_size) + results['img'] = img + img_shape = img.shape + results['img_shape'] = img_shape + + # crop masks + for key in results.get('mask_fields', []): + results[key] = results[key].crop(bbox) + + # for mask rcnn + for key in results.get('bbox_fields', []): + results[key], kept_inx = self.crop_bboxes(results[key], bbox) + if key == 'gt_bboxes': + # ignore gt_labels accordingly + if 'gt_labels' in results: + ori_labels = results['gt_labels'] + ori_inst_num = len(ori_labels) + results['gt_labels'] = [ + ori_labels[idx] for idx in range(ori_inst_num) + if idx in kept_inx + ] + # ignore g_masks accordingly + if 'gt_masks' in results: + ori_mask = results['gt_masks'].masks + kept_mask = [ + ori_mask[idx] for idx in range(ori_inst_num) + if idx in kept_inx + ] + target_h, target_w = bbox[3] - bbox[1], bbox[2] - bbox[0] + if len(kept_inx) > 0: + kept_mask = np.stack(kept_mask) + else: + kept_mask = np.empty((0, target_h, target_w), + dtype=np.float32) + results['gt_masks'] = BitmapMasks(kept_mask, target_h, + target_w) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class RandomRotateTextDet: + """Randomly rotate images.""" + + def __init__(self, rotate_ratio=1.0, max_angle=10): + self.rotate_ratio = rotate_ratio + self.max_angle = max_angle + + @staticmethod + def sample_angle(max_angle): + angle = np.random.random_sample() * 2 * max_angle - max_angle + return angle + + @staticmethod + def rotate_img(img, angle): + h, w = img.shape[:2] + rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) + img_target = cv2.warpAffine( + img, rotation_matrix, (w, h), flags=cv2.INTER_NEAREST) + assert img_target.shape == img.shape + return img_target + + def __call__(self, results): + if np.random.random_sample() < self.rotate_ratio: + # rotate imgs + results['rotated_angle'] = self.sample_angle(self.max_angle) + img = self.rotate_img(results['img'], results['rotated_angle']) + results['img'] = img + img_shape = img.shape + results['img_shape'] = img_shape + + # rotate masks + for key in results.get('mask_fields', []): + masks = results[key].masks + mask_list = [] + for m in masks: + rotated_m = self.rotate_img(m, results['rotated_angle']) + mask_list.append(rotated_m) + results[key] = BitmapMasks(mask_list, *(img_shape[:2])) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class ColorJitter: + """An interface for torch color jitter so that it can be invoked in + dbnet_detection pipeline.""" + + def __init__(self, **kwargs): + self.transform = transforms.ColorJitter(**kwargs) + + def __call__(self, results): + # img is bgr + img = results['img'][..., ::-1] + img = Image.fromarray(img) + img = self.transform(img) + img = np.asarray(img) + img = img[..., ::-1] + results['img'] = img + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class ScaleAspectJitter(Resize): + """Resize image and segmentation mask encoded by coordinates. + + Allowed resize types are `around_min_img_scale`, `long_short_bound`, and + `indep_sample_in_range`. + """ + + def __init__(self, + img_scale=None, + multiscale_mode='range', + ratio_range=None, + keep_ratio=False, + resize_type='around_min_img_scale', + aspect_ratio_range=None, + long_size_bound=None, + short_size_bound=None, + scale_range=None): + super().__init__( + img_scale=img_scale, + multiscale_mode=multiscale_mode, + ratio_range=ratio_range, + keep_ratio=keep_ratio) + assert not keep_ratio + assert resize_type in [ + 'around_min_img_scale', 'long_short_bound', 'indep_sample_in_range' + ] + self.resize_type = resize_type + + if resize_type == 'indep_sample_in_range': + assert ratio_range is None + assert aspect_ratio_range is None + assert short_size_bound is None + assert long_size_bound is None + assert scale_range is not None + else: + assert scale_range is None + assert isinstance(ratio_range, tuple) + assert isinstance(aspect_ratio_range, tuple) + assert check_argument.equal_len(ratio_range, aspect_ratio_range) + + if resize_type in ['long_short_bound']: + assert short_size_bound is not None + assert long_size_bound is not None + + self.aspect_ratio_range = aspect_ratio_range + self.long_size_bound = long_size_bound + self.short_size_bound = short_size_bound + self.scale_range = scale_range + + @staticmethod + def sample_from_range(range): + assert len(range) == 2 + min_value, max_value = min(range), max(range) + value = np.random.random_sample() * (max_value - min_value) + min_value + + return value + + def _random_scale(self, results): + + if self.resize_type == 'indep_sample_in_range': + w = self.sample_from_range(self.scale_range) + h = self.sample_from_range(self.scale_range) + results['scale'] = (int(w), int(h)) # (w,h) + results['scale_idx'] = None + return + h, w = results['img'].shape[0:2] + if self.resize_type == 'long_short_bound': + scale1 = 1 + if max(h, w) > self.long_size_bound: + scale1 = self.long_size_bound / max(h, w) + scale2 = self.sample_from_range(self.ratio_range) + scale = scale1 * scale2 + if min(h, w) * scale <= self.short_size_bound: + scale = (self.short_size_bound + 10) * 1.0 / min(h, w) + elif self.resize_type == 'around_min_img_scale': + short_size = min(self.img_scale[0]) + ratio = self.sample_from_range(self.ratio_range) + scale = (ratio * short_size) / min(h, w) + else: + raise NotImplementedError + + aspect = self.sample_from_range(self.aspect_ratio_range) + h_scale = scale * math.sqrt(aspect) + w_scale = scale / math.sqrt(aspect) + results['scale'] = (int(w * w_scale), int(h * h_scale)) # (w,h) + results['scale_idx'] = None + + +@PIPELINES.register_module() +class AffineJitter: + """An interface for torchvision random affine so that it can be invoked in + dbnet_det pipeline.""" + + def __init__(self, + degrees=4, + translate=(0.02, 0.04), + scale=(0.9, 1.1), + shear=None, + resample=False, + fillcolor=0): + self.transform = transforms.RandomAffine( + degrees=degrees, + translate=translate, + scale=scale, + shear=shear, + resample=resample, + fillcolor=fillcolor) + + def __call__(self, results): + # img is bgr + img = results['img'][..., ::-1] + img = Image.fromarray(img) + img = self.transform(img) + img = np.asarray(img) + img = img[..., ::-1] + results['img'] = img + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class RandomCropPolyInstances: + """Randomly crop images and make sure to contain at least one intact + instance.""" + + def __init__(self, + instance_key='gt_masks', + crop_ratio=5.0 / 8.0, + min_side_ratio=0.4): + super().__init__() + self.instance_key = instance_key + self.crop_ratio = crop_ratio + self.min_side_ratio = min_side_ratio + + def sample_valid_start_end(self, valid_array, min_len, max_start, min_end): + + assert isinstance(min_len, int) + assert len(valid_array) > min_len + + start_array = valid_array.copy() + max_start = min(len(start_array) - min_len, max_start) + start_array[max_start:] = 0 + start_array[0] = 1 + diff_array = np.hstack([0, start_array]) - np.hstack([start_array, 0]) + region_starts = np.where(diff_array < 0)[0] + region_ends = np.where(diff_array > 0)[0] + region_ind = np.random.randint(0, len(region_starts)) + start = np.random.randint(region_starts[region_ind], + region_ends[region_ind]) + + end_array = valid_array.copy() + min_end = max(start + min_len, min_end) + end_array[:min_end] = 0 + end_array[-1] = 1 + diff_array = np.hstack([0, end_array]) - np.hstack([end_array, 0]) + region_starts = np.where(diff_array < 0)[0] + region_ends = np.where(diff_array > 0)[0] + region_ind = np.random.randint(0, len(region_starts)) + end = np.random.randint(region_starts[region_ind], + region_ends[region_ind]) + return start, end + + def sample_crop_box(self, img_size, results): + """Generate crop box and make sure not to crop the polygon instances. + + Args: + img_size (tuple(int)): The image size (h, w). + results (dict): The results dict. + """ + + assert isinstance(img_size, tuple) + h, w = img_size[:2] + + key_masks = results[self.instance_key].masks + x_valid_array = np.ones(w, dtype=np.int32) + y_valid_array = np.ones(h, dtype=np.int32) + + selected_mask = key_masks[np.random.randint(0, len(key_masks))] + selected_mask = selected_mask[0].reshape((-1, 2)).astype(np.int32) + max_x_start = max(np.min(selected_mask[:, 0]) - 2, 0) + min_x_end = min(np.max(selected_mask[:, 0]) + 3, w - 1) + max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0) + min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1) + + for key in results.get('mask_fields', []): + if len(results[key].masks) == 0: + continue + masks = results[key].masks + for mask in masks: + assert len(mask) == 1 + mask = mask[0].reshape((-1, 2)).astype(np.int32) + clip_x = np.clip(mask[:, 0], 0, w - 1) + clip_y = np.clip(mask[:, 1], 0, h - 1) + min_x, max_x = np.min(clip_x), np.max(clip_x) + min_y, max_y = np.min(clip_y), np.max(clip_y) + + x_valid_array[min_x - 2:max_x + 3] = 0 + y_valid_array[min_y - 2:max_y + 3] = 0 + + min_w = int(w * self.min_side_ratio) + min_h = int(h * self.min_side_ratio) + + x1, x2 = self.sample_valid_start_end(x_valid_array, min_w, max_x_start, + min_x_end) + y1, y2 = self.sample_valid_start_end(y_valid_array, min_h, max_y_start, + min_y_end) + + return np.array([x1, y1, x2, y2]) + + def crop_img(self, img, bbox): + assert img.ndim == 3 + h, w, _ = img.shape + assert 0 <= bbox[1] < bbox[3] <= h + assert 0 <= bbox[0] < bbox[2] <= w + return img[bbox[1]:bbox[3], bbox[0]:bbox[2]] + + def __call__(self, results): + if len(results[self.instance_key].masks) < 1: + return results + if np.random.random_sample() < self.crop_ratio: + crop_box = self.sample_crop_box(results['img'].shape, results) + results['crop_region'] = crop_box + img = self.crop_img(results['img'], crop_box) + results['img'] = img + results['img_shape'] = img.shape + + # crop and filter masks + x1, y1, x2, y2 = crop_box + w = max(x2 - x1, 1) + h = max(y2 - y1, 1) + labels = results['gt_labels'] + valid_labels = [] + for key in results.get('mask_fields', []): + if len(results[key].masks) == 0: + continue + results[key] = results[key].crop(crop_box) + # filter out polygons beyond crop box. + masks = results[key].masks + valid_masks_list = [] + + for ind, mask in enumerate(masks): + assert len(mask) == 1 + polygon = mask[0].reshape((-1, 2)) + if (polygon[:, 0] > + -4).all() and (polygon[:, 0] < w + 4).all() and ( + polygon[:, 1] > -4).all() and (polygon[:, 1] < + h + 4).all(): + mask[0][::2] = np.clip(mask[0][::2], 0, w) + mask[0][1::2] = np.clip(mask[0][1::2], 0, h) + if key == self.instance_key: + valid_labels.append(labels[ind]) + valid_masks_list.append(mask) + + results[key] = PolygonMasks(valid_masks_list, h, w) + results['gt_labels'] = np.array(valid_labels) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class RandomRotatePolyInstances: + + def __init__(self, + rotate_ratio=0.5, + max_angle=10, + pad_with_fixed_color=False, + pad_value=(0, 0, 0)): + """Randomly rotate images and polygon masks. + + Args: + rotate_ratio (float): The ratio of samples to operate rotation. + max_angle (int): The maximum rotation angle. + pad_with_fixed_color (bool): The flag for whether to pad rotated + image with fixed value. If set to False, the rotated image will + be padded onto cropped image. + pad_value (tuple(int)): The color value for padding rotated image. + """ + self.rotate_ratio = rotate_ratio + self.max_angle = max_angle + self.pad_with_fixed_color = pad_with_fixed_color + self.pad_value = pad_value + + def rotate(self, center, points, theta, center_shift=(0, 0)): + # rotate points. + (center_x, center_y) = center + center_y = -center_y + x, y = points[::2], points[1::2] + y = -y + + theta = theta / 180 * math.pi + cos = math.cos(theta) + sin = math.sin(theta) + + x = (x - center_x) + y = (y - center_y) + + _x = center_x + x * cos - y * sin + center_shift[0] + _y = -(center_y + x * sin + y * cos) + center_shift[1] + + points[::2], points[1::2] = _x, _y + return points + + def cal_canvas_size(self, ori_size, degree): + assert isinstance(ori_size, tuple) + angle = degree * math.pi / 180.0 + h, w = ori_size[:2] + + cos = math.cos(angle) + sin = math.sin(angle) + canvas_h = int(w * math.fabs(sin) + h * math.fabs(cos)) + canvas_w = int(w * math.fabs(cos) + h * math.fabs(sin)) + + canvas_size = (canvas_h, canvas_w) + return canvas_size + + def sample_angle(self, max_angle): + angle = np.random.random_sample() * 2 * max_angle - max_angle + return angle + + def rotate_img(self, img, angle, canvas_size): + h, w = img.shape[:2] + rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) + rotation_matrix[0, 2] += int((canvas_size[1] - w) / 2) + rotation_matrix[1, 2] += int((canvas_size[0] - h) / 2) + + if self.pad_with_fixed_color: + target_img = cv2.warpAffine( + img, + rotation_matrix, (canvas_size[1], canvas_size[0]), + flags=cv2.INTER_NEAREST, + borderValue=self.pad_value) + else: + mask = np.zeros_like(img) + (h_ind, w_ind) = (np.random.randint(0, h * 7 // 8), + np.random.randint(0, w * 7 // 8)) + img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)] + img_cut = dbnet_cv.imresize(img_cut, (canvas_size[1], canvas_size[0])) + mask = cv2.warpAffine( + mask, + rotation_matrix, (canvas_size[1], canvas_size[0]), + borderValue=[1, 1, 1]) + target_img = cv2.warpAffine( + img, + rotation_matrix, (canvas_size[1], canvas_size[0]), + borderValue=[0, 0, 0]) + target_img = target_img + img_cut * mask + + return target_img + + def __call__(self, results): + if np.random.random_sample() < self.rotate_ratio: + img = results['img'] + h, w = img.shape[:2] + angle = self.sample_angle(self.max_angle) + canvas_size = self.cal_canvas_size((h, w), angle) + center_shift = (int( + (canvas_size[1] - w) / 2), int((canvas_size[0] - h) / 2)) + + # rotate image + results['rotated_poly_angle'] = angle + img = self.rotate_img(img, angle, canvas_size) + results['img'] = img + img_shape = img.shape + results['img_shape'] = img_shape + + # rotate polygons + for key in results.get('mask_fields', []): + if len(results[key].masks) == 0: + continue + masks = results[key].masks + rotated_masks = [] + for mask in masks: + rotated_mask = self.rotate((w / 2, h / 2), mask[0], angle, + center_shift) + rotated_masks.append([rotated_mask]) + + results[key] = PolygonMasks(rotated_masks, *(img_shape[:2])) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class SquareResizePad: + + def __init__(self, + target_size, + pad_ratio=0.6, + pad_with_fixed_color=False, + pad_value=(0, 0, 0)): + """Resize or pad images to be square shape. + + Args: + target_size (int): The target size of square shaped image. + pad_with_fixed_color (bool): The flag for whether to pad rotated + image with fixed value. If set to False, the rescales image will + be padded onto cropped image. + pad_value (tuple(int)): The color value for padding rotated image. + """ + assert isinstance(target_size, int) + assert isinstance(pad_ratio, float) + assert isinstance(pad_with_fixed_color, bool) + assert isinstance(pad_value, tuple) + + self.target_size = target_size + self.pad_ratio = pad_ratio + self.pad_with_fixed_color = pad_with_fixed_color + self.pad_value = pad_value + + def resize_img(self, img, keep_ratio=True): + h, w, _ = img.shape + if keep_ratio: + t_h = self.target_size if h >= w else int(h * self.target_size / w) + t_w = self.target_size if h <= w else int(w * self.target_size / h) + else: + t_h = t_w = self.target_size + img = dbnet_cv.imresize(img, (t_w, t_h)) + return img, (t_h, t_w) + + def square_pad(self, img): + h, w = img.shape[:2] + if h == w: + return img, (0, 0) + pad_size = max(h, w) + if self.pad_with_fixed_color: + expand_img = np.ones((pad_size, pad_size, 3), dtype=np.uint8) + expand_img[:] = self.pad_value + else: + (h_ind, w_ind) = (np.random.randint(0, h * 7 // 8), + np.random.randint(0, w * 7 // 8)) + img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)] + expand_img = dbnet_cv.imresize(img_cut, (pad_size, pad_size)) + if h > w: + y0, x0 = 0, (h - w) // 2 + else: + y0, x0 = (w - h) // 2, 0 + expand_img[y0:y0 + h, x0:x0 + w] = img + offset = (x0, y0) + + return expand_img, offset + + def square_pad_mask(self, points, offset): + x0, y0 = offset + pad_points = points.copy() + pad_points[::2] = pad_points[::2] + x0 + pad_points[1::2] = pad_points[1::2] + y0 + return pad_points + + def __call__(self, results): + img = results['img'] + + if np.random.random_sample() < self.pad_ratio: + img, out_size = self.resize_img(img, keep_ratio=True) + img, offset = self.square_pad(img) + else: + img, out_size = self.resize_img(img, keep_ratio=False) + offset = (0, 0) + + results['img'] = img + results['img_shape'] = img.shape + + for key in results.get('mask_fields', []): + if len(results[key].masks) == 0: + continue + results[key] = results[key].resize(out_size) + masks = results[key].masks + processed_masks = [] + for mask in masks: + square_pad_mask = self.square_pad_mask(mask[0], offset) + processed_masks.append([square_pad_mask]) + + results[key] = PolygonMasks(processed_masks, *(img.shape[:2])) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class RandomScaling: + + def __init__(self, size=800, scale=(3. / 4, 5. / 2)): + """Random scale the image while keeping aspect. + + Args: + size (int) : Base size before scaling. + scale (tuple(float)) : The range of scaling. + """ + assert isinstance(size, int) + assert isinstance(scale, float) or isinstance(scale, tuple) + self.size = size + self.scale = scale if isinstance(scale, tuple) \ + else (1 - scale, 1 + scale) + + def __call__(self, results): + image = results['img'] + h, w, _ = results['img_shape'] + + aspect_ratio = np.random.uniform(min(self.scale), max(self.scale)) + scales = self.size * 1.0 / max(h, w) * aspect_ratio + scales = np.array([scales, scales]) + out_size = (int(h * scales[1]), int(w * scales[0])) + image = dbnet_cv.imresize(image, out_size[::-1]) + + results['img'] = image + results['img_shape'] = image.shape + + for key in results.get('mask_fields', []): + if len(results[key].masks) == 0: + continue + results[key] = results[key].resize(out_size) + + return results + + +@PIPELINES.register_module() +class RandomCropFlip: + + def __init__(self, + pad_ratio=0.1, + crop_ratio=0.5, + iter_num=1, + min_area_ratio=0.2): + """Random crop and flip a patch of the image. + + Args: + crop_ratio (float): The ratio of cropping. + iter_num (int): Number of operations. + min_area_ratio (float): Minimal area ratio between cropped patch + and original image. + """ + assert isinstance(crop_ratio, float) + assert isinstance(iter_num, int) + assert isinstance(min_area_ratio, float) + + self.pad_ratio = pad_ratio + self.epsilon = 1e-2 + self.crop_ratio = crop_ratio + self.iter_num = iter_num + self.min_area_ratio = min_area_ratio + + def __call__(self, results): + for i in range(self.iter_num): + results = self.random_crop_flip(results) + return results + + def random_crop_flip(self, results): + image = results['img'] + polygons = results['gt_masks'].masks + ignore_polygons = results['gt_masks_ignore'].masks + all_polygons = polygons + ignore_polygons + if len(polygons) == 0: + return results + + if np.random.random() >= self.crop_ratio: + return results + + h, w, _ = results['img_shape'] + area = h * w + pad_h = int(h * self.pad_ratio) + pad_w = int(w * self.pad_ratio) + h_axis, w_axis = self.generate_crop_target(image, all_polygons, pad_h, + pad_w) + if len(h_axis) == 0 or len(w_axis) == 0: + return results + + attempt = 0 + while attempt < 10: + attempt += 1 + polys_keep = [] + polys_new = [] + ign_polys_keep = [] + ign_polys_new = [] + xx = np.random.choice(w_axis, size=2) + xmin = np.min(xx) - pad_w + xmax = np.max(xx) - pad_w + xmin = np.clip(xmin, 0, w - 1) + xmax = np.clip(xmax, 0, w - 1) + yy = np.random.choice(h_axis, size=2) + ymin = np.min(yy) - pad_h + ymax = np.max(yy) - pad_h + ymin = np.clip(ymin, 0, h - 1) + ymax = np.clip(ymax, 0, h - 1) + if (xmax - xmin) * (ymax - ymin) < area * self.min_area_ratio: + # area too small + continue + + pts = np.stack([[xmin, xmax, xmax, xmin], + [ymin, ymin, ymax, ymax]]).T.astype(np.int32) + pp = plg(pts) + fail_flag = False + for polygon in polygons: + ppi = plg(polygon[0].reshape(-1, 2)) + ppiou = eval_utils.poly_intersection(ppi, pp) + if np.abs(ppiou - float(ppi.area)) > self.epsilon and \ + np.abs(ppiou) > self.epsilon: + fail_flag = True + break + elif np.abs(ppiou - float(ppi.area)) < self.epsilon: + polys_new.append(polygon) + else: + polys_keep.append(polygon) + + for polygon in ignore_polygons: + ppi = plg(polygon[0].reshape(-1, 2)) + ppiou = eval_utils.poly_intersection(ppi, pp) + if np.abs(ppiou - float(ppi.area)) > self.epsilon and \ + np.abs(ppiou) > self.epsilon: + fail_flag = True + break + elif np.abs(ppiou - float(ppi.area)) < self.epsilon: + ign_polys_new.append(polygon) + else: + ign_polys_keep.append(polygon) + + if fail_flag: + continue + else: + break + + cropped = image[ymin:ymax, xmin:xmax, :] + select_type = np.random.randint(3) + if select_type == 0: + img = np.ascontiguousarray(cropped[:, ::-1]) + elif select_type == 1: + img = np.ascontiguousarray(cropped[::-1, :]) + else: + img = np.ascontiguousarray(cropped[::-1, ::-1]) + image[ymin:ymax, xmin:xmax, :] = img + results['img'] = image + + if len(polys_new) + len(ign_polys_new) != 0: + height, width, _ = cropped.shape + if select_type == 0: + for idx, polygon in enumerate(polys_new): + poly = polygon[0].reshape(-1, 2) + poly[:, 0] = width - poly[:, 0] + 2 * xmin + polys_new[idx] = [poly.reshape(-1, )] + for idx, polygon in enumerate(ign_polys_new): + poly = polygon[0].reshape(-1, 2) + poly[:, 0] = width - poly[:, 0] + 2 * xmin + ign_polys_new[idx] = [poly.reshape(-1, )] + elif select_type == 1: + for idx, polygon in enumerate(polys_new): + poly = polygon[0].reshape(-1, 2) + poly[:, 1] = height - poly[:, 1] + 2 * ymin + polys_new[idx] = [poly.reshape(-1, )] + for idx, polygon in enumerate(ign_polys_new): + poly = polygon[0].reshape(-1, 2) + poly[:, 1] = height - poly[:, 1] + 2 * ymin + ign_polys_new[idx] = [poly.reshape(-1, )] + else: + for idx, polygon in enumerate(polys_new): + poly = polygon[0].reshape(-1, 2) + poly[:, 0] = width - poly[:, 0] + 2 * xmin + poly[:, 1] = height - poly[:, 1] + 2 * ymin + polys_new[idx] = [poly.reshape(-1, )] + for idx, polygon in enumerate(ign_polys_new): + poly = polygon[0].reshape(-1, 2) + poly[:, 0] = width - poly[:, 0] + 2 * xmin + poly[:, 1] = height - poly[:, 1] + 2 * ymin + ign_polys_new[idx] = [poly.reshape(-1, )] + polygons = polys_keep + polys_new + ignore_polygons = ign_polys_keep + ign_polys_new + results['gt_masks'] = PolygonMasks(polygons, *(image.shape[:2])) + results['gt_masks_ignore'] = PolygonMasks(ignore_polygons, + *(image.shape[:2])) + + return results + + def generate_crop_target(self, image, all_polys, pad_h, pad_w): + """Generate crop target and make sure not to crop the polygon + instances. + + Args: + image (ndarray): The image waited to be crop. + all_polys (list[list[ndarray]]): All polygons including ground + truth polygons and ground truth ignored polygons. + pad_h (int): Padding length of height. + pad_w (int): Padding length of width. + Returns: + h_axis (ndarray): Vertical cropping range. + w_axis (ndarray): Horizontal cropping range. + """ + h, w, _ = image.shape + h_array = np.zeros((h + pad_h * 2), dtype=np.int32) + w_array = np.zeros((w + pad_w * 2), dtype=np.int32) + + text_polys = [] + for polygon in all_polys: + rect = cv2.minAreaRect(polygon[0].astype(np.int32).reshape(-1, 2)) + box = cv2.boxPoints(rect) + box = np.int0(box) + text_polys.append([box[0], box[1], box[2], box[3]]) + + polys = np.array(text_polys, dtype=np.int32) + for poly in polys: + poly = np.round(poly, decimals=0).astype(np.int32) + minx = np.min(poly[:, 0]) + maxx = np.max(poly[:, 0]) + w_array[minx + pad_w:maxx + pad_w] = 1 + miny = np.min(poly[:, 1]) + maxy = np.max(poly[:, 1]) + h_array[miny + pad_h:maxy + pad_h] = 1 + + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + return h_axis, w_axis + + +@PIPELINES.register_module() +class PyramidRescale: + """Resize the image to the base shape, downsample it with gaussian pyramid, + and rescale it back to original size. + + Adapted from https://github.com/FangShancheng/ABINet. + + Args: + factor (int): The decay factor from base size, or the number of + downsampling operations from the base layer. + base_shape (tuple(int)): The shape of the base layer of the pyramid. + randomize_factor (bool): If True, the final factor would be a random + integer in [0, factor]. + + :Required Keys: + - | ``img`` (ndarray): The input image. + + :Affected Keys: + :Modified: + - | ``img`` (ndarray): The modified image. + """ + + def __init__(self, factor=4, base_shape=(128, 512), randomize_factor=True): + assert isinstance(factor, int) + assert isinstance(base_shape, list) or isinstance(base_shape, tuple) + assert len(base_shape) == 2 + assert isinstance(randomize_factor, bool) + self.factor = factor if not randomize_factor else np.random.randint( + 0, factor + 1) + self.base_w, self.base_h = base_shape + + def __call__(self, results): + assert 'img' in results + if self.factor == 0: + return results + img = results['img'] + src_h, src_w = img.shape[:2] + scale_img = dbnet_cv.imresize(img, (self.base_w, self.base_h)) + for _ in range(self.factor): + scale_img = cv2.pyrDown(scale_img) + scale_img = dbnet_cv.imresize(scale_img, (src_w, src_h)) + results['img'] = scale_img + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(factor={self.factor}, ' + repr_str += f'basew={self.basew}, baseh={self.baseh})' + return repr_str diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/text_det_dataset.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/text_det_dataset.py new file mode 100755 index 00000000..75fa273c --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/text_det_dataset.py @@ -0,0 +1,131 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from dbnet_det.datasets.builder import DATASETS + +from dbnet.core.evaluation.hmean import eval_hmean +from dbnet.datasets.base_dataset import BaseDataset + + +@DATASETS.register_module() +class TextDetDataset(BaseDataset): + + def _parse_anno_info(self, annotations): + """Parse bbox and mask annotation. + Args: + annotations (dict): Annotations of one image. + + Returns: + dict: A dict containing the following keys: bboxes, bboxes_ignore, + labels, masks, masks_ignore. "masks" and + "masks_ignore" are represented by polygon boundary + point sequences. + """ + gt_bboxes, gt_bboxes_ignore = [], [] + gt_masks, gt_masks_ignore = [], [] + gt_labels = [] + for ann in annotations: + if ann.get('iscrowd', False): + gt_bboxes_ignore.append(ann['bbox']) + gt_masks_ignore.append(ann.get('segmentation', None)) + else: + gt_bboxes.append(ann['bbox']) + gt_labels.append(ann['category_id']) + gt_masks.append(ann.get('segmentation', None)) + if gt_bboxes: + gt_bboxes = np.array(gt_bboxes, dtype=np.float32) + gt_labels = np.array(gt_labels, dtype=np.int64) + else: + gt_bboxes = np.zeros((0, 4), dtype=np.float32) + gt_labels = np.array([], dtype=np.int64) + + if gt_bboxes_ignore: + gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32) + else: + gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) + + ann = dict( + bboxes=gt_bboxes, + labels=gt_labels, + bboxes_ignore=gt_bboxes_ignore, + masks_ignore=gt_masks_ignore, + masks=gt_masks) + + return ann + + def prepare_train_img(self, index): + """Get training data and annotations from pipeline. + + Args: + index (int): Index of data. + + Returns: + dict: Training data and annotation after pipeline with new keys + introduced by pipeline. + """ + img_ann_info = self.data_infos[index] + img_info = { + 'filename': img_ann_info['file_name'], + 'height': img_ann_info['height'], + 'width': img_ann_info['width'] + } + ann_info = self._parse_anno_info(img_ann_info['annotations']) + results = dict(img_info=img_info, ann_info=ann_info) + results['bbox_fields'] = [] + results['mask_fields'] = [] + results['seg_fields'] = [] + self.pre_pipeline(results) + + return self.pipeline(results) + + def evaluate(self, + results, + metric='hmean-iou', + score_thr=None, + min_score_thr=0.3, + max_score_thr=0.9, + step=0.1, + rank_list=None, + logger=None, + **kwargs): + """Evaluate the dataset. + + Args: + results (list): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. + score_thr (float): Deprecated. Please use min_score_thr instead. + min_score_thr (float): Minimum score threshold of prediction map. + max_score_thr (float): Maximum score threshold of prediction map. + step (float): The spacing between score thresholds. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + rank_list (str): json file used to save eval result + of each image after ranking. + Returns: + dict[str: float] + """ + metrics = metric if isinstance(metric, list) else [metric] + allowed_metrics = ['hmean-iou', 'hmean-ic13'] + metrics = set(metrics) & set(allowed_metrics) + + img_infos = [] + ann_infos = [] + for i in range(len(self)): + img_ann_info = self.data_infos[i] + img_info = {'filename': img_ann_info['file_name']} + ann_info = self._parse_anno_info(img_ann_info['annotations']) + img_infos.append(img_info) + ann_infos.append(ann_info) + + eval_results = eval_hmean( + results, + img_infos, + ann_infos, + metrics=metrics, + score_thr=score_thr, + min_score_thr=min_score_thr, + max_score_thr=max_score_thr, + step=step, + logger=logger, + rank_list=rank_list) + + return eval_results diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/uniform_concat_dataset.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/uniform_concat_dataset.py new file mode 100755 index 00000000..104dd3fd --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/uniform_concat_dataset.py @@ -0,0 +1,151 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from collections import defaultdict + +import numpy as np +from dbnet_cv.utils import print_log +from dbnet_det.datasets import DATASETS, ConcatDataset, build_dataset + +from dbnet.utils import is_2dlist, is_type_list + + +@DATASETS.register_module() +class UniformConcatDataset(ConcatDataset): + """A wrapper of ConcatDataset which support dataset pipeline assignment and + replacement. + + Args: + datasets (list[dict] | list[list[dict]]): A list of datasets cfgs. + separate_eval (bool): Whether to evaluate the results + separately if it is used as validation dataset. + Defaults to True. + show_mean_scores (str | bool): Whether to compute the mean evaluation + results, only applicable when ``separate_eval=True``. Options are + [True, False, ``auto``]. If ``True``, mean results will be added to + the result dictionary with keys in the form of + ``mean_{metric_name}``. If 'auto', mean results will be shown only + when more than 1 dataset is wrapped. + pipeline (None | list[dict] | list[list[dict]]): If ``None``, + each dataset in datasets use its own pipeline; + If ``list[dict]``, it will be assigned to the dataset whose + pipeline is None in datasets; + If ``list[list[dict]]``, pipeline of dataset which is None + in datasets will be replaced by the corresponding pipeline + in the list. + force_apply (bool): If True, apply pipeline above to each dataset + even if it have its own pipeline. Default: False. + """ + + def __init__(self, + datasets, + separate_eval=True, + show_mean_scores='auto', + pipeline=None, + force_apply=False, + **kwargs): + new_datasets = [] + if pipeline is not None: + assert isinstance( + pipeline, + list), 'pipeline must be list[dict] or list[list[dict]].' + if is_type_list(pipeline, dict): + self._apply_pipeline(datasets, pipeline, force_apply) + new_datasets = datasets + elif is_2dlist(pipeline): + assert is_2dlist(datasets) + assert len(datasets) == len(pipeline) + for sub_datasets, tmp_pipeline in zip(datasets, pipeline): + self._apply_pipeline(sub_datasets, tmp_pipeline, + force_apply) + new_datasets.extend(sub_datasets) + else: + if is_2dlist(datasets): + for sub_datasets in datasets: + new_datasets.extend(sub_datasets) + else: + new_datasets = datasets + datasets = [build_dataset(c, kwargs) for c in new_datasets] + super().__init__(datasets, separate_eval) + + if not separate_eval: + raise NotImplementedError( + 'Evaluating datasets as a whole is not' + ' supported yet. Please use "separate_eval=True"') + + assert isinstance(show_mean_scores, bool) or show_mean_scores == 'auto' + if show_mean_scores == 'auto': + show_mean_scores = len(self.datasets) > 1 + self.show_mean_scores = show_mean_scores + if show_mean_scores is True or show_mean_scores == 'auto' and len( + self.datasets) > 1: + if len(set([type(ds) for ds in self.datasets])) != 1: + raise NotImplementedError( + 'To compute mean evaluation scores, all datasets' + 'must have the same type') + + @staticmethod + def _apply_pipeline(datasets, pipeline, force_apply=False): + from_cfg = all(isinstance(x, dict) for x in datasets) + assert from_cfg, 'datasets should be config dicts' + assert all(isinstance(x, dict) for x in pipeline) + for dataset in datasets: + if dataset['pipeline'] is None or force_apply: + dataset['pipeline'] = copy.deepcopy(pipeline) + + def evaluate(self, results, logger=None, **kwargs): + """Evaluate the results. + + Args: + results (list[list | tuple]): Testing results of the dataset. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + + Returns: + dict[str: float]: Results of each separate + dataset if `self.separate_eval=True`. + """ + assert len(results) == self.cumulative_sizes[-1], \ + ('Dataset and results have different sizes: ' + f'{self.cumulative_sizes[-1]} v.s. {len(results)}') + + # Check whether all the datasets support evaluation + for dataset in self.datasets: + assert hasattr(dataset, 'evaluate'), \ + f'{type(dataset)} does not implement evaluate function' + + if self.separate_eval: + dataset_idx = -1 + + total_eval_results = dict() + + if self.show_mean_scores: + mean_eval_results = defaultdict(list) + + for dataset in self.datasets: + start_idx = 0 if dataset_idx == -1 else \ + self.cumulative_sizes[dataset_idx] + end_idx = self.cumulative_sizes[dataset_idx + 1] + + results_per_dataset = results[start_idx:end_idx] + print_log( + f'\nEvaluating {dataset.ann_file} with ' + f'{len(results_per_dataset)} images now', + logger=logger) + + eval_results_per_dataset = dataset.evaluate( + results_per_dataset, logger=logger, **kwargs) + dataset_idx += 1 + for k, v in eval_results_per_dataset.items(): + total_eval_results.update({f'{dataset_idx}_{k}': v}) + if self.show_mean_scores: + mean_eval_results[k].append(v) + + if self.show_mean_scores: + for k, v in mean_eval_results.items(): + total_eval_results[f'mean_{k}'] = np.mean(v) + + return total_eval_results + else: + raise NotImplementedError( + 'Evaluating datasets as a whole is not' + ' supported yet. Please use "separate_eval=True"') diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/__init__.py new file mode 100755 index 00000000..f2fc30a5 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .loader import AnnFileLoader, HardDiskLoader, LmdbLoader +from .parser import LineJsonParser, LineStrParser + +__all__ = [ + 'HardDiskLoader', 'LmdbLoader', 'AnnFileLoader', 'LineStrParser', + 'LineJsonParser' +] diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/backend.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/backend.py new file mode 100755 index 00000000..f72b6f63 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/backend.py @@ -0,0 +1,198 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os +import os.path as osp +import shutil +import warnings + +import dbnet_cv + +from dbnet import digit_version +from dbnet.utils import list_from_file + + +class LmdbAnnFileBackend: + """Lmdb storage backend for annotation file. + + Args: + lmdb_path (str): Lmdb file path. + """ + + def __init__(self, lmdb_path, encoding='utf8'): + """Currently we support two lmdb formats, one is the lmdb file with + only labels generated by txt2lmdb (deprecated), and one is the lmdb + file generated by recog2lmdb. + + The former stores string in 'filename text' format directly in lmdb, + while the latter uses a more reasonable image_key as well as label_key + for querying. + """ + self.lmdb_path = lmdb_path + self.encoding = encoding + self.deprecated_format = False + env = self._get_env() + with env.begin(write=False) as txn: + try: + self.total_number = int( + txn.get('num-samples'.encode('utf-8')).decode( + self.encoding)) + except AttributeError: + warnings.warn( + 'DeprecationWarning: The lmdb dataset generated with ' + 'txt2lmdb will be deprecate, please use the latest ' + 'tools/data/utils/recog2lmdb to generate lmdb dataset. ' + 'See https://mmocr.readthedocs.io/en/latest/tools.html#' + 'convert-text-recognition-dataset-to-lmdb-format for ' + 'details.') + self.total_number = int( + txn.get('total_number'.encode('utf-8')).decode( + self.encoding)) + self.deprecated_format = True + # The lmdb file may contain only the label, or it may contain both + # the label and the image, so we use image_key here for probing. + image_key = f'image-{1:09d}' + if txn.get(image_key.encode(encoding)) is None: + self.label_only = True + else: + self.label_only = False + + def __getitem__(self, index): + """Retrieve one line from lmdb file by index. + + In order to support space + reading, the returned lines are in the form of json, such as + '{'filename': 'image1.jpg' ,'text':'HELLO'}' + """ + if not hasattr(self, 'env'): + self.env = self._get_env() + + with self.env.begin(write=False) as txn: + if self.deprecated_format: + line = txn.get(str(index).encode('utf-8')).decode( + self.encoding) + keys = line.strip('/n').split(' ') + if len(keys) == 4: + filename, height, width, annotations = keys + line = json.dumps( + dict( + filename=filename, + height=height, + width=width, + annotations=annotations), + ensure_ascii=False) + elif len(keys) == 2: + filename, text = keys + line = json.dumps( + dict(filename=filename, text=text), ensure_ascii=False) + else: + index = index + 1 + label_key = f'label-{index:09d}' + if self.label_only: + line = txn.get(label_key.encode('utf-8')).decode( + self.encoding) + else: + img_key = f'image-{index:09d}' + text = txn.get(label_key.encode('utf-8')).decode( + self.encoding) + line = json.dumps( + dict(filename=img_key, text=text), ensure_ascii=False) + return line + + def __len__(self): + return self.total_number + + def _get_env(self): + try: + import lmdb + except ImportError: + raise ImportError( + 'Please install lmdb to enable LmdbAnnFileBackend.') + return lmdb.open( + self.lmdb_path, + max_readers=1, + readonly=True, + lock=False, + readahead=False, + meminit=False, + ) + + def close(self): + self.env.close() + + +class HardDiskAnnFileBackend: + """Load annotation file with raw hard disks storage backend.""" + + def __init__(self, file_format='txt'): + assert file_format in ['txt', 'lmdb'] + self.file_format = file_format + + def __call__(self, ann_file): + if self.file_format == 'lmdb': + return LmdbAnnFileBackend(ann_file) + + return list_from_file(ann_file) + + +class PetrelAnnFileBackend: + """Load annotation file with petrel storage backend.""" + + def __init__(self, file_format='txt', save_dir='tmp_dir'): + assert file_format in ['txt', 'lmdb'] + self.file_format = file_format + self.save_dir = save_dir + + def __call__(self, ann_file): + file_client = dbnet_cv.FileClient(backend='petrel') + + if self.file_format == 'lmdb': + dbnet_cv_version = digit_version(dbnet_cv.__version__) + if dbnet_cv_version < digit_version('1.3.16'): + raise Exception('Please update dbnet_cv to 1.3.16 or higher ' + 'to enable "get_local_path" of "FileClient".') + assert file_client.isdir(ann_file) + files = file_client.list_dir_or_file(ann_file) + + ann_file_rel_path = ann_file.split('s3://')[-1] + ann_file_dir = osp.dirname(ann_file_rel_path) + ann_file_name = osp.basename(ann_file_rel_path) + local_dir = osp.join(self.save_dir, ann_file_dir, ann_file_name) + if osp.exists(local_dir): + warnings.warn( + f'local_ann_file: {local_dir} is already existed and ' + 'will be used. If it is not the correct ann_file ' + 'corresponding to {ann_file}, please remove it or ' + 'change "save_dir" first then try again.') + else: + os.makedirs(local_dir, exist_ok=True) + print(f'Fetching {ann_file} to {local_dir}...') + for each_file in files: + tmp_file_path = file_client.join_path(ann_file, each_file) + with file_client.get_local_path( + tmp_file_path) as local_path: + shutil.copy(local_path, osp.join(local_dir, each_file)) + + return LmdbAnnFileBackend(local_dir) + + lines = str(file_client.get(ann_file), encoding='utf-8').split('\n') + + return [x for x in lines if x.strip() != ''] + + +class HTTPAnnFileBackend: + """Load annotation file with http storage backend.""" + + def __init__(self, file_format='txt'): + assert file_format in ['txt', 'lmdb'] + self.file_format = file_format + + def __call__(self, ann_file): + file_client = dbnet_cv.FileClient(backend='http') + + if self.file_format == 'lmdb': + raise NotImplementedError( + 'Loading lmdb file on http is not supported yet.') + + lines = str(file_client.get(ann_file), encoding='utf-8').split('\n') + + return [x for x in lines if x.strip() != ''] diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/loader.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/loader.py new file mode 100755 index 00000000..dff51c6c --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/loader.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from dbnet.datasets.builder import LOADERS, build_parser +from .backend import (HardDiskAnnFileBackend, HTTPAnnFileBackend, + PetrelAnnFileBackend) + + +@LOADERS.register_module() +class AnnFileLoader: + """Annotation file loader to load annotations from ann_file, and parse raw + annotation to dict format with certain parser. + + Args: + ann_file (str): Annotation file path. + parser (dict): Dictionary to construct parser + to parse original annotation infos. + repeat (int|float): Repeated times of dataset. + file_storage_backend (str): The storage backend type for annotation + file. Options are "disk", "http" and "petrel". Default: "disk". + file_format (str): The format of annotation file. Options are + "txt" and "lmdb". Default: "txt". + """ + + _backends = { + 'disk': HardDiskAnnFileBackend, + 'petrel': PetrelAnnFileBackend, + 'http': HTTPAnnFileBackend + } + + def __init__(self, + ann_file, + parser, + repeat=1, + file_storage_backend='disk', + file_format='txt', + **kwargs): + assert isinstance(ann_file, str) + assert isinstance(repeat, (int, float)) + assert isinstance(parser, dict) + assert repeat > 0 + assert file_storage_backend in ['disk', 'http', 'petrel'] + assert file_format in ['txt', 'lmdb'] + + if file_format == 'lmdb' and parser['type'] == 'LineStrParser': + raise ValueError('We only support using LineJsonParser ' + 'to parse lmdb file. Please use LineJsonParser ' + 'in the dataset config') + self.parser = build_parser(parser) + self.repeat = repeat + self.ann_file_backend = self._backends[file_storage_backend]( + file_format, **kwargs) + self.ori_data_infos = self._load(ann_file) + + def __len__(self): + return int(len(self.ori_data_infos) * self.repeat) + + def _load(self, ann_file): + """Load annotation file.""" + + return self.ann_file_backend(ann_file) + + def __getitem__(self, index): + """Retrieve anno info of one instance with dict format.""" + return self.parser.get_item(self.ori_data_infos, index) + + def __iter__(self): + self._n = 0 + return self + + def __next__(self): + if self._n < len(self): + data = self[self._n] + self._n += 1 + return data + raise StopIteration + + def close(self): + """For ann_file with lmdb format only.""" + self.ori_data_infos.close() + + +@LOADERS.register_module() +class HardDiskLoader(AnnFileLoader): + """Load txt format annotation file from hard disks.""" + + def __init__(self, ann_file, parser, repeat=1): + warnings.warn( + 'HardDiskLoader is deprecated, please use ' + 'AnnFileLoader instead.', UserWarning) + super().__init__( + ann_file, + parser, + repeat, + file_storage_backend='disk', + file_format='txt') + + +@LOADERS.register_module() +class LmdbLoader(AnnFileLoader): + """Load lmdb format annotation file from hard disks.""" + + def __init__(self, ann_file, parser, repeat=1): + warnings.warn( + 'LmdbLoader is deprecated, please use ' + 'AnnFileLoader instead.', UserWarning) + super().__init__( + ann_file, + parser, + repeat, + file_storage_backend='disk', + file_format='lmdb') diff --git a/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/parser.py b/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/parser.py new file mode 100755 index 00000000..83abb3dc --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/datasets/utils/parser.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import warnings + +from dbnet.datasets.builder import PARSERS +from dbnet.utils import StringStrip + + +@PARSERS.register_module() +class LineStrParser: + """Parse string of one line in annotation file to dict format. + + Args: + keys (list[str]): Keys in result dict. + keys_idx (list[int]): Value index in sub-string list + for each key above. + separator (str): Separator to separate string to list of sub-string. + """ + + def __init__(self, + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ', + **kwargs): + assert isinstance(keys, list) + assert isinstance(keys_idx, list) + assert isinstance(separator, str) + assert len(keys) > 0 + assert len(keys) == len(keys_idx) + self.keys = keys + self.keys_idx = keys_idx + self.separator = separator + self.strip_cls = StringStrip(**kwargs) + + def get_item(self, data_ret, index): + map_index = index % len(data_ret) + line_str = data_ret[map_index] + line_str = self.strip_cls(line_str) + if len(line_str.split(' ')) > 2: + msg = 'More than two blank spaces were detected. ' + msg += 'Please use LineJsonParser to handle ' + msg += 'annotations with blanks. ' + msg += 'Check Doc ' + msg += 'https://mmocr.readthedocs.io/en/latest/' + msg += 'tutorials/blank_recog.html ' + msg += 'for details.' + warnings.warn(msg) + line_str = line_str.split(self.separator) + if len(line_str) <= max(self.keys_idx): + raise Exception( + f'key index: {max(self.keys_idx)} out of range: {line_str}') + + line_info = {} + for i, key in enumerate(self.keys): + line_info[key] = line_str[self.keys_idx[i]] + return line_info + + +@PARSERS.register_module() +class LineJsonParser: + """Parse json-string of one line in annotation file to dict format. + + Args: + keys (list[str]): Keys in both json-string and result dict. + """ + + def __init__(self, keys=[]): + assert isinstance(keys, list) + assert len(keys) > 0 + self.keys = keys + + def get_item(self, data_ret, index): + map_index = index % len(data_ret) + json_str = data_ret[map_index] + line_json_obj = json.loads(json_str) + line_info = {} + for key in self.keys: + if key not in line_json_obj: + raise Exception(f'key {key} not in line json {line_json_obj}') + line_info[key] = line_json_obj[key] + + return line_info diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/__init__.py new file mode 100755 index 00000000..80a0426f --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from . import common, textdet +from .builder import (BACKBONES, CONVERTORS, DECODERS, DETECTORS, ENCODERS, + HEADS, LOSSES, NECKS, PREPROCESSOR, build_backbone, + build_convertor, build_decoder, build_detector, + build_encoder, build_loss, build_preprocessor) +from .common import * # NOQA +# from .kie import * # NOQA +# from .ner import * # NOQA +from .textdet import * # NOQA +# from .textrecog import * # NOQA + +# __all__ = [ +# 'BACKBONES', 'DETECTORS', 'HEADS', 'LOSSES', 'NECKS', 'build_backbone', +# 'build_detector', 'build_loss', 'CONVERTORS', 'ENCODERS', 'DECODERS', +# 'PREPROCESSOR', 'build_convertor', 'build_encoder', 'build_decoder', +# 'build_preprocessor' +# ] +# __all__ += common.__all__ + kie.__all__ + textdet.__all__ + textrecog.__all__ diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/builder.py b/cv/ocr/dbnet/pytorch/dbnet/models/builder.py new file mode 100755 index 00000000..d5a19636 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/builder.py @@ -0,0 +1,152 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +from dbnet_cv.cnn import ACTIVATION_LAYERS as DBNET_CV_ACTIVATION_LAYERS +from dbnet_cv.cnn import UPSAMPLE_LAYERS as DBNET_CV_UPSAMPLE_LAYERS +from dbnet_cv.utils import Registry, build_from_cfg +from dbnet_det.models.builder import BACKBONES as dbnet_det_BACKBONES + +CONVERTORS = Registry('convertor') +ENCODERS = Registry('encoder') +DECODERS = Registry('decoder') +PREPROCESSOR = Registry('preprocessor') +POSTPROCESSOR = Registry('postprocessor') + +UPSAMPLE_LAYERS = Registry('upsample layer', parent=DBNET_CV_UPSAMPLE_LAYERS) +BACKBONES = Registry('models', parent=dbnet_det_BACKBONES) +LOSSES = BACKBONES +DETECTORS = BACKBONES +ROI_EXTRACTORS = BACKBONES +HEADS = BACKBONES +NECKS = BACKBONES +FUSERS = BACKBONES +RECOGNIZERS = BACKBONES + +ACTIVATION_LAYERS = Registry('activation layer', parent=DBNET_CV_ACTIVATION_LAYERS) + + +def build_recognizer(cfg, train_cfg=None, test_cfg=None): + """Build recognizer.""" + return build_from_cfg(cfg, RECOGNIZERS, + dict(train_cfg=train_cfg, test_cfg=test_cfg)) + + +def build_convertor(cfg): + """Build label convertor for scene text recognizer.""" + return build_from_cfg(cfg, CONVERTORS) + + +def build_encoder(cfg): + """Build encoder for scene text recognizer.""" + return build_from_cfg(cfg, ENCODERS) + + +def build_decoder(cfg): + """Build decoder for scene text recognizer.""" + return build_from_cfg(cfg, DECODERS) + + +def build_preprocessor(cfg): + """Build preprocessor for scene text recognizer.""" + return build_from_cfg(cfg, PREPROCESSOR) + + +def build_postprocessor(cfg): + """Build postprocessor for scene text detector.""" + return build_from_cfg(cfg, POSTPROCESSOR) + + +def build_roi_extractor(cfg): + """Build roi extractor.""" + return ROI_EXTRACTORS.build(cfg) + + +def build_loss(cfg): + """Build loss.""" + return LOSSES.build(cfg) + + +def build_backbone(cfg): + """Build backbone.""" + return BACKBONES.build(cfg) + + +def build_head(cfg): + """Build head.""" + return HEADS.build(cfg) + + +def build_neck(cfg): + """Build neck.""" + return NECKS.build(cfg) + + +def build_fuser(cfg): + """Build fuser.""" + return FUSERS.build(cfg) + + +def build_upsample_layer(cfg, *args, **kwargs): + """Build upsample layer. + + Args: + cfg (dict): The upsample layer config, which should contain: + + - type (str): Layer type. + - scale_factor (int): Upsample ratio, which is not applicable to + deconv. + - layer args: Args needed to instantiate a upsample layer. + args (argument list): Arguments passed to the ``__init__`` + method of the corresponding conv layer. + kwargs (keyword arguments): Keyword arguments passed to the + ``__init__`` method of the corresponding conv layer. + + Returns: + nn.Module: Created upsample layer. + """ + if not isinstance(cfg, dict): + raise TypeError(f'cfg must be a dict, but got {type(cfg)}') + if 'type' not in cfg: + raise KeyError( + f'the cfg dict must contain the key "type", but got {cfg}') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type not in UPSAMPLE_LAYERS: + raise KeyError(f'Unrecognized upsample type {layer_type}') + else: + upsample = UPSAMPLE_LAYERS.get(layer_type) + + if upsample is nn.Upsample: + cfg_['mode'] = layer_type + layer = upsample(*args, **kwargs, **cfg_) + return layer + + +def build_activation_layer(cfg): + """Build activation layer. + + Args: + cfg (dict): The activation layer config, which should contain: + - type (str): Layer type. + - layer args: Args needed to instantiate an activation layer. + + Returns: + nn.Module: Created activation layer. + """ + return build_from_cfg(cfg, ACTIVATION_LAYERS) + + +def build_detector(cfg, train_cfg=None, test_cfg=None): + """Build detector.""" + if train_cfg is not None or test_cfg is not None: + warnings.warn( + 'train_cfg and test_cfg is deprecated, ' + 'please specify them in model', UserWarning) + assert cfg.get('train_cfg') is None or train_cfg is None, \ + 'train_cfg specified in both outer field and model field ' + assert cfg.get('test_cfg') is None or test_cfg is None, \ + 'test_cfg specified in both outer field and model field ' + return DETECTORS.build( + cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/common/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/common/__init__.py new file mode 100755 index 00000000..066ee876 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/common/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from . import backbones,losses +from .backbones import * # NOQA +from .losses import * # NOQA + +__all__ = backbones.__all__ + losses.__all__ \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/common/backbones/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/common/backbones/__init__.py new file mode 100755 index 00000000..3c384ba3 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/common/backbones/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .unet import UNet + +__all__ = ['UNet'] diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/common/backbones/unet.py b/cv/ocr/dbnet/pytorch/dbnet/models/common/backbones/unet.py new file mode 100755 index 00000000..2a6d0f8c --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/common/backbones/unet.py @@ -0,0 +1,516 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from dbnet_cv.cnn import ConvModule, build_norm_layer +from dbnet_cv.runner import BaseModule +from dbnet_cv.utils.parrots_wrapper import _BatchNorm + +from dbnet.models.builder import (BACKBONES, UPSAMPLE_LAYERS, + build_activation_layer, build_upsample_layer) + + +class UpConvBlock(nn.Module): + """Upsample convolution block in decoder for UNet. + + This upsample convolution block consists of one upsample module + followed by one convolution block. The upsample module expands the + high-level low-resolution feature map and the convolution block fuses + the upsampled high-level low-resolution feature map and the low-level + high-resolution feature map from encoder. + + Args: + conv_block (nn.Sequential): Sequential of convolutional layers. + in_channels (int): Number of input channels of the high-level + skip_channels (int): Number of input channels of the low-level + high-resolution feature map from encoder. + out_channels (int): Number of output channels. + num_convs (int): Number of convolutional layers in the conv_block. + Default: 2. + stride (int): Stride of convolutional layer in conv_block. Default: 1. + dilation (int): Dilation rate of convolutional layer in conv_block. + Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + upsample_cfg (dict): The upsample config of the upsample module in + decoder. Default: dict(type='InterpConv'). If the size of + high-level feature map is the same as that of skip feature map + (low-level feature map from encoder), it does not need upsample the + high-level feature map and the upsample_cfg is None. + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + """ + + def __init__(self, + conv_block, + in_channels, + skip_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + dcn=None, + plugins=None): + super().__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.conv_block = conv_block( + in_channels=2 * skip_channels, + out_channels=out_channels, + num_convs=num_convs, + stride=stride, + dilation=dilation, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dcn=None, + plugins=None) + if upsample_cfg is not None: + self.upsample = build_upsample_layer( + cfg=upsample_cfg, + in_channels=in_channels, + out_channels=skip_channels, + with_cp=with_cp, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + else: + self.upsample = ConvModule( + in_channels, + skip_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, skip, x): + """Forward function.""" + + x = self.upsample(x) + out = torch.cat([skip, x], dim=1) + out = self.conv_block(out) + + return out + + +class BasicConvBlock(nn.Module): + """Basic convolutional block for UNet. + + This module consists of several plain convolutional layers. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + num_convs (int): Number of convolutional layers. Default: 2. + stride (int): Whether use stride convolution to downsample + the input feature map. If stride=2, it only uses stride convolution + in the first convolutional layer to downsample the input feature + map. Options are 1 or 2. Default: 1. + dilation (int): Whether use dilated convolution to expand the + receptive field. Set dilation rate of each convolutional layer and + the dilation rate of the first convolutional layer is always 1. + Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + dcn=None, + plugins=None): + super().__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.with_cp = with_cp + convs = [] + for i in range(num_convs): + convs.append( + ConvModule( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride if i == 0 else 1, + dilation=1 if i == 0 else dilation, + padding=1 if i == 0 else dilation, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + self.convs = nn.Sequential(*convs) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.convs, x) + else: + out = self.convs(x) + return out + + +@UPSAMPLE_LAYERS.register_module() +class DeconvModule(nn.Module): + """Deconvolution upsample module in decoder for UNet (2X upsample). + + This module uses deconvolution to upsample feature map in the decoder + of UNet. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + kernel_size (int): Kernel size of the convolutional layer. Default: 4. + """ + + def __init__(self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + *, + kernel_size=4, + scale_factor=2): + super().__init__() + + assert ( + kernel_size - scale_factor >= 0 + and (kernel_size - scale_factor) % 2 == 0), ( + f'kernel_size should be greater than or equal to scale_factor ' + f'and (kernel_size - scale_factor) should be even numbers, ' + f'while the kernel size is {kernel_size} and scale_factor is ' + f'{scale_factor}.') + + stride = scale_factor + padding = (kernel_size - scale_factor) // 2 + self.with_cp = with_cp + deconv = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding) + + _, norm = build_norm_layer(norm_cfg, out_channels) + activate = build_activation_layer(act_cfg) + self.deconv_upsamping = nn.Sequential(deconv, norm, activate) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.deconv_upsamping, x) + else: + out = self.deconv_upsamping(x) + return out + + +@UPSAMPLE_LAYERS.register_module() +class InterpConv(nn.Module): + """Interpolation upsample module in decoder for UNet. + + This module uses interpolation to upsample feature map in the decoder + of UNet. It consists of one interpolation upsample layer and one + convolutional layer. It can be one interpolation upsample layer followed + by one convolutional layer (conv_first=False) or one convolutional layer + followed by one interpolation upsample layer (conv_first=True). + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + conv_first (bool): Whether convolutional layer or interpolation + upsample layer first. Default: False. It means interpolation + upsample layer followed by one convolutional layer. + kernel_size (int): Kernel size of the convolutional layer. Default: 1. + stride (int): Stride of the convolutional layer. Default: 1. + padding (int): Padding of the convolutional layer. Default: 1. + upsample_cfg (dict): Interpolation config of the upsample layer. + Default: dict( + scale_factor=2, mode='bilinear', align_corners=False). + """ + + def __init__(self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + *, + conv_cfg=None, + conv_first=False, + kernel_size=1, + stride=1, + padding=0, + upsample_cfg=dict( + scale_factor=2, mode='bilinear', align_corners=False)): + super().__init__() + + self.with_cp = with_cp + conv = ConvModule( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + upsample = nn.Upsample(**upsample_cfg) + if conv_first: + self.interp_upsample = nn.Sequential(conv, upsample) + else: + self.interp_upsample = nn.Sequential(upsample, conv) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.interp_upsample, x) + else: + out = self.interp_upsample(x) + return out + + +@BACKBONES.register_module() +class UNet(BaseModule): + """UNet backbone. + U-Net: Convolutional Networks for Biomedical Image Segmentation. + https://arxiv.org/pdf/1505.04597.pdf + + Args: + in_channels (int): Number of input image channels. Default" 3. + base_channels (int): Number of base channels of each stage. + The output channels of the first stage. Default: 64. + num_stages (int): Number of stages in encoder, normally 5. Default: 5. + strides (Sequence[int 1 | 2]): Strides of each stage in encoder. + len(strides) is equal to num_stages. Normally the stride of the + first stage in encoder is 1. If strides[i]=2, it uses stride + convolution to downsample in the correspondence encoder stage. + Default: (1, 1, 1, 1, 1). + enc_num_convs (Sequence[int]): Number of convolutional layers in the + convolution block of the correspondence encoder stage. + Default: (2, 2, 2, 2, 2). + dec_num_convs (Sequence[int]): Number of convolutional layers in the + convolution block of the correspondence decoder stage. + Default: (2, 2, 2, 2). + downsamples (Sequence[int]): Whether use MaxPool to downsample the + feature map after the first stage of encoder + (stages: [1, num_stages)). If the correspondence encoder stage use + stride convolution (strides[i]=2), it will never use MaxPool to + downsample, even downsamples[i-1]=True. + Default: (True, True, True, True). + enc_dilations (Sequence[int]): Dilation rate of each stage in encoder. + Default: (1, 1, 1, 1, 1). + dec_dilations (Sequence[int]): Dilation rate of each stage in decoder. + Default: (1, 1, 1, 1). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + upsample_cfg (dict): The upsample config of the upsample module in + decoder. Default: dict(type='InterpConv'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + + Notice: + The input image size should be divisible by the whole downsample rate + of the encoder. More detail of the whole downsample rate can be found + in UNet._check_input_divisible. + + """ + + def __init__(self, + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + norm_eval=False, + dcn=None, + plugins=None, + init_cfg=[ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + layer=['_BatchNorm', 'GroupNorm'], + val=1) + ]): + super().__init__(init_cfg=init_cfg) + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + assert len(strides) == num_stages, ( + 'The length of strides should be equal to num_stages, ' + f'while the strides is {strides}, the length of ' + f'strides is {len(strides)}, and the num_stages is ' + f'{num_stages}.') + assert len(enc_num_convs) == num_stages, ( + 'The length of enc_num_convs should be equal to num_stages, ' + f'while the enc_num_convs is {enc_num_convs}, the length of ' + f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is ' + f'{num_stages}.') + assert len(dec_num_convs) == (num_stages - 1), ( + 'The length of dec_num_convs should be equal to (num_stages-1), ' + f'while the dec_num_convs is {dec_num_convs}, the length of ' + f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is ' + f'{num_stages}.') + assert len(downsamples) == (num_stages - 1), ( + 'The length of downsamples should be equal to (num_stages-1), ' + f'while the downsamples is {downsamples}, the length of ' + f'downsamples is {len(downsamples)}, and the num_stages is ' + f'{num_stages}.') + assert len(enc_dilations) == num_stages, ( + 'The length of enc_dilations should be equal to num_stages, ' + f'while the enc_dilations is {enc_dilations}, the length of ' + f'enc_dilations is {len(enc_dilations)}, and the num_stages is ' + f'{num_stages}.') + assert len(dec_dilations) == (num_stages - 1), ( + 'The length of dec_dilations should be equal to (num_stages-1), ' + f'while the dec_dilations is {dec_dilations}, the length of ' + f'dec_dilations is {len(dec_dilations)}, and the num_stages is ' + f'{num_stages}.') + self.num_stages = num_stages + self.strides = strides + self.downsamples = downsamples + self.norm_eval = norm_eval + self.base_channels = base_channels + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + for i in range(num_stages): + enc_conv_block = [] + if i != 0: + if strides[i] == 1 and downsamples[i - 1]: + enc_conv_block.append(nn.MaxPool2d(kernel_size=2)) + upsample = (strides[i] != 1 or downsamples[i - 1]) + self.decoder.append( + UpConvBlock( + conv_block=BasicConvBlock, + in_channels=base_channels * 2**i, + skip_channels=base_channels * 2**(i - 1), + out_channels=base_channels * 2**(i - 1), + num_convs=dec_num_convs[i - 1], + stride=1, + dilation=dec_dilations[i - 1], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + upsample_cfg=upsample_cfg if upsample else None, + dcn=None, + plugins=None)) + + enc_conv_block.append( + BasicConvBlock( + in_channels=in_channels, + out_channels=base_channels * 2**i, + num_convs=enc_num_convs[i], + stride=strides[i], + dilation=enc_dilations[i], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dcn=None, + plugins=None)) + self.encoder.append((nn.Sequential(*enc_conv_block))) + in_channels = base_channels * 2**i + + def forward(self, x): + self._check_input_divisible(x) + enc_outs = [] + for enc in self.encoder: + x = enc(x) + enc_outs.append(x) + dec_outs = [x] + for i in reversed(range(len(self.decoder))): + x = self.decoder[i](enc_outs[i], x) + dec_outs.append(x) + + return dec_outs + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super().train(mode) + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _check_input_divisible(self, x): + h, w = x.shape[-2:] + whole_downsample_rate = 1 + for i in range(1, self.num_stages): + if self.strides[i] == 2 or self.downsamples[i - 1]: + whole_downsample_rate *= 2 + assert ( + h % whole_downsample_rate == 0 and w % whole_downsample_rate == 0 + ), (f'The input image size {(h, w)} should be divisible by the whole ' + f'downsample rate {whole_downsample_rate}, when num_stages is ' + f'{self.num_stages}, strides is {self.strides}, and downsamples ' + f'is {self.downsamples}.') diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/common/detectors/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/common/detectors/__init__.py new file mode 100755 index 00000000..609824a1 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/common/detectors/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .single_stage import SingleStageDetector + +__all__ = ['SingleStageDetector'] diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/common/detectors/single_stage.py b/cv/ocr/dbnet/pytorch/dbnet/models/common/detectors/single_stage.py new file mode 100755 index 00000000..dad1d98b --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/common/detectors/single_stage.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from dbnet_det.models.detectors import \ + SingleStageDetector as dbnet_det_SingleStageDetector + +from dbnet.models.builder import (DETECTORS, build_backbone, build_head, + build_neck) + + +@DETECTORS.register_module() +class SingleStageDetector(dbnet_det_SingleStageDetector): + """Base class for single-stage detectors. + + Single-stage detectors directly and densely predict bounding boxes on the + output features of the backbone+neck. + """ + + def __init__(self, + backbone, + neck=None, + bbox_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None): + super(dbnet_det_SingleStageDetector, self).__init__(init_cfg=init_cfg) + if pretrained: + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + backbone.pretrained = pretrained + self.backbone = build_backbone(backbone) + if neck is not None: + self.neck = build_neck(neck) + bbox_head.update(train_cfg=train_cfg) + bbox_head.update(test_cfg=test_cfg) + self.bbox_head = build_head(bbox_head) + self.train_cfg = train_cfg + self.test_cfg = test_cfg diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/common/losses/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/common/losses/__init__.py new file mode 100755 index 00000000..b34608a8 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/common/losses/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dice_loss import DiceLoss +# from .focal_loss import FocalLoss + +# __all__ = ['DiceLoss', 'FocalLoss'] +__all__ = ['DiceLoss'] \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/common/losses/dice_loss.py b/cv/ocr/dbnet/pytorch/dbnet/models/common/losses/dice_loss.py new file mode 100755 index 00000000..d2ace254 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/common/losses/dice_loss.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from dbnet.models.builder import LOSSES + + +@LOSSES.register_module() +class DiceLoss(nn.Module): + + def __init__(self, eps=1e-6): + super().__init__() + assert isinstance(eps, float) + self.eps = eps + + def forward(self, pred, target, mask=None): + + pred = pred.contiguous().view(pred.size()[0], -1) + target = target.contiguous().view(target.size()[0], -1) + + if mask is not None: + mask = mask.contiguous().view(mask.size()[0], -1) + pred = pred * mask + target = target * mask + + a = torch.sum(pred * target) + b = torch.sum(pred) + c = torch.sum(target) + d = (2 * a) / (b + c + self.eps) + + return 1 - d diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/common/losses/focal_loss.py b/cv/ocr/dbnet/pytorch/dbnet/models/common/losses/focal_loss.py new file mode 100755 index 00000000..1a42ab01 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/common/losses/focal_loss.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FocalLoss(nn.Module): + """Multi-class Focal loss implementation. + + Args: + gamma (float): The larger the gamma, the smaller + the loss weight of easier samples. + weight (float): A manual rescaling weight given to each + class. + ignore_index (int): Specifies a target value that is ignored + and does not contribute to the input gradient. + """ + + def __init__(self, gamma=2, weight=None, ignore_index=-100): + super().__init__() + self.gamma = gamma + self.weight = weight + self.ignore_index = ignore_index + + def forward(self, input, target): + logit = F.log_softmax(input, dim=1) + pt = torch.exp(logit) + logit = (1 - pt)**self.gamma * logit + loss = F.nll_loss( + logit, target, self.weight, ignore_index=self.ignore_index) + return loss diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/__init__.py new file mode 100755 index 00000000..027e812f --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from . import dense_heads, detectors, losses, necks, postprocess +from .dense_heads import * # NOQA +from .detectors import * # NOQA +from .losses import * # NOQA +from .necks import * # NOQA +from .postprocess import * # NOQA + +__all__ = ( + dense_heads.__all__ + detectors.__all__ + losses.__all__ + necks.__all__ + + postprocess.__all__) diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/__init__.py new file mode 100755 index 00000000..fd782a50 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .db_head import DBHead +# from .drrg_head import DRRGHead +# from .fce_head import FCEHead +from .head_mixin import HeadMixin +# from .pan_head import PANHead +# from .pse_head import PSEHead +# from .textsnake_head import TextSnakeHead + + +__all__ = [ + 'DBHead','HeadMixin' +] + +# __all__ = [ +# 'PSEHead', 'PANHead', 'DBHead', 'FCEHead', 'TextSnakeHead', 'DRRGHead', +# 'HeadMixin' +# ] diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/db_head.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/db_head.py new file mode 100755 index 00000000..dc6a6899 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/db_head.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch +import torch.nn as nn +from dbnet_cv.runner import BaseModule, Sequential + +from dbnet.models.builder import HEADS +from .head_mixin import HeadMixin + + +@HEADS.register_module() +class DBHead(HeadMixin, BaseModule): + """The class for DBNet head. + + This was partially adapted from https://github.com/MhLiao/DB + + Args: + in_channels (int): The number of input channels of the db head. + with_bias (bool): Whether add bias in Conv2d layer. + downsample_ratio (float): The downsample ratio of ground truths. + loss (dict): Config of loss for dbnet. + postprocessor (dict): Config of postprocessor for dbnet. + """ + + def __init__( + self, + in_channels, + with_bias=False, + downsample_ratio=1.0, + loss=dict(type='DBLoss'), + postprocessor=dict(type='DBPostprocessor', text_repr_type='quad'), + init_cfg=[ + dict(type='Kaiming', layer='Conv'), + dict(type='Constant', layer='BatchNorm', val=1., bias=1e-4) + ], + train_cfg=None, + test_cfg=None, + **kwargs): + old_keys = ['text_repr_type', 'decoding_type'] + for key in old_keys: + if kwargs.get(key, None): + postprocessor[key] = kwargs.get(key) + warnings.warn( + f'{key} is deprecated, please specify ' + 'it in postprocessor config dict. See ' + 'https://github.com/open-mmlab/mmocr/pull/640' + ' for details.', UserWarning) + BaseModule.__init__(self, init_cfg=init_cfg) + HeadMixin.__init__(self, loss, postprocessor) + + assert isinstance(in_channels, int) + + self.in_channels = in_channels + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.downsample_ratio = downsample_ratio + + self.binarize = Sequential( + nn.Conv2d( + in_channels, in_channels // 4, 3, bias=with_bias, padding=1), + nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2), + nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid()) + + self.threshold = self._init_thr(in_channels) + + def diff_binarize(self, prob_map, thr_map, k): + return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map))) + + def forward(self, inputs): + """ + Args: + inputs (Tensor): Shape (batch_size, hidden_size, h, w). + + Returns: + Tensor: A tensor of the same shape as input. + """ + prob_map = self.binarize(inputs) + thr_map = self.threshold(inputs) + binary_map = self.diff_binarize(prob_map, thr_map, k=50) + outputs = torch.cat((prob_map, thr_map, binary_map), dim=1) + return outputs + + def _init_thr(self, inner_channels, bias=False): + in_channels = inner_channels + + seq = Sequential( + nn.Conv2d( + in_channels, inner_channels // 4, 3, padding=1, bias=bias), + nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), + nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2), + nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), + nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), nn.Sigmoid()) + return seq diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/head_mixin.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/head_mixin.py new file mode 100755 index 00000000..a25cce4a --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/dense_heads/head_mixin.py @@ -0,0 +1,91 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +from dbnet.models.builder import HEADS, build_loss, build_postprocessor +from dbnet.utils import check_argument + + +@HEADS.register_module() +class HeadMixin: + """Base head class for text detection, including loss calcalation and + postprocess. + + Args: + loss (dict): Config to build loss. + postprocessor (dict): Config to build postprocessor. + """ + + def __init__(self, loss, postprocessor): + assert isinstance(loss, dict) + assert isinstance(postprocessor, dict) + + self.loss_module = build_loss(loss) + self.postprocessor = build_postprocessor(postprocessor) + + def resize_boundary(self, boundaries, scale_factor): + """Rescale boundaries via scale_factor. + + Args: + boundaries (list[list[float]]): The boundary list. Each boundary + has :math:`2k+1` elements with :math:`k>=4`. + scale_factor (ndarray): The scale factor of size :math:`(4,)`. + + Returns: + list[list[float]]: The scaled boundaries. + """ + assert check_argument.is_2dlist(boundaries) + assert isinstance(scale_factor, np.ndarray) + assert scale_factor.shape[0] == 4 + + for b in boundaries: + sz = len(b) + check_argument.valid_boundary(b, True) + b[:sz - + 1] = (np.array(b[:sz - 1]) * + (np.tile(scale_factor[:2], int( + (sz - 1) / 2)).reshape(1, sz - 1))).flatten().tolist() + return boundaries + + def get_boundary(self, score_maps, img_metas, rescale): + """Compute text boundaries via post processing. + + Args: + score_maps (Tensor): The text score map. + img_metas (dict): The image meta info. + rescale (bool): Rescale boundaries to the original image resolution + if true, and keep the score_maps resolution if false. + + Returns: + dict: A dict where boundary results are stored in + ``boundary_result``. + """ + + assert check_argument.is_type_list(img_metas, dict) + assert isinstance(rescale, bool) + + score_maps = score_maps.squeeze() + boundaries = self.postprocessor(score_maps) + + if rescale: + boundaries = self.resize_boundary( + boundaries, + 1.0 / self.downsample_ratio / img_metas[0]['scale_factor']) + + results = dict( + boundary_result=boundaries, filename=img_metas[0]['filename']) + + return results + + def loss(self, pred_maps, **kwargs): + """Compute the loss for scene text detection. + + Args: + pred_maps (Tensor): The input score maps of shape + :math:`(NxCxHxW)`. + + Returns: + dict: The dict for losses. + """ + losses = self.loss_module(pred_maps, self.downsample_ratio, **kwargs) + + return losses diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/__init__.py new file mode 100755 index 00000000..89f0f98e --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dbnet import DBNet +# from .drrg import DRRG +# from .fcenet import FCENet +# from .ocr_mask_rcnn import OCRMaskRCNN +# from .panet import PANet +# from .psenet import PSENet +from .single_stage_text_detector import SingleStageTextDetector +from .text_detector_mixin import TextDetectorMixin +# from .textsnake import TextSnake + + +# __all__ = [ +# 'TextDetectorMixin', 'SingleStageTextDetector', 'OCRMaskRCNN', 'DBNet', +# 'PANet', 'PSENet', 'TextSnake', 'FCENet', 'DRRG' +# ] +__all__ = [ + 'TextDetectorMixin', 'SingleStageTextDetector', 'DBNet' + ] diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/dbnet.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/dbnet.py new file mode 100755 index 00000000..c598252e --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/dbnet.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dbnet.models.builder import DETECTORS +from .single_stage_text_detector import SingleStageTextDetector +from .text_detector_mixin import TextDetectorMixin + + +@DETECTORS.register_module() +class DBNet(TextDetectorMixin, SingleStageTextDetector): + """The class for implementing DBNet text detector: Real-time Scene Text + Detection with Differentiable Binarization. + + [https://arxiv.org/abs/1911.08947]. + """ + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None, + show_score=False, + init_cfg=None): + SingleStageTextDetector.__init__(self, backbone, neck, bbox_head, + train_cfg, test_cfg, pretrained, + init_cfg) + TextDetectorMixin.__init__(self, show_score) diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/single_stage_text_detector.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/single_stage_text_detector.py new file mode 100755 index 00000000..8b2b28a8 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/single_stage_text_detector.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from dbnet.models.builder import DETECTORS +from dbnet.models.common.detectors import SingleStageDetector + + +@DETECTORS.register_module() +class SingleStageTextDetector(SingleStageDetector): + """The class for implementing single stage text detector.""" + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None): + SingleStageDetector.__init__(self, backbone, neck, bbox_head, + train_cfg, test_cfg, pretrained, init_cfg) + + def forward_train(self, img, img_metas, **kwargs): + """ + Args: + img (Tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + img_metas (list[dict]): A list of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys, see + :class:`dbnet_det.datasets.pipelines.Collect`. + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + x = self.extract_feat(img) + preds = self.bbox_head(x) + losses = self.bbox_head.loss(preds, **kwargs) + return losses + + def simple_test(self, img, img_metas, rescale=False): + x = self.extract_feat(img) + outs = self.bbox_head(x) + + # early return to avoid post processing + if torch.onnx.is_in_onnx_export(): + return outs + + if len(img_metas) > 1: + boundaries = [ + self.bbox_head.get_boundary(*(outs[i].unsqueeze(0)), + [img_metas[i]], rescale) + for i in range(len(img_metas)) + ] + + else: + boundaries = [ + self.bbox_head.get_boundary(*outs, img_metas, rescale) + ] + + return boundaries diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/text_detector_mixin.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/text_detector_mixin.py new file mode 100755 index 00000000..376517a8 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/detectors/text_detector_mixin.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import dbnet_cv + +from dbnet.core import imshow_pred_boundary + + +class TextDetectorMixin: + """Base class for text detector, only to show results. + + Args: + show_score (bool): Whether to show text instance score. + """ + + def __init__(self, show_score): + self.show_score = show_score + + def show_result(self, + img, + result, + score_thr=0.5, + bbox_color='green', + text_color='green', + thickness=1, + font_scale=0.5, + win_name='', + show=False, + wait_time=0, + out_file=None): + """Draw `result` over `img`. + + Args: + img (str or Tensor): The image to be displayed. + result (dict): The results to draw over `img`. + score_thr (float, optional): Minimum score of bboxes to be shown. + Default: 0.3. + bbox_color (str or tuple or :obj:`Color`): Color of bbox lines. + text_color (str or tuple or :obj:`Color`): Color of texts. + thickness (int): Thickness of lines. + font_scale (float): Font scales of texts. + win_name (str): The window name. + wait_time (int): Value of waitKey param. + Default: 0. + show (bool): Whether to show the image. + Default: False. + out_file (str or None): The filename to write the image. + Default: None.imshow_pred_boundary` + """ + img = dbnet_cv.imread(img) + img = img.copy() + boundaries = None + labels = None + if 'boundary_result' in result.keys(): + boundaries = result['boundary_result'] + labels = [0] * len(boundaries) + + # if out_file specified, do not show image in window + if out_file is not None: + show = False + # draw bounding boxes + if boundaries is not None: + imshow_pred_boundary( + img, + boundaries, + labels, + score_thr=score_thr, + boundary_color=bbox_color, + text_color=text_color, + thickness=thickness, + font_scale=font_scale, + win_name=win_name, + show=show, + wait_time=wait_time, + out_file=out_file, + show_score=self.show_score) + + if not (show or out_file): + warnings.warn('show==False and out_file is not specified, ' + 'result image will be returned') + return img diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/losses/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/losses/__init__.py new file mode 100755 index 00000000..14e6b5d1 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/losses/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .db_loss import DBLoss +# from .drrg_loss import DRRGLoss +# from .fce_loss import FCELoss +# from .pan_loss import PANLoss +# from .pse_loss import PSELoss +# from .textsnake_loss import TextSnakeLoss + +# __all__ = [ +# 'PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss', 'FCELoss', 'DRRGLoss' +# ] +__all__ = [ + 'DBLoss'] \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/losses/db_loss.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/losses/db_loss.py new file mode 100755 index 00000000..b9c32ef5 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/losses/db_loss.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F +from torch import nn + +from dbnet.models.builder import LOSSES +from dbnet.models.common.losses.dice_loss import DiceLoss + + +@LOSSES.register_module() +class DBLoss(nn.Module): + """The class for implementing DBNet loss. + + This is partially adapted from https://github.com/MhLiao/DB. + + Args: + alpha (float): The binary loss coef. + beta (float): The threshold loss coef. + reduction (str): The way to reduce the loss. + negative_ratio (float): The ratio of positives to negatives. + eps (float): Epsilon in the threshold loss function. + bbce_loss (bool): Whether to use balanced bce for probability loss. + If False, dice loss will be used instead. + """ + + def __init__(self, + alpha=1, + beta=1, + reduction='mean', + negative_ratio=3.0, + eps=1e-6, + bbce_loss=False): + super().__init__() + assert reduction in ['mean', + 'sum'], " reduction must in ['mean','sum']" + self.alpha = alpha + self.beta = beta + self.reduction = reduction + self.negative_ratio = negative_ratio + self.eps = eps + self.bbce_loss = bbce_loss + self.dice_loss = DiceLoss(eps=eps) + + def bitmasks2tensor(self, bitmasks, target_sz): + """Convert Bitmasks to tensor. + + Args: + bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is + for one img. + target_sz (tuple(int, int)): The target tensor of size + :math:`(H, W)`. + + Returns: + list[Tensor]: The list of kernel tensors. Each element stands for + one kernel level. + """ + assert isinstance(bitmasks, list) + assert isinstance(target_sz, tuple) + + batch_size = len(bitmasks) + num_levels = len(bitmasks[0]) + + result_tensors = [] + + for level_inx in range(num_levels): + kernel = [] + for batch_inx in range(batch_size): + mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx]) + mask_sz = mask.shape + pad = [ + 0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0] + ] + mask = F.pad(mask, pad, mode='constant', value=0) + kernel.append(mask) + kernel = torch.stack(kernel) + result_tensors.append(kernel) + + return result_tensors + + def balance_bce_loss(self, pred, gt, mask): + + positive = (gt * mask) + negative = ((1 - gt) * mask) + positive_count = int(positive.float().sum()) + negative_count = min( + int(negative.float().sum()), + int(positive_count * self.negative_ratio)) + + assert gt.max() <= 1 and gt.min() >= 0 + assert pred.max() <= 1 and pred.min() >= 0 + + loss = F.binary_cross_entropy_with_logits(pred, gt, reduction='none') + positive_loss = loss * positive.float() + negative_loss = loss * negative.float() + + negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) + + balance_loss = (positive_loss.sum() + negative_loss.sum()) / ( + positive_count + negative_count + self.eps) + + return balance_loss + + def l1_thr_loss(self, pred, gt, mask): + thr_loss = torch.abs((pred - gt) * mask).sum() / ( + mask.sum() + self.eps) + return thr_loss + + def forward(self, preds, downsample_ratio, gt_shrink, gt_shrink_mask, + gt_thr, gt_thr_mask): + """Compute DBNet loss. + + Args: + preds (Tensor): The output tensor with size :math:`(N, 3, H, W)`. + downsample_ratio (float): The downsample ratio for the + ground truths. + gt_shrink (list[BitmapMasks]): The mask list with each element + being the shrunk text mask for one img. + gt_shrink_mask (list[BitmapMasks]): The effective mask list with + each element being the shrunk effective mask for one img. + gt_thr (list[BitmapMasks]): The mask list with each element + being the threshold text mask for one img. + gt_thr_mask (list[BitmapMasks]): The effective mask list with + each element being the threshold effective mask for one img. + + Returns: + dict: The dict for dbnet losses with "loss_prob", "loss_db" and + "loss_thresh". + """ + assert isinstance(downsample_ratio, float) + + assert isinstance(gt_shrink, list) + assert isinstance(gt_shrink_mask, list) + assert isinstance(gt_thr, list) + assert isinstance(gt_thr_mask, list) + + pred_prob = preds[:, 0, :, :] + pred_thr = preds[:, 1, :, :] + pred_db = preds[:, 2, :, :] + feature_sz = preds.size() + + keys = ['gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask'] + gt = {} + for k in keys: + gt[k] = eval(k) + gt[k] = [item.rescale(downsample_ratio) for item in gt[k]] + gt[k] = self.bitmasks2tensor(gt[k], feature_sz[2:]) + gt[k] = [item.to(preds.device) for item in gt[k]] + gt['gt_shrink'][0] = (gt['gt_shrink'][0] > 0).float() + if self.bbce_loss: + loss_prob = self.balance_bce_loss(pred_prob, gt['gt_shrink'][0], + gt['gt_shrink_mask'][0]) + else: + loss_prob = self.dice_loss(pred_prob, gt['gt_shrink'][0], + gt['gt_shrink_mask'][0]) + + loss_db = self.dice_loss(pred_db, gt['gt_shrink'][0], + gt['gt_shrink_mask'][0]) + loss_thr = self.l1_thr_loss(pred_thr, gt['gt_thr'][0], + gt['gt_thr_mask'][0]) + + results = dict( + loss_prob=self.alpha * loss_prob, + loss_db=loss_db, + loss_thr=self.beta * loss_thr) + + return results diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/necks/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/necks/__init__.py new file mode 100755 index 00000000..f0be483e --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/necks/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# from .fpem_ffm import FPEM_FFM +from .fpn_cat import FPNC +# from .fpn_unet import FPN_UNet +# from .fpnf import FPNF + +# __all__ = ['FPEM_FFM', 'FPNF', 'FPNC', 'FPN_UNet'] +__all__ = [ 'FPNC'] \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/necks/fpn_cat.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/necks/fpn_cat.py new file mode 100755 index 00000000..93285e1c --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/necks/fpn_cat.py @@ -0,0 +1,271 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from dbnet_cv.cnn import ConvModule +from dbnet_cv.runner import BaseModule, ModuleList, Sequential, auto_fp16 + +from dbnet.models.builder import NECKS + + +@NECKS.register_module() +class FPNC(BaseModule): + """FPN-like fusion module in Real-time Scene Text Detection with + Differentiable Binarization. + + This was partially adapted from https://github.com/MhLiao/DB and + https://github.com/WenmuZhou/DBNet.pytorch. + + Args: + in_channels (list[int]): A list of numbers of input channels. + lateral_channels (int): Number of channels for lateral layers. + out_channels (int): Number of output channels. + bias_on_lateral (bool): Whether to use bias on lateral convolutional + layers. + bn_re_on_lateral (bool): Whether to use BatchNorm and ReLU + on lateral convolutional layers. + bias_on_smooth (bool): Whether to use bias on smoothing layer. + bn_re_on_smooth (bool): Whether to use BatchNorm and ReLU on smoothing + layer. + asf_cfg (dict): Adaptive Scale Fusion module configs. The + attention_type can be 'ScaleChannelSpatial'. + conv_after_concat (bool): Whether to add a convolution layer after + the concatenation of predictions. + init_cfg (dict or list[dict], optional): Initialization configs. + """ + + def __init__(self, + in_channels, + lateral_channels=256, + out_channels=64, + bias_on_lateral=False, + bn_re_on_lateral=False, + bias_on_smooth=False, + bn_re_on_smooth=False, + asf_cfg=None, + conv_after_concat=False, + init_cfg=[ + dict(type='Kaiming', layer='Conv'), + dict( + type='Constant', layer='BatchNorm', val=1., bias=1e-4) + ]): + super().__init__(init_cfg=init_cfg) + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.lateral_channels = lateral_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.bn_re_on_lateral = bn_re_on_lateral + self.bn_re_on_smooth = bn_re_on_smooth + self.asf_cfg = asf_cfg + self.conv_after_concat = conv_after_concat + self.lateral_convs = ModuleList() + self.smooth_convs = ModuleList() + self.num_outs = self.num_ins + + for i in range(self.num_ins): + norm_cfg = None + act_cfg = None + if self.bn_re_on_lateral: + norm_cfg = dict(type='BN') + act_cfg = dict(type='ReLU') + l_conv = ConvModule( + in_channels[i], + lateral_channels, + 1, + bias=bias_on_lateral, + conv_cfg=None, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + norm_cfg = None + act_cfg = None + if self.bn_re_on_smooth: + norm_cfg = dict(type='BN') + act_cfg = dict(type='ReLU') + + smooth_conv = ConvModule( + lateral_channels, + out_channels, + 3, + bias=bias_on_smooth, + padding=1, + conv_cfg=None, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + self.lateral_convs.append(l_conv) + self.smooth_convs.append(smooth_conv) + + if self.asf_cfg is not None: + self.asf_conv = ConvModule( + out_channels * self.num_outs, + out_channels * self.num_outs, + 3, + padding=1, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + inplace=False) + if self.asf_cfg['attention_type'] == 'ScaleChannelSpatial': + self.asf_attn = ScaleChannelSpatialAttention( + self.out_channels * self.num_outs, + (self.out_channels * self.num_outs) // 4, self.num_outs) + else: + raise NotImplementedError + + if self.conv_after_concat: + norm_cfg = dict(type='BN') + act_cfg = dict(type='ReLU') + self.out_conv = ConvModule( + out_channels * self.num_outs, + out_channels * self.num_outs, + 3, + padding=1, + conv_cfg=None, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + @auto_fp16() + def forward(self, inputs): + """ + Args: + inputs (list[Tensor]): Each tensor has the shape of + :math:`(N, C_i, H_i, W_i)`. It usually expects 4 tensors + (C2-C5 features) from ResNet. + + Returns: + Tensor: A tensor of shape :math:`(N, C_{out}, H_0, W_0)` where + :math:`C_{out}` is ``out_channels``. + """ + assert len(inputs) == len(self.in_channels) + # build laterals + laterals = [ + lateral_conv(inputs[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + used_backbone_levels = len(laterals) + # build top-down path + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + F.interpolate( + laterals[i], size=prev_shape, mode='nearest') + # build outputs + # part 1: from original levels + outs = [ + self.smooth_convs[i](laterals[i]) + for i in range(used_backbone_levels) + ] + + for i, out in enumerate(outs): + outs[i] = F.interpolate( + outs[i], size=outs[0].shape[2:], mode='nearest') + + out = torch.cat(outs, dim=1) + if self.asf_cfg is not None: + asf_feature = self.asf_conv(out) + attention = self.asf_attn(asf_feature) + enhanced_feature = [] + for i, out in enumerate(outs): + enhanced_feature.append(attention[:, i:i + 1] * outs[i]) + out = torch.cat(enhanced_feature, dim=1) + + if self.conv_after_concat: + out = self.out_conv(out) + + return out + + +class ScaleChannelSpatialAttention(BaseModule): + """Spatial Attention module in Real-Time Scene Text Detection with + Differentiable Binarization and Adaptive Scale Fusion. + + This was partially adapted from https://github.com/MhLiao/DB + + Args: + in_channels (int): A numbers of input channels. + c_wise_channels (int): Number of channel-wise attention channels. + out_channels (int): Number of output channels. + init_cfg (dict or list[dict], optional): Initialization configs. + """ + + def __init__(self, + in_channels, + c_wise_channels, + out_channels, + init_cfg=[dict(type='Kaiming', layer='Conv', bias=0)]): + super().__init__(init_cfg=init_cfg) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # Channel Wise + self.channel_wise = Sequential( + ConvModule( + in_channels, + c_wise_channels, + 1, + bias=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + inplace=False), + ConvModule( + c_wise_channels, + in_channels, + 1, + bias=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='Sigmoid'), + inplace=False)) + # Spatial Wise + self.spatial_wise = Sequential( + ConvModule( + 1, + 1, + 3, + padding=1, + bias=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + inplace=False), + ConvModule( + 1, + 1, + 1, + bias=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='Sigmoid'), + inplace=False)) + # Attention Wise + self.attention_wise = ConvModule( + in_channels, + out_channels, + 1, + bias=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='Sigmoid'), + inplace=False) + + @auto_fp16() + def forward(self, inputs): + """ + Args: + inputs (Tensor): A concat FPN feature tensor that has the shape of + :math:`(N, C, H, W)`. + + Returns: + Tensor: An attention map of shape :math:`(N, C_{out}, H, W)` + where :math:`C_{out}` is ``out_channels``. + """ + out = self.avg_pool(inputs) + out = self.channel_wise(out) + out = out + inputs + inputs = torch.mean(out, dim=1, keepdim=True) + out = self.spatial_wise(inputs) + out + out = self.attention_wise(out) + + return out diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/__init__.py new file mode 100755 index 00000000..dcfc2e69 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_postprocessor import BasePostprocessor +from .db_postprocessor import DBPostprocessor +# from .drrg_postprocessor import DRRGPostprocessor +# from .fce_postprocessor import FCEPostprocessor +# from .pan_postprocessor import PANPostprocessor +# from .pse_postprocessor import PSEPostprocessor +# from .textsnake_postprocessor import TextSnakePostprocessor + + +# __all__ = [ +# 'BasePostprocessor', 'PSEPostprocessor', 'PANPostprocessor', +# 'DBPostprocessor', 'DRRGPostprocessor', 'FCEPostprocessor', +# 'TextSnakePostprocessor' +# ] +__all__ = [ + 'BasePostprocessor', + 'DBPostprocessor' ] diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/base_postprocessor.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/base_postprocessor.py new file mode 100755 index 00000000..734f87b6 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/base_postprocessor.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. + + +class BasePostprocessor: + + def __init__(self, text_repr_type='poly'): + assert text_repr_type in ['poly', 'quad' + ], f'Invalid text repr type {text_repr_type}' + + self.text_repr_type = text_repr_type + + def is_valid_instance(self, area, confidence, area_thresh, + confidence_thresh): + + return bool(area >= area_thresh and confidence > confidence_thresh) diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/db_postprocessor.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/db_postprocessor.py new file mode 100755 index 00000000..3996facf --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/db_postprocessor.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import cv2 +import numpy as np + +from dbnet.core import points2boundary +from dbnet.models.builder import POSTPROCESSOR +from .base_postprocessor import BasePostprocessor +from .utils import box_score_fast, unclip + + +@POSTPROCESSOR.register_module() +class DBPostprocessor(BasePostprocessor): + """Decoding predictions of DbNet to instances. This is partially adapted + from https://github.com/MhLiao/DB. + + Args: + text_repr_type (str): The boundary encoding type 'poly' or 'quad'. + mask_thr (float): The mask threshold value for binarization. + min_text_score (float): The threshold value for converting binary map + to shrink text regions. + min_text_width (int): The minimum width of boundary polygon/box + predicted. + unclip_ratio (float): The unclip ratio for text regions dilation. + epsilon_ratio (float): The epsilon ratio for approximation accuracy. + max_candidates (int): The maximum candidate number. + """ + + def __init__(self, + text_repr_type='poly', + mask_thr=0.3, + min_text_score=0.3, + min_text_width=5, + unclip_ratio=1.5, + epsilon_ratio=0.01, + max_candidates=3000, + **kwargs): + super().__init__(text_repr_type) + self.mask_thr = mask_thr + self.min_text_score = min_text_score + self.min_text_width = min_text_width + self.unclip_ratio = unclip_ratio + self.epsilon_ratio = epsilon_ratio + self.max_candidates = max_candidates + + def __call__(self, preds): + """ + Args: + preds (Tensor): Prediction map with shape :math:`(C, H, W)`. + + Returns: + list[list[float]]: The predicted text boundaries. + """ + assert preds.dim() == 3 + + prob_map = preds[0, :, :] + text_mask = prob_map > self.mask_thr + + score_map = prob_map.data.cpu().numpy().astype(np.float32) + text_mask = text_mask.data.cpu().numpy().astype(np.uint8) # to numpy + + contours, _ = cv2.findContours((text_mask * 255).astype(np.uint8), + cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + + boundaries = [] + for i, poly in enumerate(contours): + if i > self.max_candidates: + break + epsilon = self.epsilon_ratio * cv2.arcLength(poly, True) + approx = cv2.approxPolyDP(poly, epsilon, True) + points = approx.reshape((-1, 2)) + if points.shape[0] < 4: + continue + score = box_score_fast(score_map, points) + if score < self.min_text_score: + continue + poly = unclip(points, unclip_ratio=self.unclip_ratio) + if len(poly) == 0 or isinstance(poly[0], list): + continue + poly = poly.reshape(-1, 2) + + if self.text_repr_type == 'quad': + poly = points2boundary(poly, self.text_repr_type, score, + self.min_text_width) + elif self.text_repr_type == 'poly': + poly = poly.flatten().tolist() + if score is not None: + poly = poly + [score] + if len(poly) < 8: + poly = None + + if poly is not None: + boundaries.append(poly) + + return boundaries diff --git a/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/utils.py b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/utils.py new file mode 100755 index 00000000..50fbaec1 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/models/textdet/postprocess/utils.py @@ -0,0 +1,482 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +import operator + +import cv2 +import numpy as np +import pyclipper +from numpy.fft import ifft +from numpy.linalg import norm +from shapely.geometry import Polygon + +from dbnet.core.evaluation.utils import boundary_iou + + +def filter_instance(area, confidence, min_area, min_confidence): + return bool(area < min_area or confidence < min_confidence) + + +def box_score_fast(bitmap, _box): + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int32), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int32), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int32), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + + +def unclip(box, unclip_ratio=1.5): + poly = Polygon(box) + distance = poly.area * unclip_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)) + return expanded + + +def fill_hole(input_mask): + h, w = input_mask.shape + canvas = np.zeros((h + 2, w + 2), np.uint8) + canvas[1:h + 1, 1:w + 1] = input_mask.copy() + + mask = np.zeros((h + 4, w + 4), np.uint8) + + cv2.floodFill(canvas, mask, (0, 0), 1) + canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool) + + return ~canvas | input_mask + + +def centralize(points_yx, + normal_sin, + normal_cos, + radius, + contour_mask, + step_ratio=0.03): + + h, w = contour_mask.shape + top_yx = bot_yx = points_yx + step_flags = np.ones((len(points_yx), 1), dtype=np.bool) + step = step_ratio * radius * np.hstack([normal_sin, normal_cos]) + while np.any(step_flags): + next_yx = np.array(top_yx + step, dtype=np.int32) + next_y, next_x = next_yx[:, 0], next_yx[:, 1] + step_flags = (next_y >= 0) & (next_y < h) & (next_x > 0) & ( + next_x < w) & contour_mask[np.clip(next_y, 0, h - 1), + np.clip(next_x, 0, w - 1)] + top_yx = top_yx + step_flags.reshape((-1, 1)) * step + step_flags = np.ones((len(points_yx), 1), dtype=np.bool) + while np.any(step_flags): + next_yx = np.array(bot_yx - step, dtype=np.int32) + next_y, next_x = next_yx[:, 0], next_yx[:, 1] + step_flags = (next_y >= 0) & (next_y < h) & (next_x > 0) & ( + next_x < w) & contour_mask[np.clip(next_y, 0, h - 1), + np.clip(next_x, 0, w - 1)] + bot_yx = bot_yx - step_flags.reshape((-1, 1)) * step + centers = np.array((top_yx + bot_yx) * 0.5, dtype=np.int32) + return centers + + +def merge_disks(disks, disk_overlap_thr): + xy = disks[:, 0:2] + radius = disks[:, 2] + scores = disks[:, 3] + order = scores.argsort()[::-1] + + merged_disks = [] + while order.size > 0: + if order.size == 1: + merged_disks.append(disks[order]) + break + i = order[0] + d = norm(xy[i] - xy[order[1:]], axis=1) + ri = radius[i] + r = radius[order[1:]] + d_thr = (ri + r) * disk_overlap_thr + + merge_inds = np.where(d <= d_thr)[0] + 1 + if merge_inds.size > 0: + merge_order = np.hstack([i, order[merge_inds]]) + merged_disks.append(np.mean(disks[merge_order], axis=0)) + else: + merged_disks.append(disks[i]) + + inds = np.where(d > d_thr)[0] + 1 + order = order[inds] + merged_disks = np.vstack(merged_disks) + + return merged_disks + + +def poly_nms(polygons, threshold): + assert isinstance(polygons, list) + + polygons = np.array(sorted(polygons, key=lambda x: x[-1])) + + keep_poly = [] + index = [i for i in range(polygons.shape[0])] + + while len(index) > 0: + keep_poly.append(polygons[index[-1]].tolist()) + A = polygons[index[-1]][:-1] + index = np.delete(index, -1) + + iou_list = np.zeros((len(index), )) + for i in range(len(index)): + B = polygons[index[i]][:-1] + + iou_list[i] = boundary_iou(A, B, 1) + remove_index = np.where(iou_list > threshold) + index = np.delete(index, remove_index) + + return keep_poly + + +def fourier2poly(fourier_coeff, num_reconstr_points=50): + """ Inverse Fourier transform + Args: + fourier_coeff (ndarray): Fourier coefficients shaped (n, 2k+1), + with n and k being candidates number and Fourier degree + respectively. + num_reconstr_points (int): Number of reconstructed polygon points. + Returns: + Polygons (ndarray): The reconstructed polygons shaped (n, n') + """ + + a = np.zeros((len(fourier_coeff), num_reconstr_points), dtype='complex') + k = (len(fourier_coeff[0]) - 1) // 2 + + a[:, 0:k + 1] = fourier_coeff[:, k:] + a[:, -k:] = fourier_coeff[:, :k] + + poly_complex = ifft(a) * num_reconstr_points + polygon = np.zeros((len(fourier_coeff), num_reconstr_points, 2)) + polygon[:, :, 0] = poly_complex.real + polygon[:, :, 1] = poly_complex.imag + return polygon.astype('int32').reshape((len(fourier_coeff), -1)) + + +class Node: + + def __init__(self, ind): + self.__ind = ind + self.__links = set() + + @property + def ind(self): + return self.__ind + + @property + def links(self): + return set(self.__links) + + def add_link(self, link_node): + self.__links.add(link_node) + link_node.__links.add(self) + + +def graph_propagation(edges, scores, text_comps, edge_len_thr=50.): + """Propagate edge score information and construct graph. This code was + partially adapted from https://github.com/GXYM/DRRG licensed under the MIT + license. + + Args: + edges (ndarray): The edge array of shape N * 2, each row is a node + index pair that makes up an edge in graph. + scores (ndarray): The edge score array. + text_comps (ndarray): The text components. + edge_len_thr (float): The edge length threshold. + + Returns: + vertices (list[Node]): The Nodes in graph. + score_dict (dict): The edge score dict. + """ + assert edges.ndim == 2 + assert edges.shape[1] == 2 + assert edges.shape[0] == scores.shape[0] + assert text_comps.ndim == 2 + assert isinstance(edge_len_thr, float) + + edges = np.sort(edges, axis=1) + score_dict = {} + for i, edge in enumerate(edges): + if text_comps is not None: + box1 = text_comps[edge[0], :8].reshape(4, 2) + box2 = text_comps[edge[1], :8].reshape(4, 2) + center1 = np.mean(box1, axis=0) + center2 = np.mean(box2, axis=0) + distance = norm(center1 - center2) + if distance > edge_len_thr: + scores[i] = 0 + if (edge[0], edge[1]) in score_dict: + score_dict[edge[0], edge[1]] = 0.5 * ( + score_dict[edge[0], edge[1]] + scores[i]) + else: + score_dict[edge[0], edge[1]] = scores[i] + + nodes = np.sort(np.unique(edges.flatten())) + mapping = -1 * np.ones((np.max(nodes) + 1), dtype=np.int) + mapping[nodes] = np.arange(nodes.shape[0]) + order_inds = mapping[edges] + vertices = [Node(node) for node in nodes] + for ind in order_inds: + vertices[ind[0]].add_link(vertices[ind[1]]) + + return vertices, score_dict + + +def connected_components(nodes, score_dict, link_thr): + """Conventional connected components searching. This code was partially + adapted from https://github.com/GXYM/DRRG licensed under the MIT license. + + Args: + nodes (list[Node]): The list of Node objects. + score_dict (dict): The edge score dict. + link_thr (float): The link threshold. + + Returns: + clusters (List[list[Node]]): The clustered Node objects. + """ + assert isinstance(nodes, list) + assert all([isinstance(node, Node) for node in nodes]) + assert isinstance(score_dict, dict) + assert isinstance(link_thr, float) + + clusters = [] + nodes = set(nodes) + while nodes: + node = nodes.pop() + cluster = {node} + node_queue = [node] + while node_queue: + node = node_queue.pop(0) + neighbors = set([ + neighbor for neighbor in node.links if + score_dict[tuple(sorted([node.ind, neighbor.ind]))] >= link_thr + ]) + neighbors.difference_update(cluster) + nodes.difference_update(neighbors) + cluster.update(neighbors) + node_queue.extend(neighbors) + clusters.append(list(cluster)) + return clusters + + +def clusters2labels(clusters, num_nodes): + """Convert clusters of Node to text component labels. This code was + partially adapted from https://github.com/GXYM/DRRG licensed under the MIT + license. + + Args: + clusters (List[list[Node]]): The clusters of Node objects. + num_nodes (int): The total node number of graphs in an image. + + Returns: + node_labels (ndarray): The node label array. + """ + assert isinstance(clusters, list) + assert all([isinstance(cluster, list) for cluster in clusters]) + assert all( + [isinstance(node, Node) for cluster in clusters for node in cluster]) + assert isinstance(num_nodes, int) + + node_labels = np.zeros(num_nodes) + for cluster_ind, cluster in enumerate(clusters): + for node in cluster: + node_labels[node.ind] = cluster_ind + return node_labels + + +def remove_single(text_comps, comp_pred_labels): + """Remove isolated text components. This code was partially adapted from + https://github.com/GXYM/DRRG licensed under the MIT license. + + Args: + text_comps (ndarray): The text components. + comp_pred_labels (ndarray): The clustering labels of text components. + + Returns: + filtered_text_comps (ndarray): The text components with isolated ones + removed. + comp_pred_labels (ndarray): The clustering labels with labels of + isolated text components removed. + """ + assert text_comps.ndim == 2 + assert text_comps.shape[0] == comp_pred_labels.shape[0] + + single_flags = np.zeros_like(comp_pred_labels) + pred_labels = np.unique(comp_pred_labels) + for label in pred_labels: + current_label_flag = (comp_pred_labels == label) + if np.sum(current_label_flag) == 1: + single_flags[np.where(current_label_flag)[0][0]] = 1 + keep_ind = [i for i in range(len(comp_pred_labels)) if not single_flags[i]] + filtered_text_comps = text_comps[keep_ind, :] + filtered_labels = comp_pred_labels[keep_ind] + + return filtered_text_comps, filtered_labels + + +def norm2(point1, point2): + return ((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)**0.5 + + +def min_connect_path(points): + """Find the shortest path to traverse all points. This code was partially + adapted from https://github.com/GXYM/DRRG licensed under the MIT license. + + Args: + points(List[list[int]]): The point sequence [[x0, y0], [x1, y1], ...]. + + Returns: + shortest_path(List[list[int]]): The shortest index path. + """ + assert isinstance(points, list) + assert all([isinstance(point, list) for point in points]) + assert all([isinstance(coord, int) for point in points for coord in point]) + + points_queue = points.copy() + shortest_path = [] + current_edge = [[], []] + + edge_dict0 = {} + edge_dict1 = {} + current_edge[0] = points_queue[0] + current_edge[1] = points_queue[0] + points_queue.remove(points_queue[0]) + while points_queue: + for point in points_queue: + length0 = norm2(point, current_edge[0]) + edge_dict0[length0] = [point, current_edge[0]] + length1 = norm2(current_edge[1], point) + edge_dict1[length1] = [current_edge[1], point] + key0 = min(edge_dict0.keys()) + key1 = min(edge_dict1.keys()) + + if key0 <= key1: + start = edge_dict0[key0][0] + end = edge_dict0[key0][1] + shortest_path.insert(0, [points.index(start), points.index(end)]) + points_queue.remove(start) + current_edge[0] = start + else: + start = edge_dict1[key1][0] + end = edge_dict1[key1][1] + shortest_path.append([points.index(start), points.index(end)]) + points_queue.remove(end) + current_edge[1] = end + + edge_dict0 = {} + edge_dict1 = {} + + shortest_path = functools.reduce(operator.concat, shortest_path) + shortest_path = sorted(set(shortest_path), key=shortest_path.index) + + return shortest_path + + +def in_contour(cont, point): + x, y = point + is_inner = cv2.pointPolygonTest(cont, (int(x), int(y)), False) > 0.5 + return is_inner + + +def fix_corner(top_line, bot_line, start_box, end_box): + """Add corner points to predicted side lines. This code was partially + adapted from https://github.com/GXYM/DRRG licensed under the MIT license. + + Args: + top_line (List[list[int]]): The predicted top sidelines of text + instance. + bot_line (List[list[int]]): The predicted bottom sidelines of text + instance. + start_box (ndarray): The first text component box. + end_box (ndarray): The last text component box. + + Returns: + top_line (List[list[int]]): The top sidelines with corner point added. + bot_line (List[list[int]]): The bottom sidelines with corner point + added. + """ + assert isinstance(top_line, list) + assert all(isinstance(point, list) for point in top_line) + assert isinstance(bot_line, list) + assert all(isinstance(point, list) for point in bot_line) + assert start_box.shape == end_box.shape == (4, 2) + + contour = np.array(top_line + bot_line[::-1]) + start_left_mid = (start_box[0] + start_box[3]) / 2 + start_right_mid = (start_box[1] + start_box[2]) / 2 + end_left_mid = (end_box[0] + end_box[3]) / 2 + end_right_mid = (end_box[1] + end_box[2]) / 2 + if not in_contour(contour, start_left_mid): + top_line.insert(0, start_box[0].tolist()) + bot_line.insert(0, start_box[3].tolist()) + elif not in_contour(contour, start_right_mid): + top_line.insert(0, start_box[1].tolist()) + bot_line.insert(0, start_box[2].tolist()) + if not in_contour(contour, end_left_mid): + top_line.append(end_box[0].tolist()) + bot_line.append(end_box[3].tolist()) + elif not in_contour(contour, end_right_mid): + top_line.append(end_box[1].tolist()) + bot_line.append(end_box[2].tolist()) + return top_line, bot_line + + +def comps2boundaries(text_comps, comp_pred_labels): + """Construct text instance boundaries from clustered text components. This + code was partially adapted from https://github.com/GXYM/DRRG licensed under + the MIT license. + + Args: + text_comps (ndarray): The text components. + comp_pred_labels (ndarray): The clustering labels of text components. + + Returns: + boundaries (List[list[float]]): The predicted boundaries of text + instances. + """ + assert text_comps.ndim == 2 + assert len(text_comps) == len(comp_pred_labels) + boundaries = [] + if len(text_comps) < 1: + return boundaries + for cluster_ind in range(0, int(np.max(comp_pred_labels)) + 1): + cluster_comp_inds = np.where(comp_pred_labels == cluster_ind) + text_comp_boxes = text_comps[cluster_comp_inds, :8].reshape( + (-1, 4, 2)).astype(np.int32) + score = np.mean(text_comps[cluster_comp_inds, -1]) + + if text_comp_boxes.shape[0] < 1: + continue + + elif text_comp_boxes.shape[0] > 1: + centers = np.mean( + text_comp_boxes, axis=1).astype(np.int32).tolist() + shortest_path = min_connect_path(centers) + text_comp_boxes = text_comp_boxes[shortest_path] + top_line = np.mean( + text_comp_boxes[:, 0:2, :], axis=1).astype(np.int32).tolist() + bot_line = np.mean( + text_comp_boxes[:, 2:4, :], axis=1).astype(np.int32).tolist() + top_line, bot_line = fix_corner(top_line, bot_line, + text_comp_boxes[0], + text_comp_boxes[-1]) + boundary_points = top_line + bot_line[::-1] + + else: + top_line = text_comp_boxes[0, 0:2, :].astype(np.int32).tolist() + bot_line = text_comp_boxes[0, 2:4:-1, :].astype(np.int32).tolist() + boundary_points = top_line + bot_line + + boundary = [p for coord in boundary_points for p in coord] + [score] + boundaries.append(boundary) + + return boundaries diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/__init__.py b/cv/ocr/dbnet/pytorch/dbnet/utils/__init__.py new file mode 100755 index 00000000..66dbd294 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/utils/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dbnet_cv.utils import Registry, build_from_cfg + +from .box_util import (bezier_to_polygon, is_on_same_line, sort_points, + stitch_boxes_into_lines) +from .check_argument import (equal_len, is_2dlist, is_3dlist, is_none_or_type, + is_type_list, valid_boundary) +from .collect_env import collect_env +# from .data_convert_util import convert_annotations +from .fileio import list_from_file, list_to_file +# from .img_util import drop_orientation, is_not_png +from .lmdb_util import recog2lmdb +from .logger import get_root_logger +from .model import revert_sync_batchnorm +from .setup_env import setup_multi_processes +from .string_util import StringStrip + +__all__ = [ + 'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env', + 'is_3dlist', 'is_type_list', 'is_none_or_type', 'equal_len', 'is_2dlist', + 'valid_boundary', 'lmdb_converter', 'list_to_file', 'list_from_file', + 'is_on_same_line', 'stitch_boxes_into_lines', 'StringStrip', + 'revert_sync_batchnorm', 'bezier_to_polygon', 'sort_points', + 'setup_multi_processes', 'recog2lmdb' +] diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/box_util.py b/cv/ocr/dbnet/pytorch/dbnet/utils/box_util.py new file mode 100755 index 00000000..c0ef44e5 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/utils/box_util.py @@ -0,0 +1,199 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + +import numpy as np + +from dbnet.utils.check_argument import is_2dlist, is_type_list + + +def is_on_same_line(box_a, box_b, min_y_overlap_ratio=0.8): + """Check if two boxes are on the same line by their y-axis coordinates. + + Two boxes are on the same line if they overlap vertically, and the length + of the overlapping line segment is greater than min_y_overlap_ratio * the + height of either of the boxes. + + Args: + box_a (list), box_b (list): Two bounding boxes to be checked + min_y_overlap_ratio (float): The minimum vertical overlapping ratio + allowed for boxes in the same line + + Returns: + The bool flag indicating if they are on the same line + """ + a_y_min = np.min(box_a[1::2]) + b_y_min = np.min(box_b[1::2]) + a_y_max = np.max(box_a[1::2]) + b_y_max = np.max(box_b[1::2]) + + # Make sure that box a is always the box above another + if a_y_min > b_y_min: + a_y_min, b_y_min = b_y_min, a_y_min + a_y_max, b_y_max = b_y_max, a_y_max + + if b_y_min <= a_y_max: + if min_y_overlap_ratio is not None: + sorted_y = sorted([b_y_min, b_y_max, a_y_max]) + overlap = sorted_y[1] - sorted_y[0] + min_a_overlap = (a_y_max - a_y_min) * min_y_overlap_ratio + min_b_overlap = (b_y_max - b_y_min) * min_y_overlap_ratio + return overlap >= min_a_overlap or \ + overlap >= min_b_overlap + else: + return True + return False + + +def stitch_boxes_into_lines(boxes, max_x_dist=10, min_y_overlap_ratio=0.8): + """Stitch fragmented boxes of words into lines. + + Note: part of its logic is inspired by @Johndirr + (https://github.com/faustomorales/keras-ocr/issues/22) + + Args: + boxes (list): List of ocr results to be stitched + max_x_dist (int): The maximum horizontal distance between the closest + edges of neighboring boxes in the same line + min_y_overlap_ratio (float): The minimum vertical overlapping ratio + allowed for any pairs of neighboring boxes in the same line + + Returns: + merged_boxes(list[dict]): List of merged boxes and texts + """ + + if len(boxes) <= 1: + return boxes + + merged_boxes = [] + + # sort groups based on the x_min coordinate of boxes + x_sorted_boxes = sorted(boxes, key=lambda x: np.min(x['box'][::2])) + # store indexes of boxes which are already parts of other lines + skip_idxs = set() + + i = 0 + # locate lines of boxes starting from the leftmost one + for i in range(len(x_sorted_boxes)): + if i in skip_idxs: + continue + # the rightmost box in the current line + rightmost_box_idx = i + line = [rightmost_box_idx] + for j in range(i + 1, len(x_sorted_boxes)): + if j in skip_idxs: + continue + if is_on_same_line(x_sorted_boxes[rightmost_box_idx]['box'], + x_sorted_boxes[j]['box'], min_y_overlap_ratio): + line.append(j) + skip_idxs.add(j) + rightmost_box_idx = j + + # split line into lines if the distance between two neighboring + # sub-lines' is greater than max_x_dist + lines = [] + line_idx = 0 + lines.append([line[0]]) + for k in range(1, len(line)): + curr_box = x_sorted_boxes[line[k]] + prev_box = x_sorted_boxes[line[k - 1]] + dist = np.min(curr_box['box'][::2]) - np.max(prev_box['box'][::2]) + if dist > max_x_dist: + line_idx += 1 + lines.append([]) + lines[line_idx].append(line[k]) + + # Get merged boxes + for box_group in lines: + merged_box = {} + merged_box['text'] = ' '.join( + [x_sorted_boxes[idx]['text'] for idx in box_group]) + x_min, y_min = float('inf'), float('inf') + x_max, y_max = float('-inf'), float('-inf') + for idx in box_group: + x_max = max(np.max(x_sorted_boxes[idx]['box'][::2]), x_max) + x_min = min(np.min(x_sorted_boxes[idx]['box'][::2]), x_min) + y_max = max(np.max(x_sorted_boxes[idx]['box'][1::2]), y_max) + y_min = min(np.min(x_sorted_boxes[idx]['box'][1::2]), y_min) + merged_box['box'] = [ + x_min, y_min, x_max, y_min, x_max, y_max, x_min, y_max + ] + merged_boxes.append(merged_box) + + return merged_boxes + + +def bezier_to_polygon(bezier_points, num_sample=20): + """Sample points from the boundary of a polygon enclosed by two Bezier + curves, which are controlled by ``bezier_points``. + + Args: + bezier_points (ndarray): A :math:`(2, 4, 2)` array of 8 Bezeir points + or its equalivance. The first 4 points control the curve at one + side and the last four control the other side. + num_sample (int): The number of sample points at each Bezeir curve. + + Returns: + list[ndarray]: A list of 2*num_sample points representing the polygon + extracted from Bezier curves. + + Warning: + The points are not guaranteed to be ordered. Please use + :func:`mmocr.utils.sort_points` to sort points if necessary. + """ + assert num_sample > 0 + + bezier_points = np.asarray(bezier_points) + assert np.prod( + bezier_points.shape) == 16, 'Need 8 Bezier control points to continue!' + + bezier = bezier_points.reshape(2, 4, 2).transpose(0, 2, 1).reshape(4, 4) + u = np.linspace(0, 1, num_sample) + + points = np.outer((1 - u) ** 3, bezier[:, 0]) \ + + np.outer(3 * u * ((1 - u) ** 2), bezier[:, 1]) \ + + np.outer(3 * (u ** 2) * (1 - u), bezier[:, 2]) \ + + np.outer(u ** 3, bezier[:, 3]) + + # Convert points to polygon + points = np.concatenate((points[:, :2], points[:, 2:]), axis=0) + return points.tolist() + + +def sort_points(points): + """Sort arbitory points in clockwise order. Reference: + https://stackoverflow.com/a/6989383. + + Args: + points (list[ndarray] or ndarray or list[list]): A list of unsorted + boundary points. + + Returns: + list[ndarray]: A list of points sorted in clockwise order. + """ + + assert is_type_list(points, np.ndarray) or isinstance(points, np.ndarray) \ + or is_2dlist(points) + + points = np.array(points) + center = np.mean(points, axis=0) + + def cmp(a, b): + oa = a - center + ob = b - center + + # Some corner cases + if oa[0] >= 0 and ob[0] < 0: + return 1 + if oa[0] < 0 and ob[0] >= 0: + return -1 + + prod = np.cross(oa, ob) + if prod > 0: + return 1 + if prod < 0: + return -1 + + # a, b are on the same line from the center + return 1 if (oa**2).sum() < (ob**2).sum() else -1 + + return sorted(points, key=functools.cmp_to_key(cmp)) diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/check_argument.py b/cv/ocr/dbnet/pytorch/dbnet/utils/check_argument.py new file mode 100755 index 00000000..34cbe8dc --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/utils/check_argument.py @@ -0,0 +1,72 @@ +# Copyright (c) OpenMMLab. All rights reserved. + + +def is_3dlist(x): + """check x is 3d-list([[[1], []]]) or 2d empty list([[], []]) or 1d empty + list([]). + + Notice: + The reason that it contains 1d or 2d empty list is because + some arguments from gt annotation file or model prediction + may be empty, but usually, it should be 3d-list. + """ + if not isinstance(x, list): + return False + if len(x) == 0: + return True + for sub_x in x: + if not is_2dlist(sub_x): + return False + + return True + + +def is_2dlist(x): + """check x is 2d-list([[1], []]) or 1d empty list([]). + + Notice: + The reason that it contains 1d empty list is because + some arguments from gt annotation file or model prediction + may be empty, but usually, it should be 2d-list. + """ + if not isinstance(x, list): + return False + if len(x) == 0: + return True + + return all(isinstance(item, list) for item in x) + + +def is_type_list(x, type): + + if not isinstance(x, list): + return False + + return all(isinstance(item, type) for item in x) + + +def is_none_or_type(x, type): + + return isinstance(x, type) or x is None + + +def equal_len(*argv): + assert len(argv) > 0 + + num_arg = len(argv[0]) + for arg in argv: + if len(arg) != num_arg: + return False + return True + + +def valid_boundary(x, with_score=True): + num = len(x) + if num < 8: + return False + if num % 2 == 0 and (not with_score): + return True + if num % 2 == 1 and with_score: + return True + + return False diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/collect_env.py b/cv/ocr/dbnet/pytorch/dbnet/utils/collect_env.py new file mode 100755 index 00000000..ed6bc0cc --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/utils/collect_env.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dbnet_cv.utils import collect_env as collect_base_env +from dbnet_cv.utils import get_git_hash + +import dbnet + + +def collect_env(): + """Collect the information of the running environments.""" + env_info = collect_base_env() + # env_info['MMOCR'] = mmocr.__version__ + '+' + get_git_hash()[:7] + return env_info + + +if __name__ == '__main__': + for name, val in collect_env().items(): + print(f'{name}: {val}') diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/fileio.py b/cv/ocr/dbnet/pytorch/dbnet/utils/fileio.py new file mode 100755 index 00000000..cf07dd05 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/utils/fileio.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os + +import dbnet_cv + + +def list_to_file(filename, lines): + """Write a list of strings to a text file. + + Args: + filename (str): The output filename. It will be created/overwritten. + lines (list(str)): Data to be written. + """ + dbnet_cv.mkdir_or_exist(os.path.dirname(filename)) + with open(filename, 'w', encoding='utf-8') as fw: + for line in lines: + fw.write(f'{line}\n') + + +def list_from_file(filename, encoding='utf-8'): + """Load a text file and parse the content as a list of strings. The + trailing "\\r" and "\\n" of each line will be removed. + + Note: + This will be replaced by dbnet_cv's version after it supports encoding. + + Args: + filename (str): Filename. + encoding (str): Encoding used to open the file. Default utf-8. + + Returns: + list[str]: A list of strings. + """ + item_list = [] + with open(filename, 'r', encoding=encoding) as f: + for line in f: + item_list.append(line.rstrip('\n\r')) + return item_list diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/lmdb_util.py b/cv/ocr/dbnet/pytorch/dbnet/utils/lmdb_util.py new file mode 100755 index 00000000..618c836f --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/utils/lmdb_util.py @@ -0,0 +1,128 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os +import os.path as osp + +import cv2 +import lmdb +import numpy as np + +from dbnet.utils import list_from_file + + +def check_image_is_valid(imageBin): + if imageBin is None: + return False + imageBuf = np.frombuffer(imageBin, dtype=np.uint8) + img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) + imgH, imgW = img.shape[0], img.shape[1] + if imgH * imgW == 0: + return False + return True + + +def parse_line(line, format): + if format == 'txt': + img_name, text = line.split(' ') + else: + line = json.loads(line) + img_name = line['filename'] + text = line['text'] + return img_name, text + + +def write_cache(env, cache): + with env.begin(write=True) as txn: + cursor = txn.cursor() + cursor.putmulti(cache, dupdata=False, overwrite=True) + + +def recog2lmdb(img_root, + label_path, + output, + label_format='txt', + label_only=False, + batch_size=1000, + encoding='utf-8', + lmdb_map_size=1099511627776, + verify=True): + """Create text recognition dataset to LMDB format. + + Args: + img_root (str): Path to images. + label_path (str): Path to label file. + output (str): LMDB output path. + label_format (str): Format of the label file, either txt or jsonl. + label_only (bool): Only convert label to lmdb format. + batch_size (int): Number of files written to the cache each time. + encoding (str): Label encoding method. + lmdb_map_size (int): Maximum size database may grow to. + verify (bool): If true, check the validity of + every image.Defaults to True. + + E.g. + This function supports MMOCR's recognition data format and the label file + can be txt or jsonl, as follows: + + ├──img_root + | |—— img1.jpg + | |—— img2.jpg + | |—— ... + |——label.txt (or label.jsonl) + + label.txt: img1.jpg HELLO + img2.jpg WORLD + ... + + label.jsonl: {'filename':'img1.jpg', 'text':'HELLO'} + {'filename':'img2.jpg', 'text':'WORLD'} + ... + """ + # check label format + assert osp.basename(label_path).split('.')[-1] == label_format + # create lmdb env + os.makedirs(output, exist_ok=True) + env = lmdb.open(output, map_size=lmdb_map_size) + # load label file + anno_list = list_from_file(label_path, encoding=encoding) + cache = [] + # index start from 1 + cnt = 1 + n_samples = len(anno_list) + for anno in anno_list: + label_key = 'label-%09d'.encode(encoding) % cnt + img_name, text = parse_line(anno, label_format) + if label_only: + # convert only labels to lmdb + line = json.dumps( + dict(filename=img_name, text=text), ensure_ascii=False) + cache.append((label_key, line.encode(encoding))) + else: + # convert both images and labels to lmdb + img_path = osp.join(img_root, img_name) + if not osp.exists(img_path): + print('%s does not exist' % img_path) + continue + with open(img_path, 'rb') as f: + image_bin = f.read() + if verify: + try: + if not check_image_is_valid(image_bin): + print('%s is not a valid image' % img_path) + continue + except Exception: + print('error occurred at ', img_name) + image_key = 'image-%09d'.encode(encoding) % cnt + cache.append((image_key, image_bin)) + cache.append((label_key, text.encode(encoding))) + + if cnt % batch_size == 0: + write_cache(env, cache) + cache = [] + print('Written %d / %d' % (cnt, n_samples)) + cnt += 1 + n_samples = cnt - 1 + cache.append( + ('num-samples'.encode(encoding), str(n_samples).encode(encoding))) + write_cache(env, cache) + print('Created lmdb dataset with %d samples' % n_samples) diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/logger.py b/cv/ocr/dbnet/pytorch/dbnet/utils/logger.py new file mode 100755 index 00000000..779ae99b --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/utils/logger.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging + +from dbnet_cv.utils import get_logger + + +def get_root_logger(log_file=None, log_level=logging.INFO): + """Use `get_logger` method in dbnet_cv to get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. The name of the root logger is the top-level package name, + e.g., "mmpose". + + Args: + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + + Returns: + logging.Logger: The root logger. + """ + return get_logger(__name__.split('.')[0], log_file, log_level) diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/model.py b/cv/ocr/dbnet/pytorch/dbnet/utils/model.py new file mode 100755 index 00000000..4a126006 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/utils/model.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm): + """A general BatchNorm layer without input dimension check. + + Reproduced from @kapily's work: + (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) + The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc + is `_check_input_dim` that is designed for tensor sanity checks. + The check has been bypassed in this class for the convenience of converting + SyncBatchNorm. + """ + + def _check_input_dim(self, input): + return + + +def revert_sync_batchnorm(module): + """Helper function to convert all `SyncBatchNorm` layers in the model to + `BatchNormXd` layers. + + Adapted from @kapily's work: + (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) + + Args: + module (nn.Module): The module containing `SyncBatchNorm` layers. + + Returns: + module_output: The converted module with `BatchNormXd` layers. + """ + module_output = module + if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm): + module_output = _BatchNormXd(module.num_features, module.eps, + module.momentum, module.affine, + module.track_running_stats) + if module.affine: + with torch.no_grad(): + module_output.weight = module.weight + module_output.bias = module.bias + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + module_output.training = module.training + if hasattr(module, 'qconfig'): + module_output.qconfig = module.qconfig + for name, child in module.named_children(): + module_output.add_module(name, revert_sync_batchnorm(child)) + del module + return module_output diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/ocr.py b/cv/ocr/dbnet/pytorch/dbnet/utils/ocr.py new file mode 100755 index 00000000..a09bb15d --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/utils/ocr.py @@ -0,0 +1,878 @@ +#!/usr/bin/env python +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os +import warnings +from argparse import ArgumentParser, Namespace +from pathlib import Path + +import dbnet_cv +import numpy as np +import torch +from dbnet_cv.image.misc import tensor2imgs +from dbnet_cv.runner import load_checkpoint +from dbnet_cv.utils.config import Config +from PIL import Image + +try: + import tesserocr +except ImportError: + tesserocr = None + +from dbnet.apis import init_detector +from dbnet.apis.inference import model_inference +from dbnet.core.visualize import det_recog_show_result +from dbnet.datasets.kie_dataset import KIEDataset +from dbnet.datasets.pipelines.crop import crop_img +from dbnet.models import build_detector +from dbnet.models.textdet.detectors import TextDetectorMixin +from dbnet.models.textrecog.recognizer import BaseRecognizer +from dbnet.utils import is_type_list +from dbnet.utils.box_util import stitch_boxes_into_lines +from dbnet.utils.fileio import list_from_file +from dbnet.utils.model import revert_sync_batchnorm + + +# Parse CLI arguments +def parse_args(): + parser = ArgumentParser() + parser.add_argument( + 'img', type=str, help='Input image file or folder path.') + parser.add_argument( + '--output', + type=str, + default='', + help='Output file/folder name for visualization') + parser.add_argument( + '--det', + type=str, + default='PANet_IC15', + help='Pretrained text detection algorithm') + parser.add_argument( + '--det-config', + type=str, + default='', + help='Path to the custom config file of the selected det model. It ' + 'overrides the settings in det') + parser.add_argument( + '--det-ckpt', + type=str, + default='', + help='Path to the custom checkpoint file of the selected det model. ' + 'It overrides the settings in det') + parser.add_argument( + '--recog', + type=str, + default='SEG', + help='Pretrained text recognition algorithm') + parser.add_argument( + '--recog-config', + type=str, + default='', + help='Path to the custom config file of the selected recog model. It' + 'overrides the settings in recog') + parser.add_argument( + '--recog-ckpt', + type=str, + default='', + help='Path to the custom checkpoint file of the selected recog model. ' + 'It overrides the settings in recog') + parser.add_argument( + '--kie', + type=str, + default='', + help='Pretrained key information extraction algorithm') + parser.add_argument( + '--kie-config', + type=str, + default='', + help='Path to the custom config file of the selected kie model. It' + 'overrides the settings in kie') + parser.add_argument( + '--kie-ckpt', + type=str, + default='', + help='Path to the custom checkpoint file of the selected kie model. ' + 'It overrides the settings in kie') + parser.add_argument( + '--config-dir', + type=str, + default=os.path.join(str(Path.cwd()), 'configs/'), + help='Path to the config directory where all the config files ' + 'are located. Defaults to "configs/"') + parser.add_argument( + '--batch-mode', + action='store_true', + help='Whether use batch mode for inference') + parser.add_argument( + '--recog-batch-size', + type=int, + default=0, + help='Batch size for text recognition') + parser.add_argument( + '--det-batch-size', + type=int, + default=0, + help='Batch size for text detection') + parser.add_argument( + '--single-batch-size', + type=int, + default=0, + help='Batch size for separate det/recog inference') + parser.add_argument( + '--device', default=None, help='Device used for inference.') + parser.add_argument( + '--export', + type=str, + default='', + help='Folder where the results of each image are exported') + parser.add_argument( + '--export-format', + type=str, + default='json', + help='Format of the exported result file(s)') + parser.add_argument( + '--details', + action='store_true', + help='Whether include the text boxes coordinates and confidence values' + ) + parser.add_argument( + '--imshow', + action='store_true', + help='Whether show image with OpenCV.') + parser.add_argument( + '--print-result', + action='store_true', + help='Prints the recognised text') + parser.add_argument( + '--merge', action='store_true', help='Merge neighboring boxes') + parser.add_argument( + '--merge-xdist', + type=float, + default=20, + help='The maximum x-axis distance to merge boxes') + args = parser.parse_args() + if args.det == 'None': + args.det = None + if args.recog == 'None': + args.recog = None + # Warnings + if args.merge and not (args.det and args.recog): + warnings.warn( + 'Box merging will not work if the script is not' + ' running in detection + recognition mode.', UserWarning) + if not os.path.samefile(args.config_dir, os.path.join(str( + Path.cwd()))) and (args.det_config != '' + or args.recog_config != ''): + warnings.warn( + 'config_dir will be overridden by det-config or recog-config.', + UserWarning) + return args + + +class MMOCR: + + def __init__(self, + det='PANet_IC15', + det_config='', + det_ckpt='', + recog='SEG', + recog_config='', + recog_ckpt='', + kie='', + kie_config='', + kie_ckpt='', + config_dir=os.path.join(str(Path.cwd()), 'configs/'), + device=None, + **kwargs): + + textdet_models = { + 'DB_r18': { + 'config': + 'dbnet/dbnet_r18_fpnc_1200e_icdar2015.py', + 'ckpt': + 'dbnet/' + 'dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth' + }, + 'DB_r50': { + 'config': + 'dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py', + 'ckpt': + 'dbnet/' + 'dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20211025-9fe3b590.pth' + }, + 'DBPP_r50': { + 'config': + 'dbnetpp/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015.py', + 'ckpt': + 'dbnet/' + 'dbnetpp_r50dcnv2_fpnc_1200e_icdar2015-20220502-d7a76fff.pth' + }, + 'DRRG': { + 'config': + 'drrg/drrg_r50_fpn_unet_1200e_ctw1500.py', + 'ckpt': + 'drrg/drrg_r50_fpn_unet_1200e_ctw1500_20211022-fb30b001.pth' + }, + 'FCE_IC15': { + 'config': + 'fcenet/fcenet_r50_fpn_1500e_icdar2015.py', + 'ckpt': + 'fcenet/fcenet_r50_fpn_1500e_icdar2015_20211022-daefb6ed.pth' + }, + 'FCE_CTW_DCNv2': { + 'config': + 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py', + 'ckpt': + 'fcenet/' + + 'fcenet_r50dcnv2_fpn_1500e_ctw1500_20211022-e326d7ec.pth' + }, + 'MaskRCNN_CTW': { + 'config': + 'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py', + 'ckpt': + 'maskrcnn/' + 'mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.pth' + }, + 'MaskRCNN_IC15': { + 'config': + 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py', + 'ckpt': + 'maskrcnn/' + 'mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.pth' + }, + 'MaskRCNN_IC17': { + 'config': + 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py', + 'ckpt': + 'maskrcnn/' + 'mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.pth' + }, + 'PANet_CTW': { + 'config': + 'panet/panet_r18_fpem_ffm_600e_ctw1500.py', + 'ckpt': + 'panet/' + 'panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.pth' + }, + 'PANet_IC15': { + 'config': + 'panet/panet_r18_fpem_ffm_600e_icdar2015.py', + 'ckpt': + 'panet/' + 'panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth' + }, + 'PS_CTW': { + 'config': 'psenet/psenet_r50_fpnf_600e_ctw1500.py', + 'ckpt': + 'psenet/psenet_r50_fpnf_600e_ctw1500_20210401-216fed50.pth' + }, + 'PS_IC15': { + 'config': + 'psenet/psenet_r50_fpnf_600e_icdar2015.py', + 'ckpt': + 'psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth' + }, + 'TextSnake': { + 'config': + 'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py', + 'ckpt': + 'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth' + }, + 'Tesseract': {} + } + + textrecog_models = { + 'CRNN': { + 'config': 'crnn/crnn_academic_dataset.py', + 'ckpt': 'crnn/crnn_academic-a723a1c5.pth' + }, + 'SAR': { + 'config': 'sar/sar_r31_parallel_decoder_academic.py', + 'ckpt': 'sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth' + }, + 'SAR_CN': { + 'config': + 'sar/sar_r31_parallel_decoder_chinese.py', + 'ckpt': + 'sar/sar_r31_parallel_decoder_chineseocr_20210507-b4be8214.pth' + }, + 'NRTR_1/16-1/8': { + 'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py', + 'ckpt': + 'nrtr/nrtr_r31_1by16_1by8_academic_20211124-f60cebf4.pth' + }, + 'NRTR_1/8-1/4': { + 'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py', + 'ckpt': + 'nrtr/nrtr_r31_1by8_1by4_academic_20211123-e1fdb322.pth' + }, + 'RobustScanner': { + 'config': 'robust_scanner/robustscanner_r31_academic.py', + 'ckpt': 'robustscanner/robustscanner_r31_academic-5f05874f.pth' + }, + 'SATRN': { + 'config': 'satrn/satrn_academic.py', + 'ckpt': 'satrn/satrn_academic_20211009-cb8b1580.pth' + }, + 'SATRN_sm': { + 'config': 'satrn/satrn_small.py', + 'ckpt': 'satrn/satrn_small_20211009-2cf13355.pth' + }, + 'ABINet': { + 'config': 'abinet/abinet_academic.py', + 'ckpt': 'abinet/abinet_academic-f718abf6.pth' + }, + 'ABINet_Vision': { + 'config': 'abinet/abinet_vision_only_academic.py', + 'ckpt': 'abinet/abinet_vision_only_academic-e6b9ea89.pth' + }, + 'SEG': { + 'config': 'seg/seg_r31_1by16_fpnocr_academic.py', + 'ckpt': 'seg/seg_r31_1by16_fpnocr_academic-72235b11.pth' + }, + 'CRNN_TPS': { + 'config': 'tps/crnn_tps_academic_dataset.py', + 'ckpt': 'tps/crnn_tps_academic_dataset_20210510-d221a905.pth' + }, + 'Tesseract': {}, + 'MASTER': { + 'config': 'master/master_r31_12e_ST_MJ_SA.py', + 'ckpt': 'master/master_r31_12e_ST_MJ_SA-787edd36.pth' + } + } + + kie_models = { + 'SDMGR': { + 'config': 'sdmgr/sdmgr_unet16_60e_wildreceipt.py', + 'ckpt': + 'sdmgr/sdmgr_unet16_60e_wildreceipt_20220706-57c220a6.pth' + } + } + + self.td = det + self.tr = recog + self.kie = kie + self.device = device + if self.device is None: + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + + # Check if the det/recog model choice is valid + if self.td and self.td not in textdet_models: + raise ValueError(self.td, + 'is not a supported text detection algorthm') + elif self.tr and self.tr not in textrecog_models: + raise ValueError(self.tr, + 'is not a supported text recognition algorithm') + elif self.kie: + if self.kie not in kie_models: + raise ValueError( + self.kie, 'is not a supported key information extraction' + ' algorithm') + elif not (self.td and self.tr): + raise NotImplementedError( + self.kie, 'has to run together' + ' with text detection and recognition algorithms.') + + self.detect_model = None + if self.td and self.td == 'Tesseract': + if tesserocr is None: + raise ImportError('Please install tesserocr first. ' + 'Check out the installation guide at ' + 'https://github.com/sirfz/tesserocr') + self.detect_model = 'Tesseract_det' + elif self.td: + # Build detection model + if not det_config: + det_config = os.path.join(config_dir, 'textdet/', + textdet_models[self.td]['config']) + if not det_ckpt: + det_ckpt = 'https://download.openmmlab.com/mmocr/textdet/' + \ + textdet_models[self.td]['ckpt'] + + self.detect_model = init_detector( + det_config, det_ckpt, device=self.device) + self.detect_model = revert_sync_batchnorm(self.detect_model) + + self.recog_model = None + if self.tr and self.tr == 'Tesseract': + if tesserocr is None: + raise ImportError('Please install tesserocr first. ' + 'Check out the installation guide at ' + 'https://github.com/sirfz/tesserocr') + self.recog_model = 'Tesseract_recog' + elif self.tr: + # Build recognition model + if not recog_config: + recog_config = os.path.join( + config_dir, 'textrecog/', + textrecog_models[self.tr]['config']) + if not recog_ckpt: + recog_ckpt = 'https://download.openmmlab.com/mmocr/' + \ + 'textrecog/' + textrecog_models[self.tr]['ckpt'] + + self.recog_model = init_detector( + recog_config, recog_ckpt, device=self.device) + self.recog_model = revert_sync_batchnorm(self.recog_model) + + self.kie_model = None + if self.kie: + # Build key information extraction model + if not kie_config: + kie_config = os.path.join(config_dir, 'kie/', + kie_models[self.kie]['config']) + if not kie_ckpt: + kie_ckpt = 'https://download.openmmlab.com/mmocr/' + \ + 'kie/' + kie_models[self.kie]['ckpt'] + + kie_cfg = Config.fromfile(kie_config) + self.kie_model = build_detector( + kie_cfg.model, test_cfg=kie_cfg.get('test_cfg')) + self.kie_model = revert_sync_batchnorm(self.kie_model) + self.kie_model.cfg = kie_cfg + load_checkpoint(self.kie_model, kie_ckpt, map_location=self.device) + + # Attribute check + for model in list(filter(None, [self.recog_model, self.detect_model])): + if hasattr(model, 'module'): + model = model.module + + @staticmethod + def get_tesserocr_api(): + """Get tesserocr api depending on different platform.""" + import subprocess + import sys + + if sys.platform == 'linux': + api = tesserocr.PyTessBaseAPI() + elif sys.platform == 'win32': + try: + p = subprocess.Popen( + 'where tesseract', stdout=subprocess.PIPE, shell=True) + s = p.communicate()[0].decode('utf-8').split('\\') + path = s[:-1] + ['tessdata'] + tessdata_path = '/'.join(path) + api = tesserocr.PyTessBaseAPI(path=tessdata_path) + except RuntimeError: + raise RuntimeError( + 'Please install tesseract first.\n Check out the' + ' installation guide at' + ' https://github.com/UB-Mannheim/tesseract/wiki') + else: + raise NotImplementedError + return api + + def tesseract_det_inference(self, imgs, **kwargs): + """Inference image(s) with the tesseract detector. + + Args: + imgs (ndarray or list[ndarray]): image(s) to inference. + + Returns: + result (dict): Predicted results. + """ + is_batch = True + if isinstance(imgs, np.ndarray): + is_batch = False + imgs = [imgs] + assert is_type_list(imgs, np.ndarray) + api = self.get_tesserocr_api() + + # Get detection result using tesseract + results = [] + for img in imgs: + image = Image.fromarray(img) + api.SetImage(image) + boxes = api.GetComponentImages(tesserocr.RIL.TEXTLINE, True) + boundaries = [] + for _, box, _, _ in boxes: + min_x = box['x'] + min_y = box['y'] + max_x = box['x'] + box['w'] + max_y = box['y'] + box['h'] + boundary = [ + min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y, 1.0 + ] + boundaries.append(boundary) + results.append({'boundary_result': boundaries}) + + # close tesserocr api + api.End() + + if not is_batch: + return results[0] + else: + return results + + def tesseract_recog_inference(self, imgs, **kwargs): + """Inference image(s) with the tesseract recognizer. + + Args: + imgs (ndarray or list[ndarray]): image(s) to inference. + + Returns: + result (dict): Predicted results. + """ + is_batch = True + if isinstance(imgs, np.ndarray): + is_batch = False + imgs = [imgs] + assert is_type_list(imgs, np.ndarray) + api = self.get_tesserocr_api() + + results = [] + for img in imgs: + image = Image.fromarray(img) + api.SetImage(image) + api.SetRectangle(0, 0, img.shape[1], img.shape[0]) + # Remove beginning and trailing spaces from Tesseract + text = api.GetUTF8Text().strip() + conf = api.MeanTextConf() / 100 + results.append({'text': text, 'score': conf}) + + # close tesserocr api + api.End() + + if not is_batch: + return results[0] + else: + return results + + def readtext(self, + img, + output=None, + details=False, + export=None, + export_format='json', + batch_mode=False, + recog_batch_size=0, + det_batch_size=0, + single_batch_size=0, + imshow=False, + print_result=False, + merge=False, + merge_xdist=20, + **kwargs): + args = locals().copy() + [args.pop(x, None) for x in ['kwargs', 'self']] + args = Namespace(**args) + + # Input and output arguments processing + self._args_processing(args) + self.args = args + + pp_result = None + + # Send args and models to the MMOCR model inference API + # and call post-processing functions for the output + if self.detect_model and self.recog_model: + det_recog_result = self.det_recog_kie_inference( + self.detect_model, self.recog_model, kie_model=self.kie_model) + pp_result = self.det_recog_pp(det_recog_result) + else: + for model in list( + filter(None, [self.recog_model, self.detect_model])): + result = self.single_inference(model, args.arrays, + args.batch_mode, + args.single_batch_size) + pp_result = self.single_pp(result, model) + + return pp_result + + # Post processing function for end2end ocr + def det_recog_pp(self, result): + final_results = [] + args = self.args + for arr, output, export, det_recog_result in zip( + args.arrays, args.output, args.export, result): + if output or args.imshow: + if self.kie_model: + res_img = det_recog_show_result(arr, det_recog_result) + else: + res_img = det_recog_show_result( + arr, det_recog_result, out_file=output) + if args.imshow and not self.kie_model: + dbnet_cv.imshow(res_img, 'inference results') + if not args.details: + simple_res = {} + simple_res['filename'] = det_recog_result['filename'] + simple_res['text'] = [ + x['text'] for x in det_recog_result['result'] + ] + final_result = simple_res + else: + final_result = det_recog_result + if export: + dbnet_cv.dump(final_result, export, indent=4) + if args.print_result: + print(final_result, end='\n\n') + final_results.append(final_result) + return final_results + + # Post processing function for separate det/recog inference + def single_pp(self, result, model): + for arr, output, export, res in zip(self.args.arrays, self.args.output, + self.args.export, result): + if export: + dbnet_cv.dump(res, export, indent=4) + if output or self.args.imshow: + if model == 'Tesseract_det': + res_img = TextDetectorMixin(show_score=False).show_result( + arr, res, out_file=output) + elif model == 'Tesseract_recog': + res_img = BaseRecognizer.show_result( + arr, res, out_file=output) + else: + res_img = model.show_result(arr, res, out_file=output) + if self.args.imshow: + dbnet_cv.imshow(res_img, 'inference results') + if self.args.print_result: + print(res, end='\n\n') + return result + + def generate_kie_labels(self, result, boxes, class_list): + idx_to_cls = {} + if class_list is not None: + for line in list_from_file(class_list): + class_idx, class_label = line.strip().split() + idx_to_cls[class_idx] = class_label + + max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1) + node_pred_label = max_idx.numpy().tolist() + node_pred_score = max_value.numpy().tolist() + labels = [] + for i in range(len(boxes)): + pred_label = str(node_pred_label[i]) + if pred_label in idx_to_cls: + pred_label = idx_to_cls[pred_label] + pred_score = node_pred_score[i] + labels.append((pred_label, pred_score)) + return labels + + def visualize_kie_output(self, + model, + data, + result, + out_file=None, + show=False): + """Visualizes KIE output.""" + img_tensor = data['img'].data + img_meta = data['img_metas'].data + gt_bboxes = data['gt_bboxes'].data.numpy().tolist() + if img_tensor.dtype == torch.uint8: + # The img tensor is the raw input not being normalized + # (For SDMGR non-visual) + img = img_tensor.cpu().numpy().transpose(1, 2, 0) + else: + img = tensor2imgs( + img_tensor.unsqueeze(0), **img_meta.get('img_norm_cfg', {}))[0] + h, w, _ = img_meta.get('img_shape', img.shape) + img_show = img[:h, :w, :] + model.show_result( + img_show, result, gt_bboxes, show=show, out_file=out_file) + + # End2end ocr inference pipeline + def det_recog_kie_inference(self, det_model, recog_model, kie_model=None): + end2end_res = [] + # Find bounding boxes in the images (text detection) + det_result = self.single_inference(det_model, self.args.arrays, + self.args.batch_mode, + self.args.det_batch_size) + bboxes_list = [res['boundary_result'] for res in det_result] + + if kie_model: + kie_dataset = KIEDataset( + dict_file=kie_model.cfg.data.test.dict_file) + + # For each bounding box, the image is cropped and + # sent to the recognition model either one by one + # or all together depending on the batch_mode + for filename, arr, bboxes, out_file in zip(self.args.filenames, + self.args.arrays, + bboxes_list, + self.args.output): + img_e2e_res = {} + img_e2e_res['filename'] = filename + img_e2e_res['result'] = [] + box_imgs = [] + for bbox in bboxes: + box_res = {} + box_res['box'] = [round(x) for x in bbox[:-1]] + box_res['box_score'] = float(bbox[-1]) + box = bbox[:8] + if len(bbox) > 9: + min_x = min(bbox[0:-1:2]) + min_y = min(bbox[1:-1:2]) + max_x = max(bbox[0:-1:2]) + max_y = max(bbox[1:-1:2]) + box = [ + min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y + ] + box_img = crop_img(arr, box) + if self.args.batch_mode: + box_imgs.append(box_img) + else: + if recog_model == 'Tesseract_recog': + recog_result = self.single_inference( + recog_model, box_img, batch_mode=True) + else: + recog_result = model_inference(recog_model, box_img) + text = recog_result['text'] + text_score = recog_result['score'] + if isinstance(text_score, list): + text_score = sum(text_score) / max(1, len(text)) + box_res['text'] = text + box_res['text_score'] = text_score + img_e2e_res['result'].append(box_res) + + if self.args.batch_mode: + recog_results = self.single_inference( + recog_model, box_imgs, True, self.args.recog_batch_size) + for i, recog_result in enumerate(recog_results): + text = recog_result['text'] + text_score = recog_result['score'] + if isinstance(text_score, (list, tuple)): + text_score = sum(text_score) / max(1, len(text)) + img_e2e_res['result'][i]['text'] = text + img_e2e_res['result'][i]['text_score'] = text_score + + if self.args.merge: + img_e2e_res['result'] = stitch_boxes_into_lines( + img_e2e_res['result'], self.args.merge_xdist, 0.5) + + if kie_model: + annotations = copy.deepcopy(img_e2e_res['result']) + # Customized for kie_dataset, which + # assumes that boxes are represented by only 4 points + for i, ann in enumerate(annotations): + min_x = min(ann['box'][::2]) + min_y = min(ann['box'][1::2]) + max_x = max(ann['box'][::2]) + max_y = max(ann['box'][1::2]) + annotations[i]['box'] = [ + min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y + ] + ann_info = kie_dataset._parse_anno_info(annotations) + ann_info['ori_bboxes'] = ann_info.get('ori_bboxes', + ann_info['bboxes']) + ann_info['gt_bboxes'] = ann_info.get('gt_bboxes', + ann_info['bboxes']) + kie_result, data = model_inference( + kie_model, + arr, + ann=ann_info, + return_data=True, + batch_mode=self.args.batch_mode) + # visualize KIE results + self.visualize_kie_output( + kie_model, + data, + kie_result, + out_file=out_file, + show=self.args.imshow) + gt_bboxes = data['gt_bboxes'].data.numpy().tolist() + labels = self.generate_kie_labels(kie_result, gt_bboxes, + kie_model.class_list) + for i in range(len(gt_bboxes)): + img_e2e_res['result'][i]['label'] = labels[i][0] + img_e2e_res['result'][i]['label_score'] = labels[i][1] + + end2end_res.append(img_e2e_res) + return end2end_res + + # Separate det/recog inference pipeline + def single_inference(self, model, arrays, batch_mode, batch_size=0): + + def inference(m, a, **kwargs): + if model == 'Tesseract_det': + return self.tesseract_det_inference(a) + elif model == 'Tesseract_recog': + return self.tesseract_recog_inference(a) + else: + return model_inference(m, a, **kwargs) + + result = [] + if batch_mode: + if batch_size == 0: + result = inference(model, arrays, batch_mode=True) + else: + n = batch_size + arr_chunks = [ + arrays[i:i + n] for i in range(0, len(arrays), n) + ] + for chunk in arr_chunks: + result.extend(inference(model, chunk, batch_mode=True)) + else: + for arr in arrays: + result.append(inference(model, arr, batch_mode=False)) + return result + + # Arguments pre-processing function + def _args_processing(self, args): + # Check if the input is a list/tuple that + # contains only np arrays or strings + if isinstance(args.img, (list, tuple)): + img_list = args.img + if not all([isinstance(x, (np.ndarray, str)) for x in args.img]): + raise AssertionError('Images must be strings or numpy arrays') + + # Create a list of the images + if isinstance(args.img, str): + img_path = Path(args.img) + if img_path.is_dir(): + img_list = [str(x) for x in img_path.glob('*')] + else: + img_list = [str(img_path)] + elif isinstance(args.img, np.ndarray): + img_list = [args.img] + + # Read all image(s) in advance to reduce wasted time + # re-reading the images for visualization output + args.arrays = [dbnet_cv.imread(x) for x in img_list] + + # Create a list of filenames (used for output images and result files) + if isinstance(img_list[0], str): + args.filenames = [str(Path(x).stem) for x in img_list] + else: + args.filenames = [str(x) for x in range(len(img_list))] + + # If given an output argument, create a list of output image filenames + num_res = len(img_list) + if args.output: + output_path = Path(args.output) + if output_path.is_dir(): + args.output = [ + str(output_path / f'out_{x}.png') for x in args.filenames + ] + else: + args.output = [str(args.output)] + if args.batch_mode: + raise AssertionError('Output of multiple images inference' + ' must be a directory') + else: + args.output = [None] * num_res + + # If given an export argument, create a list of + # result filenames for each image + if args.export: + export_path = Path(args.export) + args.export = [ + str(export_path / f'out_{x}.{args.export_format}') + for x in args.filenames + ] + else: + args.export = [None] * num_res + + return args + + +# Create an inference pipeline with parsed arguments +def main(): + args = parse_args() + ocr = MMOCR(**vars(args)) + ocr.readtext(**vars(args)) + + +if __name__ == '__main__': + main() diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/setup_env.py b/cv/ocr/dbnet/pytorch/dbnet/utils/setup_env.py new file mode 100755 index 00000000..21def2f0 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/utils/setup_env.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import platform +import warnings + +import cv2 +import torch.multiprocessing as mp + + +def setup_multi_processes(cfg): + """Setup multi-processing environment variables.""" + # set multi-process start method as `fork` to speed up the training + if platform.system() != 'Windows': + mp_start_method = cfg.get('mp_start_method', 'fork') + current_method = mp.get_start_method(allow_none=True) + if current_method is not None and current_method != mp_start_method: + warnings.warn( + f'Multi-processing start method `{mp_start_method}` is ' + f'different from the previous setting `{current_method}`.' + f'It will be force set to `{mp_start_method}`. You can change ' + f'this behavior by changing `mp_start_method` in your config.') + mp.set_start_method(mp_start_method, force=True) + + # disable opencv multithreading to avoid system being overloaded + opencv_num_threads = cfg.get('opencv_num_threads', 0) + cv2.setNumThreads(opencv_num_threads) + + # setup OMP threads + # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa + if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: + omp_num_threads = 1 + warnings.warn( + f'Setting OMP_NUM_THREADS environment variable for each process ' + f'to be {omp_num_threads} in default, to avoid your system being ' + f'overloaded, please further tune the variable for optimal ' + f'performance in your application as needed.') + os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) + + # setup MKL threads + if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: + mkl_num_threads = 1 + warnings.warn( + f'Setting MKL_NUM_THREADS environment variable for each process ' + f'to be {mkl_num_threads} in default, to avoid your system being ' + f'overloaded, please further tune the variable for optimal ' + f'performance in your application as needed.') + os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) diff --git a/cv/ocr/dbnet/pytorch/dbnet/utils/string_util.py b/cv/ocr/dbnet/pytorch/dbnet/utils/string_util.py new file mode 100755 index 00000000..5a8946ee --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/utils/string_util.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +class StringStrip: + """Removing the leading and/or the trailing characters based on the string + argument passed. + + Args: + strip (bool): Whether remove characters from both left and right of + the string. Default: True. + strip_pos (str): Which position for removing, can be one of + ('both', 'left', 'right'), Default: 'both'. + strip_str (str|None): A string specifying the set of characters + to be removed from the left and right part of the string. + If None, all leading and trailing whitespaces + are removed from the string. Default: None. + """ + + def __init__(self, strip=True, strip_pos='both', strip_str=None): + assert isinstance(strip, bool) + assert strip_pos in ('both', 'left', 'right') + assert strip_str is None or isinstance(strip_str, str) + + self.strip = strip + self.strip_pos = strip_pos + self.strip_str = strip_str + + def __call__(self, in_str): + + if not self.strip: + return in_str + + if self.strip_pos == 'left': + return in_str.lstrip(self.strip_str) + elif self.strip_pos == 'right': + return in_str.rstrip(self.strip_str) + else: + return in_str.strip(self.strip_str) diff --git a/cv/ocr/dbnet/pytorch/dbnet/version.py b/cv/ocr/dbnet/pytorch/dbnet/version.py new file mode 100755 index 00000000..3626dcd1 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet/version.py @@ -0,0 +1,4 @@ +# Copyright (c) Open-MMLab. All rights reserved. + +__version__ = '0.6.0' +short_version = __version__ diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/__init__.py new file mode 100755 index 00000000..65b6016c --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# flake8: noqa +# from .arraymisc import * +from .fileio import * +from .image import * +from .utils import * +from .version import * +# from .video import * +# from .visualization import * + +# The following modules are not imported to this level, so dbnet_cv may be used +# without PyTorch. +# - runner +# - parallel +# - op +# - device diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/__init__.py new file mode 100755 index 00000000..d2e8f099 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/__init__.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# from .alexnet import AlexNet +# yapf: disable +# from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, +# PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS, +# ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule, +# ConvTranspose2d, ConvTranspose3d, ConvWS2d, +# DepthwiseSeparableConvModule, GeneralizedAttention, +# HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d, +# NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish, +# build_activation_layer, build_conv_layer, +# build_norm_layer, build_padding_layer, build_plugin_layer, +# build_upsample_layer, conv_ws_2d, is_norm) +from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, + PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS, + Conv2d, Conv3d, ConvModule, + ConvTranspose2d, ConvTranspose3d, Linear, MaxPool2d, MaxPool3d, + + build_activation_layer, build_conv_layer, + build_norm_layer, build_padding_layer, build_plugin_layer, + build_upsample_layer, is_norm) +from .builder import MODELS, build_model_from_cfg +# yapf: enable +from .resnet import ResNet, make_res_layer +from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit, + NormalInit, PretrainedInit, TruncNormalInit, UniformInit, + XavierInit, bias_init_with_prob, caffe2_xavier_init, + constant_init, fuse_conv_bn, get_model_complexity_info, + initialize, kaiming_init, normal_init, trunc_normal_init, + uniform_init, xavier_init) +# from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit, +# NormalInit, PretrainedInit, TruncNormalInit, UniformInit, +# XavierInit, bias_init_with_prob, caffe2_xavier_init, +# constant_init, fuse_conv_bn, get_model_complexity_info, +# initialize, kaiming_init, normal_init, trunc_normal_init, +# uniform_init, xavier_init) +# from .vgg import VGG, make_vgg_layer + +# __all__ = [ +# 'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer', +# 'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init', +# 'uniform_init', 'kaiming_init', 'caffe2_xavier_init', +# 'bias_init_with_prob', 'ConvModule', 'build_activation_layer', +# 'build_conv_layer', 'build_norm_layer', 'build_padding_layer', +# 'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d', +# 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish', +# 'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', +# 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', +# 'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d', +# 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d', +# 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', +# 'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit', +# 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit', +# 'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg' +# ] diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/__init__.py new file mode 100755 index 00000000..d1e17d38 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .activation import build_activation_layer + +from .conv import build_conv_layer +from .conv_module import ConvModule +from .conv2d_adaptive_padding import Conv2dAdaptivePadding +from .norm import build_norm_layer, is_norm +from .padding import build_padding_layer +from .plugin import build_plugin_layer +from .drop import Dropout, DropPath +from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, + PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS) +from .hsigmoid import HSigmoid +from .hswish import HSwish +from .upsample import build_upsample_layer +from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d, + Linear, MaxPool2d, MaxPool3d) + diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/activation.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/activation.py new file mode 100755 index 00000000..d791a04c --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/activation.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from dbnet_cv.utils import TORCH_VERSION, build_from_cfg, digit_version +from .registry import ACTIVATION_LAYERS + +for module in [ + nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU, + nn.Sigmoid, nn.Tanh +]: + ACTIVATION_LAYERS.register_module(module=module) + + +@ACTIVATION_LAYERS.register_module(name='Clip') +@ACTIVATION_LAYERS.register_module() +class Clamp(nn.Module): + """Clamp activation layer. + + This activation function is to clamp the feature map value within + :math:`[min, max]`. More details can be found in ``torch.clamp()``. + + Args: + min (Number | optional): Lower-bound of the range to be clamped to. + Default to -1. + max (Number | optional): Upper-bound of the range to be clamped to. + Default to 1. + """ + + def __init__(self, min=-1., max=1.): + super().__init__() + self.min = min + self.max = max + + def forward(self, x): + """Forward function. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: Clamped tensor. + """ + return torch.clamp(x, min=self.min, max=self.max) + + +class GELU(nn.Module): + r"""Applies the Gaussian Error Linear Units function: + + .. math:: + \text{GELU}(x) = x * \Phi(x) + where :math:`\Phi(x)` is the Cumulative Distribution Function for + Gaussian Distribution. + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: scripts/activation_images/GELU.png + + Examples:: + + >>> m = nn.GELU() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def forward(self, input): + return F.gelu(input) + + +if (TORCH_VERSION == 'parrots' + or digit_version(TORCH_VERSION) < digit_version('1.4')): + ACTIVATION_LAYERS.register_module(module=GELU) +else: + ACTIVATION_LAYERS.register_module(module=nn.GELU) + + +def build_activation_layer(cfg): + """Build activation layer. + + Args: + cfg (dict): The activation layer config, which should contain: + + - type (str): Layer type. + - layer args: Args needed to instantiate an activation layer. + + Returns: + nn.Module: Created activation layer. + """ + return build_from_cfg(cfg, ACTIVATION_LAYERS) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv.py new file mode 100755 index 00000000..f6c35fd7 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import nn + +from .registry import CONV_LAYERS + +CONV_LAYERS.register_module('Conv1d', module=nn.Conv1d) +CONV_LAYERS.register_module('Conv2d', module=nn.Conv2d) +CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d) +CONV_LAYERS.register_module('Conv', module=nn.Conv2d) + + +def build_conv_layer(cfg, *args, **kwargs): + """Build convolution layer. + + Args: + cfg (None or dict): The conv layer config, which should contain: + - type (str): Layer type. + - layer args: Args needed to instantiate an conv layer. + args (argument list): Arguments passed to the `__init__` + method of the corresponding conv layer. + kwargs (keyword arguments): Keyword arguments passed to the `__init__` + method of the corresponding conv layer. + + Returns: + nn.Module: Created conv layer. + """ + if cfg is None: + cfg_ = dict(type='Conv2d') + else: + if not isinstance(cfg, dict): + raise TypeError('cfg must be a dict') + if 'type' not in cfg: + raise KeyError('the cfg dict must contain the key "type"') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type not in CONV_LAYERS: + raise KeyError(f'Unrecognized layer type {layer_type}') + else: + conv_layer = CONV_LAYERS.get(layer_type) + + layer = conv_layer(*args, **kwargs, **cfg_) + + return layer diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv2d_adaptive_padding.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv2d_adaptive_padding.py new file mode 100755 index 00000000..b45e758a --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv2d_adaptive_padding.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +from torch import nn +from torch.nn import functional as F + +from .registry import CONV_LAYERS + + +@CONV_LAYERS.register_module() +class Conv2dAdaptivePadding(nn.Conv2d): + """Implementation of 2D convolution in tensorflow with `padding` as "same", + which applies padding to input (if needed) so that input image gets fully + covered by filter and stride you specified. For stride 1, this will ensure + that output image size is same as input. For stride of 2, output dimensions + will be half, for example. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. + Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True): + super().__init__(in_channels, out_channels, kernel_size, stride, 0, + dilation, groups, bias) + + def forward(self, x): + img_h, img_w = x.size()[-2:] + kernel_h, kernel_w = self.weight.size()[-2:] + stride_h, stride_w = self.stride + output_h = math.ceil(img_h / stride_h) + output_w = math.ceil(img_w / stride_w) + pad_h = ( + max((output_h - 1) * self.stride[0] + + (kernel_h - 1) * self.dilation[0] + 1 - img_h, 0)) + pad_w = ( + max((output_w - 1) * self.stride[1] + + (kernel_w - 1) * self.dilation[1] + 1 - img_w, 0)) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2 + ]) + return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv_module.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv_module.py new file mode 100755 index 00000000..5c27f6e3 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/conv_module.py @@ -0,0 +1,206 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn + +from dbnet_cv.utils import _BatchNorm, _InstanceNorm +from ..utils import constant_init, kaiming_init +from .activation import build_activation_layer +from .conv import build_conv_layer +from .norm import build_norm_layer +from .padding import build_padding_layer +from .registry import PLUGIN_LAYERS + + +@PLUGIN_LAYERS.register_module() +class ConvModule(nn.Module): + """A conv block that bundles conv/norm/activation layers. + + This block simplifies the usage of convolution layers, which are commonly + used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + It is based upon three build methods: `build_conv_layer()`, + `build_norm_layer()` and `build_activation_layer()`. + + Besides, we add some additional features in this module. + 1. Automatically set `bias` of the conv layer. + 2. Spectral norm is supported. + 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only + supports zero and circular padding, and we add "reflect" padding mode. + + Args: + in_channels (int): Number of channels in the input feature map. + Same as that in ``nn._ConvNd``. + out_channels (int): Number of channels produced by the convolution. + Same as that in ``nn._ConvNd``. + kernel_size (int | tuple[int]): Size of the convolving kernel. + Same as that in ``nn._ConvNd``. + stride (int | tuple[int]): Stride of the convolution. + Same as that in ``nn._ConvNd``. + padding (int | tuple[int]): Zero-padding added to both sides of + the input. Same as that in ``nn._ConvNd``. + dilation (int | tuple[int]): Spacing between kernel elements. + Same as that in ``nn._ConvNd``. + groups (int): Number of blocked connections from input channels to + output channels. Same as that in ``nn._ConvNd``. + bias (bool | str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise + False. Default: "auto". + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + inplace (bool): Whether to use inplace mode for activation. + Default: True. + with_spectral_norm (bool): Whether use spectral norm in conv module. + Default: False. + padding_mode (str): If the `padding_mode` has not been supported by + current `Conv2d` in PyTorch, we will use our own padding layer + instead. Currently, we support ['zeros', 'circular'] with official + implementation and ['reflect'] with our own implementation. + Default: 'zeros'. + order (tuple[str]): The order of conv/norm/activation layers. It is a + sequence of "conv", "norm" and "act". Common examples are + ("conv", "norm", "act") and ("act", "conv", "norm"). + Default: ('conv', 'norm', 'act'). + """ + + _abbr_ = 'conv_block' + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias='auto', + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + inplace=True, + with_spectral_norm=False, + padding_mode='zeros', + order=('conv', 'norm', 'act')): + super().__init__() + assert conv_cfg is None or isinstance(conv_cfg, dict) + assert norm_cfg is None or isinstance(norm_cfg, dict) + assert act_cfg is None or isinstance(act_cfg, dict) + official_padding_mode = ['zeros', 'circular'] + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.inplace = inplace + self.with_spectral_norm = with_spectral_norm + self.with_explicit_padding = padding_mode not in official_padding_mode + self.order = order + assert isinstance(self.order, tuple) and len(self.order) == 3 + assert set(order) == {'conv', 'norm', 'act'} + + self.with_norm = norm_cfg is not None + self.with_activation = act_cfg is not None + # if the conv layer is before a norm layer, bias is unnecessary. + if bias == 'auto': + bias = not self.with_norm + self.with_bias = bias + + if self.with_explicit_padding: + pad_cfg = dict(type=padding_mode) + self.padding_layer = build_padding_layer(pad_cfg, padding) + + # reset padding to 0 for conv module + conv_padding = 0 if self.with_explicit_padding else padding + # build convolution layer + self.conv = build_conv_layer( + conv_cfg, + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=conv_padding, + dilation=dilation, + groups=groups, + bias=bias) + # export the attributes of self.conv to a higher level for convenience + self.in_channels = self.conv.in_channels + self.out_channels = self.conv.out_channels + self.kernel_size = self.conv.kernel_size + self.stride = self.conv.stride + self.padding = padding + self.dilation = self.conv.dilation + self.transposed = self.conv.transposed + self.output_padding = self.conv.output_padding + self.groups = self.conv.groups + + if self.with_spectral_norm: + self.conv = nn.utils.spectral_norm(self.conv) + + # build normalization layers + if self.with_norm: + # norm layer is after conv layer + if order.index('norm') > order.index('conv'): + norm_channels = out_channels + else: + norm_channels = in_channels + self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels) + self.add_module(self.norm_name, norm) + if self.with_bias: + if isinstance(norm, (_BatchNorm, _InstanceNorm)): + warnings.warn( + 'Unnecessary conv bias before batch/instance norm') + else: + self.norm_name = None + + # build activation layer + if self.with_activation: + act_cfg_ = act_cfg.copy() + # nn.Tanh has no 'inplace' argument + if act_cfg_['type'] not in [ + 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish', 'GELU' + ]: + act_cfg_.setdefault('inplace', inplace) + self.activate = build_activation_layer(act_cfg_) + + # Use msra init by default + self.init_weights() + + @property + def norm(self): + if self.norm_name: + return getattr(self, self.norm_name) + else: + return None + + def init_weights(self): + # 1. It is mainly for customized conv layers with their own + # initialization manners by calling their own ``init_weights()``, + # and we do not want ConvModule to override the initialization. + # 2. For customized conv layers without their own initialization + # manners (that is, they don't have their own ``init_weights()``) + # and PyTorch's conv layers, they will be initialized by + # this method with default ``kaiming_init``. + # Note: For PyTorch's conv layers, they will be overwritten by our + # initialization implementation using default ``kaiming_init``. + if not hasattr(self.conv, 'init_weights'): + if self.with_activation and self.act_cfg['type'] == 'LeakyReLU': + nonlinearity = 'leaky_relu' + a = self.act_cfg.get('negative_slope', 0.01) + else: + nonlinearity = 'relu' + a = 0 + kaiming_init(self.conv, a=a, nonlinearity=nonlinearity) + if self.with_norm: + constant_init(self.norm, 1, bias=0) + + def forward(self, x, activate=True, norm=True): + for layer in self.order: + if layer == 'conv': + if self.with_explicit_padding: + x = self.padding_layer(x) + x = self.conv(x) + elif layer == 'norm' and norm and self.with_norm: + x = self.norm(x) + elif layer == 'act' and activate and self.with_activation: + x = self.activate(x) + return x diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/drop.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/drop.py new file mode 100755 index 00000000..ff29bac8 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/drop.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn + +from dbnet_cv import build_from_cfg +from .registry import DROPOUT_LAYERS + + +def drop_path(x: torch.Tensor, + drop_prob: float = 0., + training: bool = False) -> torch.Tensor: + """Drop paths (Stochastic Depth) per sample (when applied in main path of + residual blocks). + We follow the implementation + https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501 + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + # handle tensors with different dimensions, not just 4D tensors. + shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + output = x.div(keep_prob) * random_tensor.floor() + return output + + +@DROPOUT_LAYERS.register_module() +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of + residual blocks). + We follow the implementation + https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501 + Args: + drop_prob (float): Probability of the path to be zeroed. Default: 0.1 + """ + + def __init__(self, drop_prob: float = 0.1): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return drop_path(x, self.drop_prob, self.training) + + +@DROPOUT_LAYERS.register_module() +class Dropout(nn.Dropout): + """A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of + ``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with + ``DropPath`` + Args: + drop_prob (float): Probability of the elements to be + zeroed. Default: 0.5. + inplace (bool): Do the operation inplace or not. Default: False. + """ + + def __init__(self, drop_prob: float = 0.5, inplace: bool = False): + super().__init__(p=drop_prob, inplace=inplace) + + +def build_dropout(cfg: Dict, default_args: Optional[Dict] = None) -> Any: + """Builder for drop out layers.""" + return build_from_cfg(cfg, DROPOUT_LAYERS, default_args) \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/hsigmoid.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/hsigmoid.py new file mode 100755 index 00000000..1854517c --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/hsigmoid.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch +import torch.nn as nn + +from .registry import ACTIVATION_LAYERS + + +@ACTIVATION_LAYERS.register_module() +class HSigmoid(nn.Module): + """Hard Sigmoid Module. Apply the hard sigmoid function: + Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value) + Default: Hsigmoid(x) = min(max((x + 3) / 6, 0), 1) + Note: + In MMCV v1.4.4, we modified the default value of args to align with + PyTorch official. + Args: + bias (float): Bias of the input feature map. Default: 3.0. + divisor (float): Divisor of the input feature map. Default: 6.0. + min_value (float): Lower bound value. Default: 0.0. + max_value (float): Upper bound value. Default: 1.0. + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + bias: float = 3.0, + divisor: float = 6.0, + min_value: float = 0.0, + max_value: float = 1.0): + super().__init__() + warnings.warn( + 'In MMCV v1.4.4, we modified the default value of args to align ' + 'with PyTorch official. Previous Implementation: ' + 'Hsigmoid(x) = min(max((x + 1) / 2, 0), 1). ' + 'Current Implementation: ' + 'Hsigmoid(x) = min(max((x + 3) / 6, 0), 1).') + self.bias = bias + self.divisor = divisor + assert self.divisor != 0 + self.min_value = min_value + self.max_value = max_value + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = (x + self.bias) / self.divisor + + return x.clamp_(self.min_value, self.max_value) \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/hswish.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/hswish.py new file mode 100755 index 00000000..1d3ee70e --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/hswish.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from dbnet_cv.utils import TORCH_VERSION, digit_version +from .registry import ACTIVATION_LAYERS + + +class HSwish(nn.Module): + """Hard Swish Module. + This module applies the hard swish function: + .. math:: + Hswish(x) = x * ReLU6(x + 3) / 6 + Args: + inplace (bool): can optionally do the operation in-place. + Default: False. + Returns: + Tensor: The output tensor. + """ + + def __init__(self, inplace: bool = False): + super().__init__() + self.act = nn.ReLU6(inplace) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * self.act(x + 3) / 6 + + +if (TORCH_VERSION == 'parrots' + or digit_version(TORCH_VERSION) < digit_version('1.7')): + # Hardswish is not supported when PyTorch version < 1.6. + # And Hardswish in PyTorch 1.6 does not support inplace. + ACTIVATION_LAYERS.register_module(module=HSwish) +else: + ACTIVATION_LAYERS.register_module(module=nn.Hardswish, name='HSwish') \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/norm.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/norm.py new file mode 100755 index 00000000..9516d812 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/norm.py @@ -0,0 +1,144 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect + +import torch.nn as nn + +from dbnet_cv.utils import is_tuple_of +from dbnet_cv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm, _InstanceNorm +from .registry import NORM_LAYERS + +NORM_LAYERS.register_module('BN', module=nn.BatchNorm2d) +NORM_LAYERS.register_module('BN1d', module=nn.BatchNorm1d) +NORM_LAYERS.register_module('BN2d', module=nn.BatchNorm2d) +NORM_LAYERS.register_module('BN3d', module=nn.BatchNorm3d) +NORM_LAYERS.register_module('SyncBN', module=SyncBatchNorm) +NORM_LAYERS.register_module('GN', module=nn.GroupNorm) +NORM_LAYERS.register_module('LN', module=nn.LayerNorm) +NORM_LAYERS.register_module('IN', module=nn.InstanceNorm2d) +NORM_LAYERS.register_module('IN1d', module=nn.InstanceNorm1d) +NORM_LAYERS.register_module('IN2d', module=nn.InstanceNorm2d) +NORM_LAYERS.register_module('IN3d', module=nn.InstanceNorm3d) + + +def infer_abbr(class_type): + """Infer abbreviation from the class name. + + When we build a norm layer with `build_norm_layer()`, we want to preserve + the norm type in variable names, e.g, self.bn1, self.gn. This method will + infer the abbreviation to map class types to abbreviations. + + Rule 1: If the class has the property "_abbr_", return the property. + Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or + InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and + "in" respectively. + Rule 3: If the class name contains "batch", "group", "layer" or "instance", + the abbreviation of this layer will be "bn", "gn", "ln" and "in" + respectively. + Rule 4: Otherwise, the abbreviation falls back to "norm". + + Args: + class_type (type): The norm layer type. + + Returns: + str: The inferred abbreviation. + """ + if not inspect.isclass(class_type): + raise TypeError( + f'class_type must be a type, but got {type(class_type)}') + if hasattr(class_type, '_abbr_'): + return class_type._abbr_ + if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN + return 'in' + elif issubclass(class_type, _BatchNorm): + return 'bn' + elif issubclass(class_type, nn.GroupNorm): + return 'gn' + elif issubclass(class_type, nn.LayerNorm): + return 'ln' + else: + class_name = class_type.__name__.lower() + if 'batch' in class_name: + return 'bn' + elif 'group' in class_name: + return 'gn' + elif 'layer' in class_name: + return 'ln' + elif 'instance' in class_name: + return 'in' + else: + return 'norm_layer' + + +def build_norm_layer(cfg, num_features, postfix=''): + """Build normalization layer. + + Args: + cfg (dict): The norm layer config, which should contain: + + - type (str): Layer type. + - layer args: Args needed to instantiate a norm layer. + - requires_grad (bool, optional): Whether stop gradient updates. + num_features (int): Number of input channels. + postfix (int | str): The postfix to be appended into norm abbreviation + to create named layer. + + Returns: + tuple[str, nn.Module]: The first element is the layer name consisting + of abbreviation and postfix, e.g., bn1, gn. The second element is the + created norm layer. + """ + if not isinstance(cfg, dict): + raise TypeError('cfg must be a dict') + if 'type' not in cfg: + raise KeyError('the cfg dict must contain the key "type"') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type not in NORM_LAYERS: + raise KeyError(f'Unrecognized norm type {layer_type}') + + norm_layer = NORM_LAYERS.get(layer_type) + abbr = infer_abbr(norm_layer) + + assert isinstance(postfix, (int, str)) + name = abbr + str(postfix) + + requires_grad = cfg_.pop('requires_grad', True) + cfg_.setdefault('eps', 1e-5) + if layer_type != 'GN': + layer = norm_layer(num_features, **cfg_) + if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'): + layer._specify_ddp_gpu_num(1) + else: + assert 'num_groups' in cfg_ + layer = norm_layer(num_channels=num_features, **cfg_) + + for param in layer.parameters(): + param.requires_grad = requires_grad + + return name, layer + + +def is_norm(layer, exclude=None): + """Check if a layer is a normalization layer. + + Args: + layer (nn.Module): The layer to be checked. + exclude (type | tuple[type]): Types to be excluded. + + Returns: + bool: Whether the layer is a norm layer. + """ + if exclude is not None: + if not isinstance(exclude, tuple): + exclude = (exclude, ) + if not is_tuple_of(exclude, type): + raise TypeError( + f'"exclude" must be either None or type or a tuple of types, ' + f'but got {type(exclude)}: {exclude}') + + if exclude and isinstance(layer, exclude): + return False + + all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm) + return isinstance(layer, all_norm_bases) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/padding.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/padding.py new file mode 100755 index 00000000..e4ac6b28 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/padding.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from .registry import PADDING_LAYERS + +PADDING_LAYERS.register_module('zero', module=nn.ZeroPad2d) +PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d) +PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d) + + +def build_padding_layer(cfg, *args, **kwargs): + """Build padding layer. + + Args: + cfg (None or dict): The padding layer config, which should contain: + - type (str): Layer type. + - layer args: Args needed to instantiate a padding layer. + + Returns: + nn.Module: Created padding layer. + """ + if not isinstance(cfg, dict): + raise TypeError('cfg must be a dict') + if 'type' not in cfg: + raise KeyError('the cfg dict must contain the key "type"') + + cfg_ = cfg.copy() + padding_type = cfg_.pop('type') + if padding_type not in PADDING_LAYERS: + raise KeyError(f'Unrecognized padding type {padding_type}.') + else: + padding_layer = PADDING_LAYERS.get(padding_type) + + layer = padding_layer(*args, **kwargs, **cfg_) + + return layer diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/plugin.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/plugin.py new file mode 100755 index 00000000..6aa13f43 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/plugin.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +import platform + +from .registry import PLUGIN_LAYERS + +if platform.system() == 'Windows': + import regex as re # type: ignore +else: + import re # type: ignore + + +def infer_abbr(class_type): + """Infer abbreviation from the class name. + + This method will infer the abbreviation to map class types to + abbreviations. + + Rule 1: If the class has the property "abbr", return the property. + Rule 2: Otherwise, the abbreviation falls back to snake case of class + name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``. + + Args: + class_type (type): The norm layer type. + + Returns: + str: The inferred abbreviation. + """ + + def camel2snack(word): + """Convert camel case word into snack case. + + Modified from `inflection lib + `_. + + Example:: + + >>> camel2snack("FancyBlock") + 'fancy_block' + """ + + word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word) + word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word) + word = word.replace('-', '_') + return word.lower() + + if not inspect.isclass(class_type): + raise TypeError( + f'class_type must be a type, but got {type(class_type)}') + if hasattr(class_type, '_abbr_'): + return class_type._abbr_ + else: + return camel2snack(class_type.__name__) + + +def build_plugin_layer(cfg, postfix='', **kwargs): + """Build plugin layer. + + Args: + cfg (None or dict): cfg should contain: + + - type (str): identify plugin layer type. + - layer args: args needed to instantiate a plugin layer. + postfix (int, str): appended into norm abbreviation to + create named layer. Default: ''. + + Returns: + tuple[str, nn.Module]: The first one is the concatenation of + abbreviation and postfix. The second is the created plugin layer. + """ + if not isinstance(cfg, dict): + raise TypeError('cfg must be a dict') + if 'type' not in cfg: + raise KeyError('the cfg dict must contain the key "type"') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type not in PLUGIN_LAYERS: + raise KeyError(f'Unrecognized plugin type {layer_type}') + + plugin_layer = PLUGIN_LAYERS.get(layer_type) + abbr = infer_abbr(plugin_layer) + + assert isinstance(postfix, (int, str)) + name = abbr + str(postfix) + + layer = plugin_layer(**kwargs, **cfg_) + + return name, layer diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/registry.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/registry.py new file mode 100755 index 00000000..b56fbe18 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/registry.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dbnet_cv.utils import Registry + +CONV_LAYERS = Registry('conv layer') +NORM_LAYERS = Registry('norm layer') +ACTIVATION_LAYERS = Registry('activation layer') +PADDING_LAYERS = Registry('padding layer') +UPSAMPLE_LAYERS = Registry('upsample layer') +PLUGIN_LAYERS = Registry('plugin layer') + +DROPOUT_LAYERS = Registry('drop out layers') +POSITIONAL_ENCODING = Registry('position encoding') +ATTENTION = Registry('attention') +FEEDFORWARD_NETWORK = Registry('feed-forward Network') +TRANSFORMER_LAYER = Registry('transformerLayer') +TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence') diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/upsample.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/upsample.py new file mode 100755 index 00000000..15e4febd --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/upsample.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F + +from ..utils import xavier_init +from .registry import UPSAMPLE_LAYERS + +UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample) +UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample) + + +@UPSAMPLE_LAYERS.register_module(name='pixel_shuffle') +class PixelShufflePack(nn.Module): + """Pixel Shuffle upsample layer. + + This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to + achieve a simple upsampling with pixel shuffle. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + scale_factor (int): Upsample ratio. + upsample_kernel (int): Kernel size of the conv layer to expand the + channels. + """ + + def __init__(self, in_channels, out_channels, scale_factor, + upsample_kernel): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.scale_factor = scale_factor + self.upsample_kernel = upsample_kernel + self.upsample_conv = nn.Conv2d( + self.in_channels, + self.out_channels * scale_factor * scale_factor, + self.upsample_kernel, + padding=(self.upsample_kernel - 1) // 2) + self.init_weights() + + def init_weights(self): + xavier_init(self.upsample_conv, distribution='uniform') + + def forward(self, x): + x = self.upsample_conv(x) + x = F.pixel_shuffle(x, self.scale_factor) + return x + + +def build_upsample_layer(cfg, *args, **kwargs): + """Build upsample layer. + + Args: + cfg (dict): The upsample layer config, which should contain: + + - type (str): Layer type. + - scale_factor (int): Upsample ratio, which is not applicable to + deconv. + - layer args: Args needed to instantiate a upsample layer. + args (argument list): Arguments passed to the ``__init__`` + method of the corresponding conv layer. + kwargs (keyword arguments): Keyword arguments passed to the + ``__init__`` method of the corresponding conv layer. + + Returns: + nn.Module: Created upsample layer. + """ + if not isinstance(cfg, dict): + raise TypeError(f'cfg must be a dict, but got {type(cfg)}') + if 'type' not in cfg: + raise KeyError( + f'the cfg dict must contain the key "type", but got {cfg}') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type not in UPSAMPLE_LAYERS: + raise KeyError(f'Unrecognized upsample type {layer_type}') + else: + upsample = UPSAMPLE_LAYERS.get(layer_type) + + if upsample is nn.Upsample: + cfg_['mode'] = layer_type + layer = upsample(*args, **kwargs, **cfg_) + return layer diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/wrappers.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/wrappers.py new file mode 100755 index 00000000..8aebf67b --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/bricks/wrappers.py @@ -0,0 +1,180 @@ +# Copyright (c) OpenMMLab. All rights reserved. +r"""Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/wrappers.py # noqa: E501 + +Wrap some nn modules to support empty tensor input. Currently, these wrappers +are mainly used in mask heads like fcn_mask_head and maskiou_heads since mask +heads are trained on only positive RoIs. +""" +import math + +import torch +import torch.nn as nn +from torch.nn.modules.utils import _pair, _triple + +from .registry import CONV_LAYERS, UPSAMPLE_LAYERS + +if torch.__version__ == 'parrots': + TORCH_VERSION = torch.__version__ +else: + # torch.__version__ could be 1.3.1+cu92, we only need the first two + # for comparison + TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2]) + + +def obsolete_torch_version(torch_version, version_threshold): + return torch_version == 'parrots' or torch_version <= version_threshold + + +class NewEmptyTensorOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, new_shape): + ctx.shape = x.shape + return x.new_empty(new_shape) + + @staticmethod + def backward(ctx, grad): + shape = ctx.shape + return NewEmptyTensorOp.apply(grad, shape), None + + +@CONV_LAYERS.register_module('Conv', force=True) +class Conv2d(nn.Conv2d): + + def forward(self, x): + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + out_shape = [x.shape[0], self.out_channels] + for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size, + self.padding, self.stride, self.dilation): + o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1 + out_shape.append(o) + empty = NewEmptyTensorOp.apply(x, out_shape) + if self.training: + # produce dummy gradient to avoid DDP warning. + dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 + return empty + dummy + else: + return empty + + return super().forward(x) + + +@CONV_LAYERS.register_module('Conv3d', force=True) +class Conv3d(nn.Conv3d): + + def forward(self, x): + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + out_shape = [x.shape[0], self.out_channels] + for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size, + self.padding, self.stride, self.dilation): + o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1 + out_shape.append(o) + empty = NewEmptyTensorOp.apply(x, out_shape) + if self.training: + # produce dummy gradient to avoid DDP warning. + dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 + return empty + dummy + else: + return empty + + return super().forward(x) + + +@CONV_LAYERS.register_module() +@CONV_LAYERS.register_module('deconv') +@UPSAMPLE_LAYERS.register_module('deconv', force=True) +class ConvTranspose2d(nn.ConvTranspose2d): + + def forward(self, x): + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + out_shape = [x.shape[0], self.out_channels] + for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size, + self.padding, self.stride, + self.dilation, self.output_padding): + out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op) + empty = NewEmptyTensorOp.apply(x, out_shape) + if self.training: + # produce dummy gradient to avoid DDP warning. + dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 + return empty + dummy + else: + return empty + + return super().forward(x) + + +@CONV_LAYERS.register_module() +@CONV_LAYERS.register_module('deconv3d') +@UPSAMPLE_LAYERS.register_module('deconv3d', force=True) +class ConvTranspose3d(nn.ConvTranspose3d): + + def forward(self, x): + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + out_shape = [x.shape[0], self.out_channels] + for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size, + self.padding, self.stride, + self.dilation, self.output_padding): + out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op) + empty = NewEmptyTensorOp.apply(x, out_shape) + if self.training: + # produce dummy gradient to avoid DDP warning. + dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 + return empty + dummy + else: + return empty + + return super().forward(x) + + +class MaxPool2d(nn.MaxPool2d): + + def forward(self, x): + # PyTorch 1.9 does not support empty tensor inference yet + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): + out_shape = list(x.shape[:2]) + for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size), + _pair(self.padding), _pair(self.stride), + _pair(self.dilation)): + o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1 + o = math.ceil(o) if self.ceil_mode else math.floor(o) + out_shape.append(o) + empty = NewEmptyTensorOp.apply(x, out_shape) + return empty + + return super().forward(x) + + +class MaxPool3d(nn.MaxPool3d): + + def forward(self, x): + # PyTorch 1.9 does not support empty tensor inference yet + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): + out_shape = list(x.shape[:2]) + for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size), + _triple(self.padding), + _triple(self.stride), + _triple(self.dilation)): + o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1 + o = math.ceil(o) if self.ceil_mode else math.floor(o) + out_shape.append(o) + empty = NewEmptyTensorOp.apply(x, out_shape) + return empty + + return super().forward(x) + + +class Linear(torch.nn.Linear): + + def forward(self, x): + # empty tensor forward of Linear layer is supported in Pytorch 1.6 + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)): + out_shape = [x.shape[0], self.out_features] + empty = NewEmptyTensorOp.apply(x, out_shape) + if self.training: + # produce dummy gradient to avoid DDP warning. + dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 + return empty + dummy + else: + return empty + + return super().forward(x) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/builder.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/builder.py new file mode 100755 index 00000000..7567316c --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/builder.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..runner import Sequential +from ..utils import Registry, build_from_cfg + + +def build_model_from_cfg(cfg, registry, default_args=None): + """Build a PyTorch model from config dict(s). Different from + ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built. + + Args: + cfg (dict, list[dict]): The config of modules, is is either a config + dict or a list of config dicts. If cfg is a list, a + the built modules will be wrapped with ``nn.Sequential``. + registry (:obj:`Registry`): A registry the module belongs to. + default_args (dict, optional): Default arguments to build the module. + Defaults to None. + + Returns: + nn.Module: A built nn module. + """ + if isinstance(cfg, list): + modules = [ + build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg + ] + return Sequential(*modules) + else: + return build_from_cfg(cfg, registry, default_args) + + +MODELS = Registry('model', build_func=build_model_from_cfg) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/resnet.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/resnet.py new file mode 100755 index 00000000..fb29e625 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/resnet.py @@ -0,0 +1,322 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from typing import Optional, Sequence, Tuple, Union + +import torch.nn as nn +import torch.utils.checkpoint as cp +from torch import Tensor + +from .utils import constant_init, kaiming_init + + +def conv3x3(in_planes: int, + out_planes: int, + stride: int = 1, + dilation: int = 1): + """3x3 convolution with padding.""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, + inplanes: int, + planes: int, + stride: int = 1, + dilation: int = 1, + downsample: Optional[nn.Module] = None, + style: str = 'pytorch', + with_cp: bool = False): + super().__init__() + assert style in ['pytorch', 'caffe'] + self.conv1 = conv3x3(inplanes, planes, stride, dilation) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + assert not with_cp + + def forward(self, x: Tensor) -> Tensor: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, + inplanes: int, + planes: int, + stride: int = 1, + dilation: int = 1, + downsample: Optional[nn.Module] = None, + style: str = 'pytorch', + with_cp: bool = False): + """Bottleneck block. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if + it is "caffe", the stride-two layer is the first 1x1 conv layer. + """ + super().__init__() + assert style in ['pytorch', 'caffe'] + if style == 'pytorch': + conv1_stride = 1 + conv2_stride = stride + else: + conv1_stride = stride + conv2_stride = 1 + self.conv1 = nn.Conv2d( + inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + + self.bn1 = nn.BatchNorm2d(planes) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.with_cp = with_cp + + def forward(self, x: Tensor) -> Tensor: + + def _inner_forward(x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +def make_res_layer(block: nn.Module, + inplanes: int, + planes: int, + blocks: int, + stride: int = 1, + dilation: int = 1, + style: str = 'pytorch', + with_cp: bool = False) -> nn.Module: + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + inplanes, + planes, + stride, + dilation, + downsample, + style=style, + with_cp=with_cp)) + inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(inplanes, planes, 1, dilation, style=style, with_cp=with_cp)) + + return nn.Sequential(*layers) + + +class ResNet(nn.Module): + """ResNet backbone. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + num_stages (int): Resnet stages, normally 4. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze + running stats (mean and var). + bn_frozen (bool): Whether to freeze weight and bias of BN layers. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + """ + + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth: int, + num_stages: int = 4, + strides: Sequence[int] = (1, 2, 2, 2), + dilations: Sequence[int] = (1, 1, 1, 1), + out_indices: Sequence[int] = (0, 1, 2, 3), + style: str = 'pytorch', + frozen_stages: int = -1, + bn_eval: bool = True, + bn_frozen: bool = False, + with_cp: bool = False): + super().__init__() + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + assert num_stages >= 1 and num_stages <= 4 + block, stage_blocks = self.arch_settings[depth] + stage_blocks = stage_blocks[:num_stages] # type: ignore + assert len(strides) == len(dilations) == num_stages + assert max(out_indices) < num_stages + + self.out_indices = out_indices + self.style = style + self.frozen_stages = frozen_stages + self.bn_eval = bn_eval + self.bn_frozen = bn_frozen + self.with_cp = with_cp + + self.inplanes: int = 64 + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.res_layers = [] + for i, num_blocks in enumerate(stage_blocks): + stride = strides[i] + dilation = dilations[i] + planes = 64 * 2**i + res_layer = make_res_layer( + block, + self.inplanes, + planes, + num_blocks, + stride=stride, + dilation=dilation, + style=self.style, + with_cp=with_cp) + self.inplanes = planes * block.expansion # type: ignore + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self.feat_dim = block.expansion * 64 * 2**( # type: ignore + len(stage_blocks) - 1) + + def init_weights(self, pretrained: Optional[str] = None) -> None: + if isinstance(pretrained, str): + logger = logging.getLogger() + from ..runner import load_checkpoint + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor]]: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def train(self, mode: bool = True) -> None: + super().train(mode) + if self.bn_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + if self.bn_frozen: + for params in m.parameters(): + params.requires_grad = False + if mode and self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for param in self.bn1.parameters(): + param.requires_grad = False + self.bn1.eval() + self.bn1.weight.requires_grad = False + self.bn1.bias.requires_grad = False + for i in range(1, self.frozen_stages + 1): + mod = getattr(self, f'layer{i}') + mod.eval() + for param in mod.parameters(): + param.requires_grad = False diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/__init__.py new file mode 100755 index 00000000..a263e31c --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .flops_counter import get_model_complexity_info +from .fuse_conv_bn import fuse_conv_bn +from .sync_bn import revert_sync_batchnorm +from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit, + KaimingInit, NormalInit, PretrainedInit, + TruncNormalInit, UniformInit, XavierInit, + bias_init_with_prob, caffe2_xavier_init, + constant_init, initialize, kaiming_init, normal_init, + trunc_normal_init, uniform_init, xavier_init) + +__all__ = [ + 'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init', + 'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init', + 'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize', + 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit', + 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit', + 'Caffe2XavierInit', 'revert_sync_batchnorm' +] diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/flops_counter.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/flops_counter.py new file mode 100755 index 00000000..b8024ad3 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/flops_counter.py @@ -0,0 +1,603 @@ +# Modified from flops-counter.pytorch by Vladislav Sovrasov +# original repo: https://github.com/sovrasov/flops-counter.pytorch + +# MIT License + +# Copyright (c) 2018 Vladislav Sovrasov + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import sys +import warnings +from functools import partial +from typing import Any, Callable, Dict, Optional, TextIO, Tuple + +import numpy as np +import torch +import torch.nn as nn + +import dbnet_cv + + +def get_model_complexity_info(model: nn.Module, + input_shape: tuple, + print_per_layer_stat: bool = True, + as_strings: bool = True, + input_constructor: Optional[Callable] = None, + flush: bool = False, + ost: TextIO = sys.stdout) -> tuple: + """Get complexity information of a model. + + This method can calculate FLOPs and parameter counts of a model with + corresponding input shape. It can also print complexity information for + each layer in a model. + + Supported layers are listed as below: + - Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``. + - Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, + ``nn.LeakyReLU``, ``nn.ReLU6``. + - Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``, + ``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``, + ``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``, + ``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``, + ``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``. + - BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``, + ``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``, + ``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``. + - Linear: ``nn.Linear``. + - Deconvolution: ``nn.ConvTranspose2d``. + - Upsample: ``nn.Upsample``. + + Args: + model (nn.Module): The model for complexity calculation. + input_shape (tuple): Input shape used for calculation. + print_per_layer_stat (bool): Whether to print complexity information + for each layer in a model. Default: True. + as_strings (bool): Output FLOPs and params counts in a string form. + Default: True. + input_constructor (None | callable): If specified, it takes a callable + method that generates input. otherwise, it will generate a random + tensor with input shape to calculate FLOPs. Default: None. + flush (bool): same as that in :func:`print`. Default: False. + ost (stream): same as ``file`` param in :func:`print`. + Default: sys.stdout. + + Returns: + tuple[float | str]: If ``as_strings`` is set to True, it will return + FLOPs and parameter counts in a string format. otherwise, it will + return those in a float number format. + """ + assert type(input_shape) is tuple + assert len(input_shape) >= 1 + assert isinstance(model, nn.Module) + flops_model = add_flops_counting_methods(model) + flops_model.eval() + flops_model.start_flops_count() + if input_constructor: + input = input_constructor(input_shape) + _ = flops_model(**input) + else: + try: + batch = torch.ones(()).new_empty( + (1, *input_shape), + dtype=next(flops_model.parameters()).dtype, + device=next(flops_model.parameters()).device) + except StopIteration: + # Avoid StopIteration for models which have no parameters, + # like `nn.Relu()`, `nn.AvgPool2d`, etc. + batch = torch.ones(()).new_empty((1, *input_shape)) + + _ = flops_model(batch) + + flops_count, params_count = flops_model.compute_average_flops_cost() + if print_per_layer_stat: + print_model_with_flops( + flops_model, flops_count, params_count, ost=ost, flush=flush) + flops_model.stop_flops_count() + + if as_strings: + return flops_to_string(flops_count), params_to_string(params_count) + + return flops_count, params_count + + +def flops_to_string(flops: float, + units: Optional[str] = 'GFLOPs', + precision: int = 2) -> str: + """Convert FLOPs number into a string. + + Note that Here we take a multiply-add counts as one FLOP. + + Args: + flops (float): FLOPs number to be converted. + units (str | None): Converted FLOPs units. Options are None, 'GFLOPs', + 'MFLOPs', 'KFLOPs', 'FLOPs'. If set to None, it will automatically + choose the most suitable unit for FLOPs. Default: 'GFLOPs'. + precision (int): Digit number after the decimal point. Default: 2. + + Returns: + str: The converted FLOPs number with units. + + Examples: + >>> flops_to_string(1e9) + '1.0 GFLOPs' + >>> flops_to_string(2e5, 'MFLOPs') + '0.2 MFLOPs' + >>> flops_to_string(3e-9, None) + '3e-09 FLOPs' + """ + if units is None: + if flops // 10**9 > 0: + return str(round(flops / 10.**9, precision)) + ' GFLOPs' + elif flops // 10**6 > 0: + return str(round(flops / 10.**6, precision)) + ' MFLOPs' + elif flops // 10**3 > 0: + return str(round(flops / 10.**3, precision)) + ' KFLOPs' + else: + return str(flops) + ' FLOPs' + else: + if units == 'GFLOPs': + return str(round(flops / 10.**9, precision)) + ' ' + units + elif units == 'MFLOPs': + return str(round(flops / 10.**6, precision)) + ' ' + units + elif units == 'KFLOPs': + return str(round(flops / 10.**3, precision)) + ' ' + units + else: + return str(flops) + ' FLOPs' + + +def params_to_string(num_params: float, + units: Optional[str] = None, + precision: int = 2) -> str: + """Convert parameter number into a string. + + Args: + num_params (float): Parameter number to be converted. + units (str | None): Converted FLOPs units. Options are None, 'M', + 'K' and ''. If set to None, it will automatically choose the most + suitable unit for Parameter number. Default: None. + precision (int): Digit number after the decimal point. Default: 2. + + Returns: + str: The converted parameter number with units. + + Examples: + >>> params_to_string(1e9) + '1000.0 M' + >>> params_to_string(2e5) + '200.0 k' + >>> params_to_string(3e-9) + '3e-09' + """ + if units is None: + if num_params // 10**6 > 0: + return str(round(num_params / 10**6, precision)) + ' M' + elif num_params // 10**3: + return str(round(num_params / 10**3, precision)) + ' k' + else: + return str(num_params) + else: + if units == 'M': + return str(round(num_params / 10.**6, precision)) + ' ' + units + elif units == 'K': + return str(round(num_params / 10.**3, precision)) + ' ' + units + else: + return str(num_params) + + +def print_model_with_flops(model: nn.Module, + total_flops: float, + total_params: float, + units: Optional[str] = 'GFLOPs', + precision: int = 3, + ost: TextIO = sys.stdout, + flush: bool = False) -> None: + """Print a model with FLOPs for each layer. + + Args: + model (nn.Module): The model to be printed. + total_flops (float): Total FLOPs of the model. + total_params (float): Total parameter counts of the model. + units (str | None): Converted FLOPs units. Default: 'GFLOPs'. + precision (int): Digit number after the decimal point. Default: 3. + ost (stream): same as `file` param in :func:`print`. + Default: sys.stdout. + flush (bool): same as that in :func:`print`. Default: False. + + Example: + >>> class ExampleModel(nn.Module): + + >>> def __init__(self): + >>> super().__init__() + >>> self.conv1 = nn.Conv2d(3, 8, 3) + >>> self.conv2 = nn.Conv2d(8, 256, 3) + >>> self.conv3 = nn.Conv2d(256, 8, 3) + >>> self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) + >>> self.flatten = nn.Flatten() + >>> self.fc = nn.Linear(8, 1) + + >>> def forward(self, x): + >>> x = self.conv1(x) + >>> x = self.conv2(x) + >>> x = self.conv3(x) + >>> x = self.avg_pool(x) + >>> x = self.flatten(x) + >>> x = self.fc(x) + >>> return x + + >>> model = ExampleModel() + >>> x = (3, 16, 16) + to print the complexity information state for each layer, you can use + >>> get_model_complexity_info(model, x) + or directly use + >>> print_model_with_flops(model, 4579784.0, 37361) + ExampleModel( + 0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs, + (conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1)) # noqa: E501 + (conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1)) + (conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1)) + (avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1)) + (flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, ) + (fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True) + ) + """ + + def accumulate_params(self): + if is_supported_instance(self): + return self.__params__ + else: + sum = 0 + for m in self.children(): + sum += m.accumulate_params() + return sum + + def accumulate_flops(self): + if is_supported_instance(self): + return self.__flops__ / model.__batch_counter__ + else: + sum = 0 + for m in self.children(): + sum += m.accumulate_flops() + return sum + + def flops_repr(self): + accumulated_num_params = self.accumulate_params() + accumulated_flops_cost = self.accumulate_flops() + return ', '.join([ + params_to_string( + accumulated_num_params, units='M', precision=precision), + f'{accumulated_num_params / total_params:.3%} Params', + flops_to_string( + accumulated_flops_cost, units=units, precision=precision), + f'{accumulated_flops_cost / total_flops:.3%} FLOPs', + self.original_extra_repr() + ]) + + def add_extra_repr(m): + m.accumulate_flops = accumulate_flops.__get__(m) + m.accumulate_params = accumulate_params.__get__(m) + flops_extra_repr = flops_repr.__get__(m) + if m.extra_repr != flops_extra_repr: + m.original_extra_repr = m.extra_repr + m.extra_repr = flops_extra_repr + assert m.extra_repr != m.original_extra_repr + + def del_extra_repr(m): + if hasattr(m, 'original_extra_repr'): + m.extra_repr = m.original_extra_repr + del m.original_extra_repr + if hasattr(m, 'accumulate_flops'): + del m.accumulate_flops + + model.apply(add_extra_repr) + print(model, file=ost, flush=flush) + model.apply(del_extra_repr) + + +def get_model_parameters_number(model: nn.Module) -> float: + """Calculate parameter number of a model. + + Args: + model (nn.module): The model for parameter number calculation. + + Returns: + float: Parameter number of the model. + """ + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + return num_params + + +def add_flops_counting_methods(net_main_module: nn.Module) -> nn.Module: + # adding additional methods to the existing module object, + # this is done this way so that each function has access to self object + net_main_module.start_flops_count = start_flops_count.__get__( # type: ignore # noqa E501 + net_main_module) + net_main_module.stop_flops_count = stop_flops_count.__get__( # type: ignore # noqa E501 + net_main_module) + net_main_module.reset_flops_count = reset_flops_count.__get__( # type: ignore # noqa E501 + net_main_module) + net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # type: ignore # noqa E501 + net_main_module) + + net_main_module.reset_flops_count() + + return net_main_module + + +def compute_average_flops_cost(self) -> Tuple[float, float]: + """Compute average FLOPs cost. + + A method to compute average FLOPs cost, which will be available after + `add_flops_counting_methods()` is called on a desired net object. + + Returns: + float: Current mean flops consumption per image. + """ + batches_count = self.__batch_counter__ + flops_sum = 0 + for module in self.modules(): + if is_supported_instance(module): + flops_sum += module.__flops__ + params_sum = get_model_parameters_number(self) + return flops_sum / batches_count, params_sum + + +def start_flops_count(self) -> None: + """Activate the computation of mean flops consumption per image. + + A method to activate the computation of mean flops consumption per image. + which will be available after ``add_flops_counting_methods()`` is called on + a desired net object. It should be called before running the network. + """ + add_batch_counter_hook_function(self) + + def add_flops_counter_hook_function(module: nn.Module) -> None: + if is_supported_instance(module): + if hasattr(module, '__flops_handle__'): + return + + else: + handle = module.register_forward_hook( + get_modules_mapping()[type(module)]) + + module.__flops_handle__ = handle + + self.apply(partial(add_flops_counter_hook_function)) + + +def stop_flops_count(self) -> None: + """Stop computing the mean flops consumption per image. + + A method to stop computing the mean flops consumption per image, which will + be available after ``add_flops_counting_methods()`` is called on a desired + net object. It can be called to pause the computation whenever. + """ + remove_batch_counter_hook_function(self) + self.apply(remove_flops_counter_hook_function) + + +def reset_flops_count(self) -> None: + """Reset statistics computed so far. + + A method to Reset computed statistics, which will be available after + `add_flops_counting_methods()` is called on a desired net object. + """ + add_batch_counter_variables_or_reset(self) + self.apply(add_flops_counter_variable_or_reset) + + +# ---- Internal functions +def empty_flops_counter_hook(module: nn.Module, input: tuple, + output: Any) -> None: + module.__flops__ += 0 + + +def upsample_flops_counter_hook(module: nn.Module, input: tuple, + output: torch.Tensor) -> None: + output_size = output[0] + batch_size = output_size.shape[0] + output_elements_count = batch_size + for val in output_size.shape[1:]: + output_elements_count *= val + module.__flops__ += int(output_elements_count) + + +def relu_flops_counter_hook(module: nn.Module, input: tuple, + output: torch.Tensor) -> None: + active_elements_count = output.numel() + module.__flops__ += int(active_elements_count) + + +def linear_flops_counter_hook(module: nn.Module, input: tuple, + output: torch.Tensor) -> None: + output_last_dim = output.shape[ + -1] # pytorch checks dimensions, so here we don't care much + module.__flops__ += int(np.prod(input[0].shape) * output_last_dim) + + +def pool_flops_counter_hook(module: nn.Module, input: tuple, + output: torch.Tensor) -> None: + module.__flops__ += int(np.prod(input[0].shape)) + + +def norm_flops_counter_hook(module: nn.Module, input: tuple, + output: torch.Tensor) -> None: + batch_flops = np.prod(input[0].shape) + if (getattr(module, 'affine', False) + or getattr(module, 'elementwise_affine', False)): + batch_flops *= 2 + module.__flops__ += int(batch_flops) + + +def deconv_flops_counter_hook(conv_module: nn.Module, input: tuple, + output: torch.Tensor) -> None: + # Can have multiple inputs, getting the first one + batch_size = input[0].shape[0] + input_height, input_width = input[0].shape[2:] + + kernel_height, kernel_width = conv_module.kernel_size + in_channels = conv_module.in_channels + out_channels = conv_module.out_channels + groups = conv_module.groups + + filters_per_channel = out_channels // groups + conv_per_position_flops = ( + kernel_height * kernel_width * in_channels * filters_per_channel) + + active_elements_count = batch_size * input_height * input_width + overall_conv_flops = conv_per_position_flops * active_elements_count + bias_flops = 0 + if conv_module.bias is not None: + output_height, output_width = output.shape[2:] + bias_flops = out_channels * batch_size * output_height * output_width + overall_flops = overall_conv_flops + bias_flops + + conv_module.__flops__ += int(overall_flops) + + +def conv_flops_counter_hook(conv_module: nn.Module, input: tuple, + output: torch.Tensor) -> None: + # Can have multiple inputs, getting the first one + batch_size = input[0].shape[0] + output_dims = list(output.shape[2:]) + + kernel_dims = list(conv_module.kernel_size) + in_channels = conv_module.in_channels + out_channels = conv_module.out_channels + groups = conv_module.groups + + filters_per_channel = out_channels // groups + conv_per_position_flops = int( + np.prod(kernel_dims)) * in_channels * filters_per_channel + + active_elements_count = batch_size * int(np.prod(output_dims)) + + overall_conv_flops = conv_per_position_flops * active_elements_count + + bias_flops = 0 + + if conv_module.bias is not None: + + bias_flops = out_channels * active_elements_count + + overall_flops = overall_conv_flops + bias_flops + + conv_module.__flops__ += int(overall_flops) + + +def batch_counter_hook(module: nn.Module, input: tuple, output: Any) -> None: + batch_size = 1 + if len(input) > 0: + # Can have multiple inputs, getting the first one + batch_size = len(input[0]) + else: + warnings.warn('No positional inputs found for a module, ' + 'assuming batch size is 1.') + module.__batch_counter__ += batch_size + + +def add_batch_counter_variables_or_reset(module: nn.Module) -> None: + + module.__batch_counter__ = 0 + + +def add_batch_counter_hook_function(module: nn.Module) -> None: + if hasattr(module, '__batch_counter_handle__'): + return + + handle = module.register_forward_hook(batch_counter_hook) + module.__batch_counter_handle__ = handle + + +def remove_batch_counter_hook_function(module: nn.Module) -> None: + if hasattr(module, '__batch_counter_handle__'): + module.__batch_counter_handle__.remove() + del module.__batch_counter_handle__ + + +def add_flops_counter_variable_or_reset(module: nn.Module) -> None: + if is_supported_instance(module): + if hasattr(module, '__flops__') or hasattr(module, '__params__'): + warnings.warn('variables __flops__ or __params__ are already ' + 'defined for the module' + type(module).__name__ + + ' ptflops can affect your code!') + module.__flops__ = 0 + module.__params__ = get_model_parameters_number(module) + + +def is_supported_instance(module: nn.Module) -> bool: + if type(module) in get_modules_mapping(): + return True + return False + + +def remove_flops_counter_hook_function(module: nn.Module) -> None: + if is_supported_instance(module): + if hasattr(module, '__flops_handle__'): + module.__flops_handle__.remove() + del module.__flops_handle__ + + +def get_modules_mapping() -> Dict: + return { + # convolutions + nn.Conv1d: conv_flops_counter_hook, + nn.Conv2d: conv_flops_counter_hook, + dbnet_cv.cnn.bricks.Conv2d: conv_flops_counter_hook, + nn.Conv3d: conv_flops_counter_hook, + dbnet_cv.cnn.bricks.Conv3d: conv_flops_counter_hook, + # activations + nn.ReLU: relu_flops_counter_hook, + nn.PReLU: relu_flops_counter_hook, + nn.ELU: relu_flops_counter_hook, + nn.LeakyReLU: relu_flops_counter_hook, + nn.ReLU6: relu_flops_counter_hook, + # poolings + nn.MaxPool1d: pool_flops_counter_hook, + nn.AvgPool1d: pool_flops_counter_hook, + nn.AvgPool2d: pool_flops_counter_hook, + nn.MaxPool2d: pool_flops_counter_hook, + dbnet_cv.cnn.bricks.MaxPool2d: pool_flops_counter_hook, + nn.MaxPool3d: pool_flops_counter_hook, + dbnet_cv.cnn.bricks.MaxPool3d: pool_flops_counter_hook, + nn.AvgPool3d: pool_flops_counter_hook, + nn.AdaptiveMaxPool1d: pool_flops_counter_hook, + nn.AdaptiveAvgPool1d: pool_flops_counter_hook, + nn.AdaptiveMaxPool2d: pool_flops_counter_hook, + nn.AdaptiveAvgPool2d: pool_flops_counter_hook, + nn.AdaptiveMaxPool3d: pool_flops_counter_hook, + nn.AdaptiveAvgPool3d: pool_flops_counter_hook, + # normalizations + nn.BatchNorm1d: norm_flops_counter_hook, + nn.BatchNorm2d: norm_flops_counter_hook, + nn.BatchNorm3d: norm_flops_counter_hook, + nn.GroupNorm: norm_flops_counter_hook, + nn.InstanceNorm1d: norm_flops_counter_hook, + nn.InstanceNorm2d: norm_flops_counter_hook, + nn.InstanceNorm3d: norm_flops_counter_hook, + nn.LayerNorm: norm_flops_counter_hook, + # FC + nn.Linear: linear_flops_counter_hook, + dbnet_cv.cnn.bricks.Linear: linear_flops_counter_hook, + # Upscale + nn.Upsample: upsample_flops_counter_hook, + # Deconvolution + nn.ConvTranspose2d: deconv_flops_counter_hook, + dbnet_cv.cnn.bricks.ConvTranspose2d: deconv_flops_counter_hook, + } diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/fuse_conv_bn.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/fuse_conv_bn.py new file mode 100755 index 00000000..6ccaab3b --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/fuse_conv_bn.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + + +def _fuse_conv_bn(conv: nn.Module, bn: nn.Module) -> nn.Module: + """Fuse conv and bn into one module. + + Args: + conv (nn.Module): Conv to be fused. + bn (nn.Module): BN to be fused. + + Returns: + nn.Module: Fused module. + """ + conv_w = conv.weight + conv_b = conv.bias if conv.bias is not None else torch.zeros_like( + bn.running_mean) + + factor = bn.weight / torch.sqrt(bn.running_var + bn.eps) + conv.weight = nn.Parameter(conv_w * + factor.reshape([conv.out_channels, 1, 1, 1])) + conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias) + return conv + + +def fuse_conv_bn(module: nn.Module) -> nn.Module: + """Recursively fuse conv and bn in a module. + + During inference, the functionary of batch norm layers is turned off + but only the mean and var alone channels are used, which exposes the + chance to fuse it with the preceding conv layers to save computations and + simplify network structures. + + Args: + module (nn.Module): Module to be fused. + + Returns: + nn.Module: Fused module. + """ + last_conv = None + last_conv_name = None + + for name, child in module.named_children(): + if isinstance(child, + (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)): + if last_conv is None: # only fuse BN that is after Conv + continue + fused_conv = _fuse_conv_bn(last_conv, child) + module._modules[last_conv_name] = fused_conv + # To reduce changes, set BN as Identity instead of deleting it. + module._modules[name] = nn.Identity() + last_conv = None + elif isinstance(child, nn.Conv2d): + last_conv = child + last_conv_name = name + else: + fuse_conv_bn(child) + return module diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/sync_bn.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/sync_bn.py new file mode 100755 index 00000000..47e36588 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/sync_bn.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +import dbnet_cv + + +class _BatchNormXd(nn.modules.batchnorm._BatchNorm): + """A general BatchNorm layer without input dimension check. + + Reproduced from @kapily's work: + (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) + The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc + is `_check_input_dim` that is designed for tensor sanity checks. + The check has been bypassed in this class for the convenience of converting + SyncBatchNorm. + """ + + def _check_input_dim(self, input: torch.Tensor): + return + + +def revert_sync_batchnorm(module: nn.Module) -> nn.Module: + """Helper function to convert all `SyncBatchNorm` (SyncBN) and + `dbnet_cv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to + `BatchNormXd` layers. + + Adapted from @kapily's work: + (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) + + Args: + module (nn.Module): The module containing `SyncBatchNorm` layers. + + Returns: + module_output: The converted module with `BatchNormXd` layers. + """ + module_output = module + module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm] + if hasattr(dbnet_cv, 'ops'): + module_checklist.append(dbnet_cv.ops.SyncBatchNorm) + if isinstance(module, tuple(module_checklist)): + module_output = _BatchNormXd(module.num_features, module.eps, + module.momentum, module.affine, + module.track_running_stats) + if module.affine: + # no_grad() may not be needed here but + # just to be consistent with `convert_sync_batchnorm()` + with torch.no_grad(): + module_output.weight = module.weight + module_output.bias = module.bias + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + module_output.training = module.training + # qconfig exists in quantized models + if hasattr(module, 'qconfig'): + module_output.qconfig = module.qconfig + for name, child in module.named_children(): + module_output.add_module(name, revert_sync_batchnorm(child)) + del module + return module_output diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/weight_init.py b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/weight_init.py new file mode 100755 index 00000000..8ed99109 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/cnn/utils/weight_init.py @@ -0,0 +1,708 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import math +import warnings +from typing import Dict, List, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor + +from dbnet_cv.utils import Registry, build_from_cfg, get_logger, print_log + +INITIALIZERS = Registry('initializer') + + +def update_init_info(module: nn.Module, init_info: str) -> None: + """Update the `_params_init_info` in the module if the value of parameters + are changed. + + Args: + module (obj:`nn.Module`): The module of PyTorch with a user-defined + attribute `_params_init_info` which records the initialization + information. + init_info (str): The string that describes the initialization. + """ + assert hasattr( + module, + '_params_init_info'), f'Can not find `_params_init_info` in {module}' + for name, param in module.named_parameters(): + + assert param in module._params_init_info, ( + f'Find a new :obj:`Parameter` ' + f'named `{name}` during executing the ' + f'`init_weights` of ' + f'`{module.__class__.__name__}`. ' + f'Please do not add or ' + f'replace parameters during executing ' + f'the `init_weights`. ') + + # The parameter has been changed during executing the + # `init_weights` of module + mean_value = param.data.mean() + if module._params_init_info[param]['tmp_mean_value'] != mean_value: + module._params_init_info[param]['init_info'] = init_info + module._params_init_info[param]['tmp_mean_value'] = mean_value + + +def constant_init(module: nn.Module, val: float, bias: float = 0) -> None: + if hasattr(module, 'weight') and module.weight is not None: + nn.init.constant_(module.weight, val) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def xavier_init(module: nn.Module, + gain: float = 1, + bias: float = 0, + distribution: str = 'normal') -> None: + assert distribution in ['uniform', 'normal'] + if hasattr(module, 'weight') and module.weight is not None: + if distribution == 'uniform': + nn.init.xavier_uniform_(module.weight, gain=gain) + else: + nn.init.xavier_normal_(module.weight, gain=gain) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def normal_init(module: nn.Module, + mean: float = 0, + std: float = 1, + bias: float = 0) -> None: + if hasattr(module, 'weight') and module.weight is not None: + nn.init.normal_(module.weight, mean, std) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def trunc_normal_init(module: nn.Module, + mean: float = 0, + std: float = 1, + a: float = -2, + b: float = 2, + bias: float = 0) -> None: + if hasattr(module, 'weight') and module.weight is not None: + trunc_normal_(module.weight, mean, std, a, b) # type: ignore + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) # type: ignore + + +def uniform_init(module: nn.Module, + a: float = 0, + b: float = 1, + bias: float = 0) -> None: + if hasattr(module, 'weight') and module.weight is not None: + nn.init.uniform_(module.weight, a, b) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def kaiming_init(module: nn.Module, + a: float = 0, + mode: str = 'fan_out', + nonlinearity: str = 'relu', + bias: float = 0, + distribution: str = 'normal') -> None: + assert distribution in ['uniform', 'normal'] + if hasattr(module, 'weight') and module.weight is not None: + if distribution == 'uniform': + nn.init.kaiming_uniform_( + module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + else: + nn.init.kaiming_normal_( + module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def caffe2_xavier_init(module: nn.Module, bias: float = 0) -> None: + # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch + # Acknowledgment to FAIR's internal code + kaiming_init( + module, + a=1, + mode='fan_in', + nonlinearity='leaky_relu', + bias=bias, + distribution='uniform') + + +def bias_init_with_prob(prior_prob: float) -> float: + """initialize conv/fc bias value according to a given probability value.""" + bias_init = float(-np.log((1 - prior_prob) / prior_prob)) + return bias_init + + +def _get_bases_name(m: nn.Module) -> List[str]: + return [b.__name__ for b in m.__class__.__bases__] + + +class BaseInit: + + def __init__(self, + *, + bias: float = 0, + bias_prob: Optional[float] = None, + layer: Union[str, List, None] = None): + self.wholemodule = False + if not isinstance(bias, (int, float)): + raise TypeError(f'bias must be a number, but got a {type(bias)}') + + if bias_prob is not None: + if not isinstance(bias_prob, float): + raise TypeError(f'bias_prob type must be float, \ + but got {type(bias_prob)}') + + if layer is not None: + if not isinstance(layer, (str, list)): + raise TypeError(f'layer must be a str or a list of str, \ + but got a {type(layer)}') + else: + layer = [] + + if bias_prob is not None: + self.bias = bias_init_with_prob(bias_prob) + else: + self.bias = bias + self.layer = [layer] if isinstance(layer, str) else layer + + def _get_init_info(self) -> str: + info = f'{self.__class__.__name__}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='Constant') +class ConstantInit(BaseInit): + """Initialize module parameters with constant values. + + Args: + val (int | float): the value to fill the weights in the module with + bias (int | float): the value to fill the bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + """ + + def __init__(self, val: Union[int, float], **kwargs): + super().__init__(**kwargs) + self.val = val + + def __call__(self, module: nn.Module) -> None: + + def init(m): + if self.wholemodule: + constant_init(m, self.val, self.bias) + else: + layername = m.__class__.__name__ + basesname = _get_bases_name(m) + if len(set(self.layer) & set([layername] + basesname)): + constant_init(m, self.val, self.bias) + + module.apply(init) + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self) -> str: + info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='Xavier') +class XavierInit(BaseInit): + r"""Initialize module parameters with values according to the method + described in `Understanding the difficulty of training deep feedforward + neural networks - Glorot, X. & Bengio, Y. (2010). + `_ + + Args: + gain (int | float): an optional scaling factor. Defaults to 1. + bias (int | float): the value to fill the bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + distribution (str): distribution either be ``'normal'`` + or ``'uniform'``. Defaults to ``'normal'``. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + """ + + def __init__(self, + gain: float = 1, + distribution: str = 'normal', + **kwargs): + super().__init__(**kwargs) + self.gain = gain + self.distribution = distribution + + def __call__(self, module: nn.Module) -> None: + + def init(m): + if self.wholemodule: + xavier_init(m, self.gain, self.bias, self.distribution) + else: + layername = m.__class__.__name__ + basesname = _get_bases_name(m) + if len(set(self.layer) & set([layername] + basesname)): + xavier_init(m, self.gain, self.bias, self.distribution) + + module.apply(init) + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self) -> str: + info = f'{self.__class__.__name__}: gain={self.gain}, ' \ + f'distribution={self.distribution}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='Normal') +class NormalInit(BaseInit): + r"""Initialize module parameters with the values drawn from the normal + distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. + + Args: + mean (int | float):the mean of the normal distribution. Defaults to 0. + std (int | float): the standard deviation of the normal distribution. + Defaults to 1. + bias (int | float): the value to fill the bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + + """ + + def __init__(self, mean: float = 0, std: float = 1, **kwargs): + super().__init__(**kwargs) + self.mean = mean + self.std = std + + def __call__(self, module: nn.Module) -> None: + + def init(m): + if self.wholemodule: + normal_init(m, self.mean, self.std, self.bias) + else: + layername = m.__class__.__name__ + basesname = _get_bases_name(m) + if len(set(self.layer) & set([layername] + basesname)): + normal_init(m, self.mean, self.std, self.bias) + + module.apply(init) + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self) -> str: + info = f'{self.__class__.__name__}: mean={self.mean},' \ + f' std={self.std}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='TruncNormal') +class TruncNormalInit(BaseInit): + r"""Initialize module parameters with the values drawn from the normal + distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values + outside :math:`[a, b]`. + + Args: + mean (float): the mean of the normal distribution. Defaults to 0. + std (float): the standard deviation of the normal distribution. + Defaults to 1. + a (float): The minimum cutoff value. + b ( float): The maximum cutoff value. + bias (float): the value to fill the bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + + """ + + def __init__(self, + mean: float = 0, + std: float = 1, + a: float = -2, + b: float = 2, + **kwargs) -> None: + super().__init__(**kwargs) + self.mean = mean + self.std = std + self.a = a + self.b = b + + def __call__(self, module: nn.Module) -> None: + + def init(m): + if self.wholemodule: + trunc_normal_init(m, self.mean, self.std, self.a, self.b, + self.bias) + else: + layername = m.__class__.__name__ + basesname = _get_bases_name(m) + if len(set(self.layer) & set([layername] + basesname)): + trunc_normal_init(m, self.mean, self.std, self.a, self.b, + self.bias) + + module.apply(init) + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self): + info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \ + f' mean={self.mean}, std={self.std}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='Uniform') +class UniformInit(BaseInit): + r"""Initialize module parameters with values drawn from the uniform + distribution :math:`\mathcal{U}(a, b)`. + + Args: + a (int | float): the lower bound of the uniform distribution. + Defaults to 0. + b (int | float): the upper bound of the uniform distribution. + Defaults to 1. + bias (int | float): the value to fill the bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + """ + + def __init__(self, a: float = 0., b: float = 1., **kwargs): + super().__init__(**kwargs) + self.a = a + self.b = b + + def __call__(self, module: nn.Module) -> None: + + def init(m): + if self.wholemodule: + uniform_init(m, self.a, self.b, self.bias) + else: + layername = m.__class__.__name__ + basesname = _get_bases_name(m) + if len(set(self.layer) & set([layername] + basesname)): + uniform_init(m, self.a, self.b, self.bias) + + module.apply(init) + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self) -> str: + info = f'{self.__class__.__name__}: a={self.a},' \ + f' b={self.b}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='Kaiming') +class KaimingInit(BaseInit): + r"""Initialize module parameters with the values according to the method + described in `Delving deep into rectifiers: Surpassing human-level + performance on ImageNet classification - He, K. et al. (2015). + `_ + + Args: + a (int | float): the negative slope of the rectifier used after this + layer (only used with ``'leaky_relu'``). Defaults to 0. + mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing + ``'fan_in'`` preserves the magnitude of the variance of the weights + in the forward pass. Choosing ``'fan_out'`` preserves the + magnitudes in the backwards pass. Defaults to ``'fan_out'``. + nonlinearity (str): the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` . + Defaults to 'relu'. + bias (int | float): the value to fill the bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + distribution (str): distribution either be ``'normal'`` or + ``'uniform'``. Defaults to ``'normal'``. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + """ + + def __init__(self, + a: float = 0, + mode: str = 'fan_out', + nonlinearity: str = 'relu', + distribution: str = 'normal', + **kwargs): + super().__init__(**kwargs) + self.a = a + self.mode = mode + self.nonlinearity = nonlinearity + self.distribution = distribution + + def __call__(self, module: nn.Module) -> None: + + def init(m): + if self.wholemodule: + kaiming_init(m, self.a, self.mode, self.nonlinearity, + self.bias, self.distribution) + else: + layername = m.__class__.__name__ + basesname = _get_bases_name(m) + if len(set(self.layer) & set([layername] + basesname)): + kaiming_init(m, self.a, self.mode, self.nonlinearity, + self.bias, self.distribution) + + module.apply(init) + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self) -> str: + info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \ + f'nonlinearity={self.nonlinearity}, ' \ + f'distribution ={self.distribution}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='Caffe2Xavier') +class Caffe2XavierInit(KaimingInit): + # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch + # Acknowledgment to FAIR's internal code + def __init__(self, **kwargs): + super().__init__( + a=1, + mode='fan_in', + nonlinearity='leaky_relu', + distribution='uniform', + **kwargs) + + def __call__(self, module: nn.Module) -> None: + super().__call__(module) + + +@INITIALIZERS.register_module(name='Pretrained') +class PretrainedInit: + """Initialize module by loading a pretrained model. + + Args: + checkpoint (str): the checkpoint file of the pretrained model should + be load. + prefix (str, optional): the prefix of a sub-module in the pretrained + model. it is for loading a part of the pretrained model to + initialize. For example, if we would like to only load the + backbone of a detector model, we can set ``prefix='backbone.'``. + Defaults to None. + map_location (str): map tensors into proper locations. + """ + + def __init__(self, + checkpoint: str, + prefix: Optional[str] = None, + map_location: Optional[str] = None): + self.checkpoint = checkpoint + self.prefix = prefix + self.map_location = map_location + + def __call__(self, module: nn.Module) -> None: + from dbnet_cv.runner import (_load_checkpoint_with_prefix, load_checkpoint, + load_state_dict) + logger = get_logger('dbnet_cv') + if self.prefix is None: + print_log(f'load model from: {self.checkpoint}', logger=logger) + load_checkpoint( + module, + self.checkpoint, + map_location=self.map_location, + strict=False, + logger=logger) + else: + print_log( + f'load {self.prefix} in model from: {self.checkpoint}', + logger=logger) + state_dict = _load_checkpoint_with_prefix( + self.prefix, self.checkpoint, map_location=self.map_location) + load_state_dict(module, state_dict, strict=False, logger=logger) + + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self) -> str: + info = f'{self.__class__.__name__}: load from {self.checkpoint}' + return info + + +def _initialize(module: nn.Module, + cfg: Dict, + wholemodule: bool = False) -> None: + func = build_from_cfg(cfg, INITIALIZERS) + # wholemodule flag is for override mode, there is no layer key in override + # and initializer will give init values for the whole module with the name + # in override. + func.wholemodule = wholemodule + func(module) + + +def _initialize_override(module: nn.Module, override: Union[Dict, List], + cfg: Dict) -> None: + if not isinstance(override, (dict, list)): + raise TypeError(f'override must be a dict or a list of dict, \ + but got {type(override)}') + + override = [override] if isinstance(override, dict) else override + + for override_ in override: + + cp_override = copy.deepcopy(override_) + name = cp_override.pop('name', None) + if name is None: + raise ValueError('`override` must contain the key "name",' + f'but got {cp_override}') + # if override only has name key, it means use args in init_cfg + if not cp_override: + cp_override.update(cfg) + # if override has name key and other args except type key, it will + # raise error + elif 'type' not in cp_override.keys(): + raise ValueError( + f'`override` need "type" key, but got {cp_override}') + + if hasattr(module, name): + _initialize(getattr(module, name), cp_override, wholemodule=True) + else: + raise RuntimeError(f'module did not have attribute {name}, ' + f'but init_cfg is {cp_override}.') + + +def initialize(module: nn.Module, init_cfg: Union[Dict, List[dict]]) -> None: + r"""Initialize a module. + + Args: + module (``torch.nn.Module``): the module will be initialized. + init_cfg (dict | list[dict]): initialization configuration dict to + define initializer. OpenMMLab has implemented 6 initializers + including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``, + ``Kaiming``, and ``Pretrained``. + + Example: + >>> module = nn.Linear(2, 3, bias=True) + >>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2) + >>> initialize(module, init_cfg) + + >>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2)) + >>> # define key ``'layer'`` for initializing layer with different + >>> # configuration + >>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1), + dict(type='Constant', layer='Linear', val=2)] + >>> initialize(module, init_cfg) + + >>> # define key``'override'`` to initialize some specific part in + >>> # module + >>> class FooNet(nn.Module): + >>> def __init__(self): + >>> super().__init__() + >>> self.feat = nn.Conv2d(3, 16, 3) + >>> self.reg = nn.Conv2d(16, 10, 3) + >>> self.cls = nn.Conv2d(16, 5, 3) + >>> model = FooNet() + >>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d', + >>> override=dict(type='Constant', name='reg', val=3, bias=4)) + >>> initialize(model, init_cfg) + + >>> model = ResNet(depth=50) + >>> # Initialize weights with the pretrained model. + >>> init_cfg = dict(type='Pretrained', + checkpoint='torchvision://resnet50') + >>> initialize(model, init_cfg) + + >>> # Initialize weights of a sub-module with the specific part of + >>> # a pretrained model by using "prefix". + >>> url = 'http://download.openmmlab.com/dbnet_detection/v2.0/retinanet/'\ + >>> 'retinanet_r50_fpn_1x_coco/'\ + >>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' + >>> init_cfg = dict(type='Pretrained', + checkpoint=url, prefix='backbone.') + """ + if not isinstance(init_cfg, (dict, list)): + raise TypeError(f'init_cfg must be a dict or a list of dict, \ + but got {type(init_cfg)}') + + if isinstance(init_cfg, dict): + init_cfg = [init_cfg] + + for cfg in init_cfg: + # should deeply copy the original config because cfg may be used by + # other modules, e.g., one init_cfg shared by multiple bottleneck + # blocks, the expected cfg will be changed after pop and will change + # the initialization behavior of other modules + cp_cfg = copy.deepcopy(cfg) + override = cp_cfg.pop('override', None) + _initialize(module, cp_cfg) + + if override is not None: + cp_cfg.pop('layer', None) + _initialize_override(module, override, cp_cfg) + else: + # All attributes in module have same initialization. + pass + + +def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float, + b: float) -> Tensor: + # Method based on + # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + # Modified from + # https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + lower = norm_cdf((a - mean) / std) + upper = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [lower, upper], then translate + # to [2lower-1, 2upper-1]. + tensor.uniform_(2 * lower - 1, 2 * upper - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor: Tensor, + mean: float = 0., + std: float = 1., + a: float = -2., + b: float = 2.) -> Tensor: + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Modified from + https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py + + Args: + tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`. + mean (float): the mean of the normal distribution. + std (float): the standard deviation of the normal distribution. + a (float): the minimum cutoff value. + b (float): the maximum cutoff value. + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/engine/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/engine/__init__.py new file mode 100755 index 00000000..3193b7f6 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/engine/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .test import (collect_results_cpu, collect_results_gpu, multi_gpu_test, + single_gpu_test) + +__all__ = [ + 'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test', + 'single_gpu_test' +] diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/engine/test.py b/cv/ocr/dbnet/pytorch/dbnet_cv/engine/test.py new file mode 100755 index 00000000..99b68971 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/engine/test.py @@ -0,0 +1,213 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import pickle +import shutil +import tempfile +import time +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.utils.data import DataLoader + +import dbnet_cv +from dbnet_cv.runner import get_dist_info + + +def single_gpu_test(model: nn.Module, data_loader: DataLoader) -> list: + """Test model with a single gpu. + + This method tests model with a single gpu and displays test progress bar. + + Args: + model (nn.Module): Model to be tested. + data_loader (nn.Dataloader): Pytorch data loader. + + Returns: + list: The prediction results. + """ + model.eval() + results = [] + dataset = data_loader.dataset + prog_bar = dbnet_cv.ProgressBar(len(dataset)) + for data in data_loader: + with torch.no_grad(): + result = model(return_loss=False, **data) + results.extend(result) + + # Assume result has the same length of batch_size + # refer to https://github.com/open-mmlab/dbnet_cv/issues/985 + batch_size = len(result) + for _ in range(batch_size): + prog_bar.update() + return results + + +def multi_gpu_test(model: nn.Module, + data_loader: DataLoader, + tmpdir: Optional[str] = None, + gpu_collect: bool = False) -> Optional[list]: + """Test model with multiple gpus. + + This method tests model with multiple gpus and collects the results + under two different modes: gpu and cpu modes. By setting + ``gpu_collect=True``, it encodes results to gpu tensors and use gpu + communication for results collection. On cpu mode it saves the results on + different gpus to ``tmpdir`` and collects them by the rank 0 worker. + + Args: + model (nn.Module): Model to be tested. + data_loader (nn.Dataloader): Pytorch data loader. + tmpdir (str): Path of directory to save the temporary results from + different gpus under cpu mode. + gpu_collect (bool): Option to use either gpu or cpu to collect results. + + Returns: + list: The prediction results. + """ + model.eval() + results = [] + dataset = data_loader.dataset + rank, world_size = get_dist_info() + if rank == 0: + prog_bar = dbnet_cv.ProgressBar(len(dataset)) + time.sleep(2) # This line can prevent deadlock problem in some cases. + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(return_loss=False, **data) + results.extend(result) + + if rank == 0: + batch_size = len(result) + batch_size_all = batch_size * world_size + if batch_size_all + prog_bar.completed > len(dataset): + batch_size_all = len(dataset) - prog_bar.completed + for _ in range(batch_size_all): + prog_bar.update() + + # collect results from all ranks + if gpu_collect: + result_from_ranks = collect_results_gpu(results, len(dataset)) + else: + result_from_ranks = collect_results_cpu(results, len(dataset), tmpdir) + return result_from_ranks + + +def collect_results_cpu(result_part: list, + size: int, + tmpdir: Optional[str] = None) -> Optional[list]: + """Collect results under cpu mode. + + On cpu mode, this function will save the results on different gpus to + ``tmpdir`` and collect them by the rank 0 worker. + + Args: + result_part (list): Result list containing result parts + to be collected. + size (int): Size of the results, commonly equal to length of + the results. + tmpdir (str | None): temporal directory for collected results to + store. If set to None, it will create a random temporal directory + for it. + + Returns: + list: The collected results. + """ + rank, world_size = get_dist_info() + # create a tmp dir if it is not specified + if tmpdir is None: + MAX_LEN = 512 + # 32 is whitespace + dir_tensor = torch.full((MAX_LEN, ), + 32, + dtype=torch.uint8, + device='cuda') + if rank == 0: + dbnet_cv.mkdir_or_exist('.dist_test') + tmpdir = tempfile.mkdtemp(dir='.dist_test') + tmpdir = torch.tensor( + bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') + dir_tensor[:len(tmpdir)] = tmpdir + dist.broadcast(dir_tensor, 0) + tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() + else: + dbnet_cv.mkdir_or_exist(tmpdir) + # dump the part result to the dir + part_file = osp.join(tmpdir, f'part_{rank}.pkl') # type: ignore + dbnet_cv.dump(result_part, part_file) + dist.barrier() + # collect all parts + if rank != 0: + return None + else: + # load results of all parts from tmp dir + part_list = [] + for i in range(world_size): + part_file = osp.join(tmpdir, f'part_{i}.pkl') # type: ignore + part_result = dbnet_cv.load(part_file) + # When data is severely insufficient, an empty part_result + # on a certain gpu could makes the overall outputs empty. + if part_result: + part_list.append(part_result) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + # remove tmp dir + shutil.rmtree(tmpdir) # type: ignore + return ordered_results + + +def collect_results_gpu(result_part: list, size: int) -> Optional[list]: + """Collect results under gpu mode. + + On gpu mode, this function will encode results to gpu tensors and use gpu + communication for results collection. + + Args: + result_part (list): Result list containing result parts + to be collected. + size (int): Size of the results, commonly equal to length of + the results. + + Returns: + list: The collected results. + """ + rank, world_size = get_dist_info() + # dump result part to tensor with pickle + part_tensor = torch.tensor( + bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') + # gather all result part tensor shape + shape_tensor = torch.tensor(part_tensor.shape, device='cuda') + shape_list = [shape_tensor.clone() for _ in range(world_size)] + dist.all_gather(shape_list, shape_tensor) + # padding result part tensor to max length + shape_max = torch.tensor(shape_list).max() + part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') + part_send[:shape_tensor[0]] = part_tensor + part_recv_list = [ + part_tensor.new_zeros(shape_max) for _ in range(world_size) + ] + # gather all result part + dist.all_gather(part_recv_list, part_send) + + if rank == 0: + part_list = [] + for recv, shape in zip(part_recv_list, shape_list): + part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()) + # When data is severely insufficient, an empty part_result + # on a certain gpu could makes the overall outputs empty. + if part_result: + part_list.append(part_result) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + return ordered_results + else: + return None diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/__init__.py new file mode 100755 index 00000000..2051b85f --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .file_client import BaseStorageBackend, FileClient +from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler +from .io import dump, load, register_handler +from .parse import dict_from_file, list_from_file + +__all__ = [ + 'BaseStorageBackend', 'FileClient', 'load', 'dump', 'register_handler', + 'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler', + 'list_from_file', 'dict_from_file' +] diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/file_client.py b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/file_client.py new file mode 100755 index 00000000..1f99eb54 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/file_client.py @@ -0,0 +1,1173 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +import os +import os.path as osp +import re +import tempfile +import warnings +from abc import ABCMeta, abstractmethod +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Generator, Iterator, Optional, Tuple, Union +from urllib.request import urlopen + +import dbnet_cv +from dbnet_cv.utils.misc import has_method +from dbnet_cv.utils.path import is_filepath + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + # a flag to indicate whether the backend can create a symlink for a file + _allow_symlink = False + + @property + def name(self): + return self.__class__.__name__ + + @property + def allow_symlink(self): + return self._allow_symlink + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class CephBackend(BaseStorageBackend): + """Ceph storage backend (for internal use). + + Args: + path_mapping (dict|None): path mapping dict from local path to Petrel + path. When ``path_mapping={'src': 'dst'}``, ``src`` in ``filepath`` + will be replaced by ``dst``. Default: None. + + .. warning:: + :class:`dbnet_cv.fileio.file_client.CephBackend` will be deprecated, + please use :class:`dbnet_cv.fileio.file_client.PetrelBackend` instead. + """ + + def __init__(self, path_mapping=None): + try: + import ceph + except ImportError: + raise ImportError('Please install ceph to enable CephBackend.') + + warnings.warn( + 'CephBackend will be deprecated, please use PetrelBackend instead', + DeprecationWarning) + self._client = ceph.S3Client() + assert isinstance(path_mapping, dict) or path_mapping is None + self.path_mapping = path_mapping + + def get(self, filepath): + filepath = str(filepath) + if self.path_mapping is not None: + for k, v in self.path_mapping.items(): + filepath = filepath.replace(k, v) + value = self._client.Get(filepath) + value_buf = memoryview(value) + return value_buf + + def get_text(self, filepath, encoding=None): + raise NotImplementedError + + +class PetrelBackend(BaseStorageBackend): + """Petrel storage backend (for internal use). + + PetrelBackend supports reading and writing data to multiple clusters. + If the file path contains the cluster name, PetrelBackend will read data + from specified cluster or write data to it. Otherwise, PetrelBackend will + access the default cluster. + + Args: + path_mapping (dict, optional): Path mapping dict from local path to + Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in + ``filepath`` will be replaced by ``dst``. Default: None. + enable_mc (bool, optional): Whether to enable memcached support. + Default: True. + + Examples: + >>> filepath1 = 's3://path/of/file' + >>> filepath2 = 'cluster-name:s3://path/of/file' + >>> client = PetrelBackend() + >>> client.get(filepath1) # get data from default cluster + >>> client.get(filepath2) # get data from 'cluster-name' cluster + """ + + def __init__(self, + path_mapping: Optional[dict] = None, + enable_mc: bool = True): + try: + from petrel_client import client + except ImportError: + raise ImportError('Please install petrel_client to enable ' + 'PetrelBackend.') + + self._client = client.Client(enable_mc=enable_mc) + assert isinstance(path_mapping, dict) or path_mapping is None + self.path_mapping = path_mapping + + def _map_path(self, filepath: Union[str, Path]) -> str: + """Map ``filepath`` to a string path whose prefix will be replaced by + :attr:`self.path_mapping`. + + Args: + filepath (str): Path to be mapped. + """ + filepath = str(filepath) + if self.path_mapping is not None: + for k, v in self.path_mapping.items(): + filepath = filepath.replace(k, v) + return filepath + + def _format_path(self, filepath: str) -> str: + """Convert a ``filepath`` to standard format of petrel oss. + + If the ``filepath`` is concatenated by ``os.path.join``, in a Windows + environment, the ``filepath`` will be the format of + 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the + above ``filepath`` will be converted to 's3://bucket_name/image.jpg'. + + Args: + filepath (str): Path to be formatted. + """ + return re.sub(r'\\+', '/', filepath) + + def get(self, filepath: Union[str, Path]) -> memoryview: + """Read data from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + memoryview: A memory view of expected bytes object to avoid + copying. The memoryview object can be converted to bytes by + ``value_buf.tobytes()``. + """ + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + value = self._client.Get(filepath) + value_buf = memoryview(value) + return value_buf + + def get_text(self, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + return str(self.get(filepath), encoding=encoding) + + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Save data to a given ``filepath``. + + Args: + obj (bytes): Data to be saved. + filepath (str or Path): Path to write data. + """ + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + self._client.put(filepath, obj) + + def put_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + """Save data to a given ``filepath``. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to encode the ``obj``. + Default: 'utf-8'. + """ + self.put(bytes(obj, encoding=encoding), filepath) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str or Path): Path to be removed. + """ + if not has_method(self._client, 'delete'): + raise NotImplementedError( + 'Current version of Petrel Python SDK has not supported ' + 'the `delete` method, please use a higher version or dev' + ' branch instead.') + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + self._client.delete(filepath) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + """ + if not (has_method(self._client, 'contains') + and has_method(self._client, 'isdir')): + raise NotImplementedError( + 'Current version of Petrel Python SDK has not supported ' + 'the `contains` and `isdir` methods, please use a higher' + 'version or dev branch instead.') + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + return self._client.contains(filepath) or self._client.isdir(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + """ + if not has_method(self._client, 'isdir'): + raise NotImplementedError( + 'Current version of Petrel Python SDK has not supported ' + 'the `isdir` method, please use a higher version or dev' + ' branch instead.') + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + return self._client.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + """ + if not has_method(self._client, 'contains'): + raise NotImplementedError( + 'Current version of Petrel Python SDK has not supported ' + 'the `contains` method, please use a higher version or ' + 'dev branch instead.') + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + return self._client.contains(filepath) + + def join_path(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: + """Concatenate all file paths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result after concatenation. + """ + filepath = self._format_path(self._map_path(filepath)) + if filepath.endswith('/'): + filepath = filepath[:-1] + formatted_paths = [filepath] + for path in filepaths: + formatted_paths.append(self._format_path(self._map_path(path))) + return '/'.join(formatted_paths) + + @contextmanager + def get_local_path( + self, + filepath: Union[str, + Path]) -> Generator[Union[str, Path], None, None]: + """Download a file from ``filepath`` and return a temporary path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str | Path): Download a file from ``filepath``. + + Examples: + >>> client = PetrelBackend() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with client.get_local_path('s3://path/of/your/file') as path: + ... # do something here + + Yields: + Iterable[str]: Only yield one temporary path. + """ + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + assert self.isfile(filepath) + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.get(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) + + def list_dir_or_file(self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + Petrel has no concept of directories but it simulates the directory + hierarchy in the filesystem through public prefixes. In addition, + if the returned path ends with '/', it means the path is a public + prefix which is a logical directory. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + In addition, the returned path of directory will not contains the + suffix '/' which is consistent with other backends. + + Args: + dir_path (str | Path): Path of the directory. + list_dir (bool): List the directories. Default: True. + list_file (bool): List the path of files. Default: True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Default: None. + recursive (bool): If set to True, recursively scan the + directory. Default: False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + """ + if not has_method(self._client, 'list'): + raise NotImplementedError( + 'Current version of Petrel Python SDK has not supported ' + 'the `list` method, please use a higher version or dev' + ' branch instead.') + + dir_path = self._map_path(dir_path) + dir_path = self._format_path(dir_path) + if list_dir and suffix is not None: + raise TypeError( + '`list_dir` should be False when `suffix` is not None') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('`suffix` must be a string or tuple of strings') + + # Petrel's simulated directory hierarchy assumes that directory paths + # should end with `/` + if not dir_path.endswith('/'): + dir_path += '/' + + root = dir_path + + def _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive): + for path in self._client.list(dir_path): + # the `self.isdir` is not used here to determine whether path + # is a directory, because `self.isdir` relies on + # `self._client.list` + if path.endswith('/'): # a directory path + next_dir_path = self.join_path(dir_path, path) + if list_dir: + # get the relative path and exclude the last + # character '/' + rel_dir = next_dir_path[len(root):-1] + yield rel_dir + if recursive: + yield from _list_dir_or_file(next_dir_path, list_dir, + list_file, suffix, + recursive) + else: # a file path + absolute_path = self.join_path(dir_path, path) + rel_path = absolute_path[len(root):] + if (suffix is None + or rel_path.endswith(suffix)) and list_file: + yield rel_path + + return _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive) + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError( + 'Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, + self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath, encoding=None): + raise NotImplementedError + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_path (str): Lmdb database path. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_path (str): Lmdb database path. + """ + + def __init__(self, + db_path, + readonly=True, + lock=False, + readahead=False, + **kwargs): + try: + import lmdb # NOQA + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + self.db_path = str(db_path) + self.readonly = readonly + self.lock = lock + self.readahead = readahead + self.kwargs = kwargs + self._client = None + + def get(self, filepath): + """Get values according to the filepath. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + """ + if self._client is None: + self._client = self._get_client() + + with self._client.begin(write=False) as txn: + value_buf = txn.get(str(filepath).encode('utf-8')) + return value_buf + + def get_text(self, filepath, encoding=None): + raise NotImplementedError + + def _get_client(self): + import lmdb + + return lmdb.open( + self.db_path, + readonly=self.readonly, + lock=self.lock, + readahead=self.readahead, + **self.kwargs) + + def __del__(self): + self._client.close() + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + _allow_symlink = True + + def get(self, filepath: Union[str, Path]) -> bytes: + """Read data from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. + """ + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + with open(filepath, encoding=encoding) as f: + value_buf = f.read() + return value_buf + + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'wb' mode. + + Note: + ``put`` will create a directory if the directory of ``filepath`` + does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ + dbnet_cv.mkdir_or_exist(osp.dirname(filepath)) + with open(filepath, 'wb') as f: + f.write(obj) + + def put_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + """Write data to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + """ + dbnet_cv.mkdir_or_exist(osp.dirname(filepath)) + with open(filepath, 'w', encoding=encoding) as f: + f.write(obj) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str or Path): Path to be removed. + """ + os.remove(filepath) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + """ + return osp.exists(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + """ + return osp.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + """ + return osp.isfile(filepath) + + def join_path(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: + """Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of *filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result of concatenation. + """ + return osp.join(filepath, *filepaths) + + @contextmanager + def get_local_path( + self, + filepath: Union[str, + Path]) -> Generator[Union[str, Path], None, None]: + """Only for unified API and do nothing.""" + yield filepath + + def list_dir_or_file(self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str | Path): Path of the directory. + list_dir (bool): List the directories. Default: True. + list_file (bool): List the path of files. Default: True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Default: None. + recursive (bool): If set to True, recursively scan the + directory. Default: False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + """ + if list_dir and suffix is not None: + raise TypeError('`suffix` should be None when `list_dir` is True') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('`suffix` must be a string or tuple of strings') + + root = dir_path + + def _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + if (suffix is None + or rel_path.endswith(suffix)) and list_file: + yield rel_path + elif osp.isdir(entry.path): + if list_dir: + rel_dir = osp.relpath(entry.path, root) + yield rel_dir + if recursive: + yield from _list_dir_or_file(entry.path, list_dir, + list_file, suffix, + recursive) + + return _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive) + + +class HTTPBackend(BaseStorageBackend): + """HTTP and HTTPS storage bachend.""" + + def get(self, filepath): + value_buf = urlopen(filepath).read() + return value_buf + + def get_text(self, filepath, encoding='utf-8'): + value_buf = urlopen(filepath).read() + return value_buf.decode(encoding) + + @contextmanager + def get_local_path( + self, filepath: str) -> Generator[Union[str, Path], None, None]: + """Download a file from ``filepath``. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str): Download a file from ``filepath``. + + Examples: + >>> client = HTTPBackend() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with client.get_local_path('http://path/of/your/file') as path: + ... # do something here + """ + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.get(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) + + +class FileClient: + """A general file client to access files in different backends. + + The client loads a file or text in a specified backend from its path + and returns it as a binary or text file. There are two ways to choose a + backend, the name of backend and the prefix of path. Although both of them + can be used to choose a storage backend, ``backend`` has a higher priority + that is if they are all set, the storage backend will be chosen by the + backend argument. If they are all `None`, the disk backend will be chosen. + Note that It can also register other backend accessor with a given name, + prefixes, and backend class. In addition, We use the singleton pattern to + avoid repeated object creation. If the arguments are the same, the same + object will be returned. + + Args: + backend (str, optional): The storage backend type. Options are "disk", + "ceph", "memcached", "lmdb", "http" and "petrel". Default: None. + prefix (str, optional): The prefix of the registered storage backend. + Options are "s3", "http", "https". Default: None. + + Examples: + >>> # only set backend + >>> file_client = FileClient(backend='petrel') + >>> # only set prefix + >>> file_client = FileClient(prefix='s3') + >>> # set both backend and prefix but use backend to choose client + >>> file_client = FileClient(backend='petrel', prefix='s3') + >>> # if the arguments are the same, the same object is returned + >>> file_client1 = FileClient(backend='petrel') + >>> file_client1 is file_client + True + + Attributes: + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'ceph': CephBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + 'petrel': PetrelBackend, + 'http': HTTPBackend, + } + + _prefix_to_backends = { + 's3': PetrelBackend, + 'http': HTTPBackend, + 'https': HTTPBackend, + } + + _instances: dict = {} + + client: Any + + def __new__(cls, backend=None, prefix=None, **kwargs): + if backend is None and prefix is None: + backend = 'disk' + if backend is not None and backend not in cls._backends: + raise ValueError( + f'Backend {backend} is not supported. Currently supported ones' + f' are {list(cls._backends.keys())}') + if prefix is not None and prefix not in cls._prefix_to_backends: + raise ValueError( + f'prefix {prefix} is not supported. Currently supported ones ' + f'are {list(cls._prefix_to_backends.keys())}') + + # concatenate the arguments to a unique key for determining whether + # objects with the same arguments were created + arg_key = f'{backend}:{prefix}' + for key, value in kwargs.items(): + arg_key += f':{key}:{value}' + + if arg_key in cls._instances: + _instance = cls._instances[arg_key] + else: + # create a new object and put it to _instance + _instance = super().__new__(cls) + if backend is not None: + _instance.client = cls._backends[backend](**kwargs) + else: + _instance.client = cls._prefix_to_backends[prefix](**kwargs) + + cls._instances[arg_key] = _instance + + return _instance + + @property + def name(self): + return self.client.name + + @property + def allow_symlink(self): + return self.client.allow_symlink + + @staticmethod + def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: + """Parse the prefix of a uri. + + Args: + uri (str | Path): Uri to be parsed that contains the file prefix. + + Examples: + >>> FileClient.parse_uri_prefix('s3://path/of/your/file') + 's3' + + Returns: + str | None: Return the prefix of uri if the uri contains '://' else + ``None``. + """ + assert is_filepath(uri) + uri = str(uri) + if '://' not in uri: + return None + else: + prefix, _ = uri.split('://') + # In the case of PetrelBackend, the prefix may contains the cluster + # name like clusterName:s3 + if ':' in prefix: + _, prefix = prefix.split(':') + return prefix + + @classmethod + def infer_client(cls, + file_client_args: Optional[dict] = None, + uri: Optional[Union[str, Path]] = None) -> 'FileClient': + """Infer a suitable file client based on the URI and arguments. + + Args: + file_client_args (dict, optional): Arguments to instantiate a + FileClient. Default: None. + uri (str | Path, optional): Uri to be parsed that contains the file + prefix. Default: None. + + Examples: + >>> uri = 's3://path/of/your/file' + >>> file_client = FileClient.infer_client(uri=uri) + >>> file_client_args = {'backend': 'petrel'} + >>> file_client = FileClient.infer_client(file_client_args) + + Returns: + FileClient: Instantiated FileClient object. + """ + assert file_client_args is not None or uri is not None + if file_client_args is None: + file_prefix = cls.parse_uri_prefix(uri) # type: ignore + return cls(prefix=file_prefix) + else: + return cls(**file_client_args) + + @classmethod + def _register_backend(cls, name, backend, force=False, prefixes=None): + if not isinstance(name, str): + raise TypeError('the backend name should be a string, ' + f'but got {type(name)}') + if not inspect.isclass(backend): + raise TypeError( + f'backend should be a class but got {type(backend)}') + if not issubclass(backend, BaseStorageBackend): + raise TypeError( + f'backend {backend} is not a subclass of BaseStorageBackend') + if not force and name in cls._backends: + raise KeyError( + f'{name} is already registered as a storage backend, ' + 'add "force=True" if you want to override it') + + if name in cls._backends and force: + for arg_key, instance in list(cls._instances.items()): + if isinstance(instance.client, cls._backends[name]): + cls._instances.pop(arg_key) + cls._backends[name] = backend + + if prefixes is not None: + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + for prefix in prefixes: + if prefix not in cls._prefix_to_backends: + cls._prefix_to_backends[prefix] = backend + elif (prefix in cls._prefix_to_backends) and force: + overridden_backend = cls._prefix_to_backends[prefix] + if isinstance(overridden_backend, list): + overridden_backend = tuple(overridden_backend) + for arg_key, instance in list(cls._instances.items()): + if isinstance(instance.client, overridden_backend): + cls._instances.pop(arg_key) + cls._prefix_to_backends[prefix] = backend + else: + raise KeyError( + f'{prefix} is already registered as a storage backend,' + ' add "force=True" if you want to override it') + + @classmethod + def register_backend(cls, name, backend=None, force=False, prefixes=None): + """Register a backend to FileClient. + + This method can be used as a normal class method or a decorator. + + .. code-block:: python + + class NewBackend(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + FileClient.register_backend('new', NewBackend) + + or + + .. code-block:: python + + @FileClient.register_backend('new') + class NewBackend(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + Args: + name (str): The name of the registered backend. + backend (class, optional): The backend class to be registered, + which must be a subclass of :class:`BaseStorageBackend`. + When this method is used as a decorator, backend is None. + Defaults to None. + force (bool, optional): Whether to override the backend if the name + has already been registered. Defaults to False. + prefixes (str or list[str] or tuple[str], optional): The prefixes + of the registered storage backend. Default: None. + `New in version 1.3.15.` + """ + if backend is not None: + cls._register_backend( + name, backend, force=force, prefixes=prefixes) + return + + def _register(backend_cls): + cls._register_backend( + name, backend_cls, force=force, prefixes=prefixes) + return backend_cls + + return _register + + def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]: + """Read data from a given ``filepath`` with 'rb' mode. + + Note: + There are two types of return values for ``get``, one is ``bytes`` + and the other is ``memoryview``. The advantage of using memoryview + is that you can avoid copying, and if you want to convert it to + ``bytes``, you can use ``.tobytes()``. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes | memoryview: Expected bytes object or a memory view of the + bytes object. + """ + return self.client.get(filepath) + + def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + return self.client.get_text(filepath, encoding) + + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'wb' mode. + + Note: + ``put`` should create a directory if the directory of ``filepath`` + does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ + self.client.put(obj, filepath) + + def put_text(self, obj: str, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` should create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str, optional): The encoding format used to open the + `filepath`. Default: 'utf-8'. + """ + self.client.put_text(obj, filepath) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str, Path): Path to be removed. + """ + self.client.remove(filepath) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + """ + return self.client.exists(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + """ + return self.client.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + """ + return self.client.isfile(filepath) + + def join_path(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: + """Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of *filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result of concatenation. + """ + return self.client.join_path(filepath, *filepaths) + + @contextmanager + def get_local_path( + self, + filepath: Union[str, + Path]) -> Generator[Union[str, Path], None, None]: + """Download data from ``filepath`` and write the data to local path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Note: + If the ``filepath`` is a local path, just return itself. + + .. warning:: + ``get_local_path`` is an experimental interface that may change in + the future. + + Args: + filepath (str or Path): Path to be read data. + + Examples: + >>> file_client = FileClient(prefix='s3') + >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path: + ... # do something here + + Yields: + Iterable[str]: Only yield one path. + """ + with self.client.get_local_path(str(filepath)) as local_path: + yield local_path + + def list_dir_or_file(self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str | Path): Path of the directory. + list_dir (bool): List the directories. Default: True. + list_file (bool): List the path of files. Default: True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Default: None. + recursive (bool): If set to True, recursively scan the + directory. Default: False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + """ + yield from self.client.list_dir_or_file(dir_path, list_dir, list_file, + suffix, recursive) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/__init__.py new file mode 100755 index 00000000..aa24d919 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseFileHandler +from .json_handler import JsonHandler +from .pickle_handler import PickleHandler +from .yaml_handler import YamlHandler + +__all__ = ['BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler'] diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/base.py b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/base.py new file mode 100755 index 00000000..0c9cc15b --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/base.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + + +class BaseFileHandler(metaclass=ABCMeta): + # `str_like` is a flag to indicate whether the type of file object is + # str-like object or bytes-like object. Pickle only processes bytes-like + # objects but json only processes str-like object. If it is str-like + # object, `StringIO` will be used to process the buffer. + str_like = True + + @abstractmethod + def load_from_fileobj(self, file, **kwargs): + pass + + @abstractmethod + def dump_to_fileobj(self, obj, file, **kwargs): + pass + + @abstractmethod + def dump_to_str(self, obj, **kwargs): + pass + + def load_from_path(self, filepath: str, mode: str = 'r', **kwargs): + with open(filepath, mode) as f: + return self.load_from_fileobj(f, **kwargs) + + def dump_to_path(self, obj, filepath: str, mode: str = 'w', **kwargs): + with open(filepath, mode) as f: + self.dump_to_fileobj(obj, f, **kwargs) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/json_handler.py b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/json_handler.py new file mode 100755 index 00000000..18d4f15f --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/json_handler.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json + +import numpy as np + +from .base import BaseFileHandler + + +def set_default(obj): + """Set default json values for non-serializable values. + + It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. + It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, + etc.) into plain numbers of plain python built-in types. + """ + if isinstance(obj, (set, range)): + return list(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.generic): + return obj.item() + raise TypeError(f'{type(obj)} is unsupported for json dump') + + +class JsonHandler(BaseFileHandler): + + def load_from_fileobj(self, file): + return json.load(file) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault('default', set_default) + json.dump(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault('default', set_default) + return json.dumps(obj, **kwargs) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/pickle_handler.py b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/pickle_handler.py new file mode 100755 index 00000000..073856fd --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/pickle_handler.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pickle + +from .base import BaseFileHandler + + +class PickleHandler(BaseFileHandler): + + str_like = False + + def load_from_fileobj(self, file, **kwargs): + return pickle.load(file, **kwargs) + + def load_from_path(self, filepath, **kwargs): + return super().load_from_path(filepath, mode='rb', **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault('protocol', 2) + return pickle.dumps(obj, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault('protocol', 2) + pickle.dump(obj, file, **kwargs) + + def dump_to_path(self, obj, filepath, **kwargs): + super().dump_to_path(obj, filepath, mode='wb', **kwargs) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/yaml_handler.py b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/yaml_handler.py new file mode 100755 index 00000000..1c1b0779 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/handlers/yaml_handler.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import yaml + +try: + from yaml import CDumper as Dumper + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader, Dumper # type: ignore + +from .base import BaseFileHandler # isort:skip + + +class YamlHandler(BaseFileHandler): + + def load_from_fileobj(self, file, **kwargs): + kwargs.setdefault('Loader', Loader) + return yaml.load(file, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault('Dumper', Dumper) + yaml.dump(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault('Dumper', Dumper) + return yaml.dump(obj, **kwargs) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/io.py b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/io.py new file mode 100755 index 00000000..c5cf5249 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/io.py @@ -0,0 +1,163 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from io import BytesIO, StringIO +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, TextIO, Union + +from ..utils import is_list_of +from .file_client import FileClient +from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler + +FileLikeObject = Union[TextIO, StringIO, BytesIO] + +file_handlers = { + 'json': JsonHandler(), + 'yaml': YamlHandler(), + 'yml': YamlHandler(), + 'pickle': PickleHandler(), + 'pkl': PickleHandler() +} + + +def load(file: Union[str, Path, FileLikeObject], + file_format: Optional[str] = None, + file_client_args: Optional[Dict] = None, + **kwargs): + """Load data from json/yaml/pickle files. + + This method provides a unified api for loading data from serialized files. + + Note: + In v1.3.16 and later, ``load`` supports loading data from serialized + files those can be storaged in different backends. + + Args: + file (str or :obj:`Path` or file-like object): Filename or a file-like + object. + file_format (str, optional): If not specified, the file format will be + inferred from the file extension, otherwise use the specified one. + Currently supported formats include "json", "yaml/yml" and + "pickle/pkl". + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`dbnet_cv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> load('/path/of/your/file') # file is storaged in disk + >>> load('https://path/of/your/file') # file is storaged in Internet + >>> load('s3://path/of/your/file') # file is storaged in petrel + + Returns: + The content from the file. + """ + if isinstance(file, Path): + file = str(file) + if file_format is None and isinstance(file, str): + file_format = file.split('.')[-1] + if file_format not in file_handlers: + raise TypeError(f'Unsupported format: {file_format}') + + handler = file_handlers[file_format] + f: FileLikeObject + if isinstance(file, str): + file_client = FileClient.infer_client(file_client_args, file) + if handler.str_like: + with StringIO(file_client.get_text(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + else: + with BytesIO(file_client.get(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + elif hasattr(file, 'read'): + obj = handler.load_from_fileobj(file, **kwargs) + else: + raise TypeError('"file" must be a filepath str or a file-object') + return obj + + +def dump(obj: Any, + file: Optional[Union[str, Path, FileLikeObject]] = None, + file_format: Optional[str] = None, + file_client_args: Optional[Dict] = None, + **kwargs): + """Dump data to json/yaml/pickle strings or files. + + This method provides a unified api for dumping data as strings or to files, + and also supports custom arguments for each file format. + + Note: + In v1.3.16 and later, ``dump`` supports dumping data as strings or to + files which is saved to different backends. + + Args: + obj (any): The python object to be dumped. + file (str or :obj:`Path` or file-like object, optional): If not + specified, then the object is dumped to a str, otherwise to a file + specified by the filename or file-like object. + file_format (str, optional): Same as :func:`load`. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`dbnet_cv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> dump('hello world', '/path/of/your/file') # disk + >>> dump('hello world', 's3://path/of/your/file') # ceph or petrel + + Returns: + bool: True for success, False otherwise. + """ + if isinstance(file, Path): + file = str(file) + if file_format is None: + if isinstance(file, str): + file_format = file.split('.')[-1] + elif file is None: + raise ValueError( + 'file_format must be specified since file is None') + if file_format not in file_handlers: + raise TypeError(f'Unsupported format: {file_format}') + f: FileLikeObject + handler = file_handlers[file_format] + if file is None: + return handler.dump_to_str(obj, **kwargs) + elif isinstance(file, str): + file_client = FileClient.infer_client(file_client_args, file) + if handler.str_like: + with StringIO() as f: + handler.dump_to_fileobj(obj, f, **kwargs) + file_client.put_text(f.getvalue(), file) + else: + with BytesIO() as f: + handler.dump_to_fileobj(obj, f, **kwargs) + file_client.put(f.getvalue(), file) + elif hasattr(file, 'write'): + handler.dump_to_fileobj(obj, file, **kwargs) + else: + raise TypeError('"file" must be a filename str or a file-object') + + +def _register_handler(handler: BaseFileHandler, + file_formats: Union[str, List[str]]) -> None: + """Register a handler for some file extensions. + + Args: + handler (:obj:`BaseFileHandler`): Handler to be registered. + file_formats (str or list[str]): File formats to be handled by this + handler. + """ + if not isinstance(handler, BaseFileHandler): + raise TypeError( + f'handler must be a child of BaseFileHandler, not {type(handler)}') + if isinstance(file_formats, str): + file_formats = [file_formats] + if not is_list_of(file_formats, str): + raise TypeError('file_formats must be a str or a list of str') + for ext in file_formats: + file_handlers[ext] = handler + + +def register_handler(file_formats: Union[str, list], **kwargs) -> Callable: + + def wrap(cls): + _register_handler(cls(**kwargs), file_formats) + return cls + + return wrap diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/parse.py b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/parse.py new file mode 100755 index 00000000..298e1f99 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/fileio/parse.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from io import StringIO +from pathlib import Path +from typing import Dict, List, Optional, Union + +from .file_client import FileClient + + +def list_from_file(filename: Union[str, Path], + prefix: str = '', + offset: int = 0, + max_num: int = 0, + encoding: str = 'utf-8', + file_client_args: Optional[Dict] = None) -> List: + """Load a text file and parse the content as a list of strings. + + Note: + In v1.3.16 and later, ``list_from_file`` supports loading a text file + which can be storaged in different backends and parsing the content as + a list for strings. + + Args: + filename (str): Filename. + prefix (str): The prefix to be inserted to the beginning of each item. + offset (int): The offset of lines. + max_num (int): The maximum number of lines to be read, + zeros and negatives mean no limitation. + encoding (str): Encoding used to open the file. Default utf-8. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`dbnet_cv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> list_from_file('/path/of/your/file') # disk + ['hello', 'world'] + >>> list_from_file('s3://path/of/your/file') # ceph or petrel + ['hello', 'world'] + + Returns: + list[str]: A list of strings. + """ + cnt = 0 + item_list = [] + file_client = FileClient.infer_client(file_client_args, filename) + with StringIO(file_client.get_text(filename, encoding)) as f: + for _ in range(offset): + f.readline() + for line in f: + if 0 < max_num <= cnt: + break + item_list.append(prefix + line.rstrip('\n\r')) + cnt += 1 + return item_list + + +def dict_from_file(filename: Union[str, Path], + key_type: type = str, + encoding: str = 'utf-8', + file_client_args: Optional[Dict] = None) -> Dict: + """Load a text file and parse the content as a dict. + + Each line of the text file will be two or more columns split by + whitespaces or tabs. The first column will be parsed as dict keys, and + the following columns will be parsed as dict values. + + Note: + In v1.3.16 and later, ``dict_from_file`` supports loading a text file + which can be storaged in different backends and parsing the content as + a dict. + + Args: + filename(str): Filename. + key_type(type): Type of the dict keys. str is user by default and + type conversion will be performed if specified. + encoding (str): Encoding used to open the file. Default utf-8. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`dbnet_cv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> dict_from_file('/path/of/your/file') # disk + {'key1': 'value1', 'key2': 'value2'} + >>> dict_from_file('s3://path/of/your/file') # ceph or petrel + {'key1': 'value1', 'key2': 'value2'} + + Returns: + dict: The parsed contents. + """ + mapping = {} + file_client = FileClient.infer_client(file_client_args, filename) + with StringIO(file_client.get_text(filename, encoding)) as f: + for line in f: + items = line.rstrip('\n').split() + assert len(items) >= 2 + key = key_type(items[0]) + val = items[1:] if len(items) > 2 else items[1] + mapping[key] = val + return mapping diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/image/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/image/__init__.py new file mode 100755 index 00000000..7bbcdefd --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/image/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .colorspace import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, bgr2ycbcr, + gray2bgr, gray2rgb, hls2bgr, hsv2bgr, imconvert, + rgb2bgr, rgb2gray, rgb2ycbcr, ycbcr2bgr, ycbcr2rgb) +from .geometric import (cutout, imcrop, imflip, imflip_, impad, + impad_to_multiple, imrescale, imresize, imresize_like, + imresize_to_multiple, imrotate, imshear, imtranslate, + rescale_size) +from .io import imfrombytes, imread, imwrite, supported_backends, use_backend +from .misc import tensor2imgs +# from .photometric import (adjust_brightness, adjust_color, adjust_contrast, +# adjust_hue, adjust_lighting, adjust_sharpness, +# auto_contrast, clahe, imdenormalize, imequalize, +# iminvert, imnormalize, imnormalize_, lut_transform, +# posterize, solarize) +from .photometric import (adjust_brightness, adjust_color, adjust_contrast, + adjust_hue, adjust_lighting, adjust_sharpness, + auto_contrast, clahe, imdenormalize, imequalize, + iminvert, imnormalize, imnormalize_, lut_transform, + posterize, solarize) +# __all__ = [ +# 'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb', +# 'hls2bgr', 'hsv2bgr', 'imconvert', 'rgb2bgr', 'rgb2gray', 'imrescale', +# 'imresize', 'imresize_like', 'imresize_to_multiple', 'rescale_size', +# 'imcrop', 'imflip', 'imflip_', 'impad', 'impad_to_multiple', 'imrotate', +# 'imfrombytes', 'imread', 'imwrite', 'supported_backends', 'use_backend', +# 'imdenormalize', 'imnormalize', 'imnormalize_', 'iminvert', 'posterize', +# 'solarize', 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr', +# 'tensor2imgs', 'imshear', 'imtranslate', 'adjust_color', 'imequalize', +# 'adjust_brightness', 'adjust_contrast', 'lut_transform', 'clahe', +# 'adjust_sharpness', 'auto_contrast', 'cutout', 'adjust_lighting', +# 'adjust_hue' +# ] diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/image/colorspace.py b/cv/ocr/dbnet/pytorch/dbnet_cv/image/colorspace.py new file mode 100755 index 00000000..08f99524 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/image/colorspace.py @@ -0,0 +1,309 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, Union + +import cv2 +import numpy as np + + +def imconvert(img: np.ndarray, src: str, dst: str) -> np.ndarray: + """Convert an image from the src colorspace to dst colorspace. + + Args: + img (ndarray): The input image. + src (str): The source colorspace, e.g., 'rgb', 'hsv'. + dst (str): The destination colorspace, e.g., 'rgb', 'hsv'. + + Returns: + ndarray: The converted image. + """ + code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}') + out_img = cv2.cvtColor(img, code) + return out_img + + +def bgr2gray(img: np.ndarray, keepdim: bool = False) -> np.ndarray: + """Convert a BGR image to grayscale image. + + Args: + img (ndarray): The input image. + keepdim (bool): If False (by default), then return the grayscale image + with 2 dims, otherwise 3 dims. + + Returns: + ndarray: The converted grayscale image. + """ + out_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + if keepdim: + out_img = out_img[..., None] + return out_img + + +def rgb2gray(img: np.ndarray, keepdim: bool = False) -> np.ndarray: + """Convert a RGB image to grayscale image. + + Args: + img (ndarray): The input image. + keepdim (bool): If False (by default), then return the grayscale image + with 2 dims, otherwise 3 dims. + + Returns: + ndarray: The converted grayscale image. + """ + out_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + if keepdim: + out_img = out_img[..., None] + return out_img + + +def gray2bgr(img: np.ndarray) -> np.ndarray: + """Convert a grayscale image to BGR image. + + Args: + img (ndarray): The input image. + + Returns: + ndarray: The converted BGR image. + """ + img = img[..., None] if img.ndim == 2 else img + out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + return out_img + + +def gray2rgb(img: np.ndarray) -> np.ndarray: + """Convert a grayscale image to RGB image. + + Args: + img (ndarray): The input image. + + Returns: + ndarray: The converted RGB image. + """ + img = img[..., None] if img.ndim == 2 else img + out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + return out_img + + +def _convert_input_type_range(img: np.ndarray) -> np.ndarray: + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + conversion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError('The img type should be np.float32 or np.uint8, ' + f'but got {img_type}') + return img + + +def _convert_output_type_range( + img: np.ndarray, dst_type: Union[np.uint8, np.float32]) -> np.ndarray: + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace conversion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError('The dst_type should be np.float32 or np.uint8, ' + f'but got {dst_type}') + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) + + +def rgb2ycbcr(img: np.ndarray, y_only: bool = False) -> np.ndarray: + """Convert a RGB image to YCbCr image. + + This function produces the same results as Matlab's `rgb2ycbcr` function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 + else: + out_img = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def bgr2ycbcr(img: np.ndarray, y_only: bool = False) -> np.ndarray: + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2rgb(img: np.ndarray) -> np.ndarray: + """Convert a YCbCr image to RGB image. + + This function produces the same results as Matlab's ycbcr2rgb function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted RGB image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], + [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [ + -222.921, 135.576, -276.836 + ] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2bgr(img: np.ndarray) -> np.ndarray: + """Convert a YCbCr image to BGR image. + + The bgr version of ycbcr2rgb. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted BGR image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], + [0.00791071, -0.00153632, 0], + [0, -0.00318811, 0.00625893]]) * 255.0 + [ + -276.836, 135.576, -222.921 + ] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def convert_color_factory(src: str, dst: str) -> Callable: + + code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}') + + def convert_color(img: np.ndarray) -> np.ndarray: + out_img = cv2.cvtColor(img, code) + return out_img + + convert_color.__doc__ = f"""Convert a {src.upper()} image to {dst.upper()} + image. + + Args: + img (ndarray or str): The input image. + + Returns: + ndarray: The converted {dst.upper()} image. + """ + + return convert_color + + +bgr2rgb = convert_color_factory('bgr', 'rgb') + +rgb2bgr = convert_color_factory('rgb', 'bgr') + +bgr2hsv = convert_color_factory('bgr', 'hsv') + +hsv2bgr = convert_color_factory('hsv', 'bgr') + +bgr2hls = convert_color_factory('bgr', 'hls') + +hls2bgr = convert_color_factory('hls', 'bgr') diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/image/geometric.py b/cv/ocr/dbnet/pytorch/dbnet_cv/image/geometric.py new file mode 100755 index 00000000..38e440b6 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/image/geometric.py @@ -0,0 +1,741 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numbers + +import cv2 +import numpy as np + +from ..utils import to_2tuple +from .io import imread_backend + +try: + from PIL import Image +except ImportError: + Image = None + + +def _scale_size(size, scale): + """Rescale a size by a ratio. + + Args: + size (tuple[int]): (w, h). + scale (float | tuple(float)): Scaling factor. + + Returns: + tuple[int]: scaled size. + """ + if isinstance(scale, (float, int)): + scale = (scale, scale) + w, h = size + return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5) + + +cv2_interp_codes = { + 'nearest': cv2.INTER_NEAREST, + 'bilinear': cv2.INTER_LINEAR, + 'bicubic': cv2.INTER_CUBIC, + 'area': cv2.INTER_AREA, + 'lanczos': cv2.INTER_LANCZOS4 +} + +# Pillow >=v9.1.0 use a slightly different naming scheme for filters. +# Set pillow_interp_codes according to the naming scheme used. +if Image is not None: + if hasattr(Image, 'Resampling'): + pillow_interp_codes = { + 'nearest': Image.Resampling.NEAREST, + 'bilinear': Image.Resampling.BILINEAR, + 'bicubic': Image.Resampling.BICUBIC, + 'box': Image.Resampling.BOX, + 'lanczos': Image.Resampling.LANCZOS, + 'hamming': Image.Resampling.HAMMING + } + else: + pillow_interp_codes = { + 'nearest': Image.NEAREST, + 'bilinear': Image.BILINEAR, + 'bicubic': Image.BICUBIC, + 'box': Image.BOX, + 'lanczos': Image.LANCZOS, + 'hamming': Image.HAMMING + } + + +def imresize(img, + size, + return_scale=False, + interpolation='bilinear', + out=None, + backend=None): + """Resize image to a given size. + + Args: + img (ndarray): The input image. + size (tuple[int]): Target size (w, h). + return_scale (bool): Whether to return `w_scale` and `h_scale`. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. + out (ndarray): The output destination. + backend (str | None): The image resize backend type. Options are `cv2`, + `pillow`, `None`. If backend is None, the global imread_backend + specified by ``dbnet_cv.use_backend()`` will be used. Default: None. + + Returns: + tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or + `resized_img`. + """ + h, w = img.shape[:2] + if backend is None: + backend = imread_backend + if backend not in ['cv2', 'pillow']: + raise ValueError(f'backend: {backend} is not supported for resize.' + f"Supported backends are 'cv2', 'pillow'") + + if backend == 'pillow': + assert img.dtype == np.uint8, 'Pillow backend only support uint8 type' + pil_image = Image.fromarray(img) + pil_image = pil_image.resize(size, pillow_interp_codes[interpolation]) + resized_img = np.array(pil_image) + else: + resized_img = cv2.resize( + img, size, dst=out, interpolation=cv2_interp_codes[interpolation]) + if not return_scale: + return resized_img + else: + w_scale = size[0] / w + h_scale = size[1] / h + return resized_img, w_scale, h_scale + + +def imresize_to_multiple(img, + divisor, + size=None, + scale_factor=None, + keep_ratio=False, + return_scale=False, + interpolation='bilinear', + out=None, + backend=None): + """Resize image according to a given size or scale factor and then rounds + up the the resized or rescaled image size to the nearest value that can be + divided by the divisor. + + Args: + img (ndarray): The input image. + divisor (int | tuple): Resized image size will be a multiple of + divisor. If divisor is a tuple, divisor should be + (w_divisor, h_divisor). + size (None | int | tuple[int]): Target size (w, h). Default: None. + scale_factor (None | float | tuple[float]): Multiplier for spatial + size. Should match input size if it is a tuple and the 2D style is + (w_scale_factor, h_scale_factor). Default: None. + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. Default: False. + return_scale (bool): Whether to return `w_scale` and `h_scale`. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. + out (ndarray): The output destination. + backend (str | None): The image resize backend type. Options are `cv2`, + `pillow`, `None`. If backend is None, the global imread_backend + specified by ``dbnet_cv.use_backend()`` will be used. Default: None. + + Returns: + tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or + `resized_img`. + """ + h, w = img.shape[:2] + if size is not None and scale_factor is not None: + raise ValueError('only one of size or scale_factor should be defined') + elif size is None and scale_factor is None: + raise ValueError('one of size or scale_factor should be defined') + elif size is not None: + size = to_2tuple(size) + if keep_ratio: + size = rescale_size((w, h), size, return_scale=False) + else: + size = _scale_size((w, h), scale_factor) + + divisor = to_2tuple(divisor) + size = tuple(int(np.ceil(s / d)) * d for s, d in zip(size, divisor)) + resized_img, w_scale, h_scale = imresize( + img, + size, + return_scale=True, + interpolation=interpolation, + out=out, + backend=backend) + if return_scale: + return resized_img, w_scale, h_scale + else: + return resized_img + + +def imresize_like(img, + dst_img, + return_scale=False, + interpolation='bilinear', + backend=None): + """Resize image to the same size of a given image. + + Args: + img (ndarray): The input image. + dst_img (ndarray): The target image. + return_scale (bool): Whether to return `w_scale` and `h_scale`. + interpolation (str): Same as :func:`resize`. + backend (str | None): Same as :func:`resize`. + + Returns: + tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or + `resized_img`. + """ + h, w = dst_img.shape[:2] + return imresize(img, (w, h), return_scale, interpolation, backend=backend) + + +def rescale_size(old_size, scale, return_scale=False): + """Calculate the new size to be rescaled to. + + Args: + old_size (tuple[int]): The old size (w, h) of image. + scale (float | tuple[int]): The scaling factor or maximum size. + If it is a float number, then the image will be rescaled by this + factor, else if it is a tuple of 2 integers, then the image will + be rescaled as large as possible within the scale. + return_scale (bool): Whether to return the scaling factor besides the + rescaled image size. + + Returns: + tuple[int]: The new rescaled image size. + """ + w, h = old_size + if isinstance(scale, (float, int)): + if scale <= 0: + raise ValueError(f'Invalid scale {scale}, must be positive.') + scale_factor = scale + elif isinstance(scale, tuple): + max_long_edge = max(scale) + max_short_edge = min(scale) + scale_factor = min(max_long_edge / max(h, w), + max_short_edge / min(h, w)) + else: + raise TypeError( + f'Scale must be a number or tuple of int, but got {type(scale)}') + + new_size = _scale_size((w, h), scale_factor) + + if return_scale: + return new_size, scale_factor + else: + return new_size + + +def imrescale(img, + scale, + return_scale=False, + interpolation='bilinear', + backend=None): + """Resize image while keeping the aspect ratio. + + Args: + img (ndarray): The input image. + scale (float | tuple[int]): The scaling factor or maximum size. + If it is a float number, then the image will be rescaled by this + factor, else if it is a tuple of 2 integers, then the image will + be rescaled as large as possible within the scale. + return_scale (bool): Whether to return the scaling factor besides the + rescaled image. + interpolation (str): Same as :func:`resize`. + backend (str | None): Same as :func:`resize`. + + Returns: + ndarray: The rescaled image. + """ + h, w = img.shape[:2] + new_size, scale_factor = rescale_size((w, h), scale, return_scale=True) + rescaled_img = imresize( + img, new_size, interpolation=interpolation, backend=backend) + if return_scale: + return rescaled_img, scale_factor + else: + return rescaled_img + + +def imflip(img, direction='horizontal'): + """Flip an image horizontally or vertically. + + Args: + img (ndarray): Image to be flipped. + direction (str): The flip direction, either "horizontal" or + "vertical" or "diagonal". + + Returns: + ndarray: The flipped image. + """ + assert direction in ['horizontal', 'vertical', 'diagonal'] + if direction == 'horizontal': + return np.flip(img, axis=1) + elif direction == 'vertical': + return np.flip(img, axis=0) + else: + return np.flip(img, axis=(0, 1)) + + +def imflip_(img, direction='horizontal'): + """Inplace flip an image horizontally or vertically. + + Args: + img (ndarray): Image to be flipped. + direction (str): The flip direction, either "horizontal" or + "vertical" or "diagonal". + + Returns: + ndarray: The flipped image (inplace). + """ + assert direction in ['horizontal', 'vertical', 'diagonal'] + if direction == 'horizontal': + return cv2.flip(img, 1, img) + elif direction == 'vertical': + return cv2.flip(img, 0, img) + else: + return cv2.flip(img, -1, img) + + +def imrotate(img, + angle, + center=None, + scale=1.0, + border_value=0, + interpolation='bilinear', + auto_bound=False): + """Rotate an image. + + Args: + img (ndarray): Image to be rotated. + angle (float): Rotation angle in degrees, positive values mean + clockwise rotation. + center (tuple[float], optional): Center point (w, h) of the rotation in + the source image. If not specified, the center of the image will be + used. + scale (float): Isotropic scale factor. + border_value (int): Border value. + interpolation (str): Same as :func:`resize`. + auto_bound (bool): Whether to adjust the image size to cover the whole + rotated image. + + Returns: + ndarray: The rotated image. + """ + if center is not None and auto_bound: + raise ValueError('`auto_bound` conflicts with `center`') + h, w = img.shape[:2] + if center is None: + center = ((w - 1) * 0.5, (h - 1) * 0.5) + assert isinstance(center, tuple) + + matrix = cv2.getRotationMatrix2D(center, -angle, scale) + if auto_bound: + cos = np.abs(matrix[0, 0]) + sin = np.abs(matrix[0, 1]) + new_w = h * sin + w * cos + new_h = h * cos + w * sin + matrix[0, 2] += (new_w - w) * 0.5 + matrix[1, 2] += (new_h - h) * 0.5 + w = int(np.round(new_w)) + h = int(np.round(new_h)) + rotated = cv2.warpAffine( + img, + matrix, (w, h), + flags=cv2_interp_codes[interpolation], + borderValue=border_value) + return rotated + + +def bbox_clip(bboxes, img_shape): + """Clip bboxes to fit the image shape. + + Args: + bboxes (ndarray): Shape (..., 4*k) + img_shape (tuple[int]): (height, width) of the image. + + Returns: + ndarray: Clipped bboxes. + """ + assert bboxes.shape[-1] % 4 == 0 + cmin = np.empty(bboxes.shape[-1], dtype=bboxes.dtype) + cmin[0::2] = img_shape[1] - 1 + cmin[1::2] = img_shape[0] - 1 + clipped_bboxes = np.maximum(np.minimum(bboxes, cmin), 0) + return clipped_bboxes + + +def bbox_scaling(bboxes, scale, clip_shape=None): + """Scaling bboxes w.r.t the box center. + + Args: + bboxes (ndarray): Shape(..., 4). + scale (float): Scaling factor. + clip_shape (tuple[int], optional): If specified, bboxes that exceed the + boundary will be clipped according to the given shape (h, w). + + Returns: + ndarray: Scaled bboxes. + """ + if float(scale) == 1.0: + scaled_bboxes = bboxes.copy() + else: + w = bboxes[..., 2] - bboxes[..., 0] + 1 + h = bboxes[..., 3] - bboxes[..., 1] + 1 + dw = (w * (scale - 1)) * 0.5 + dh = (h * (scale - 1)) * 0.5 + scaled_bboxes = bboxes + np.stack((-dw, -dh, dw, dh), axis=-1) + if clip_shape is not None: + return bbox_clip(scaled_bboxes, clip_shape) + else: + return scaled_bboxes + + +def imcrop(img, bboxes, scale=1.0, pad_fill=None): + """Crop image patches. + + 3 steps: scale the bboxes -> clip bboxes -> crop and pad. + + Args: + img (ndarray): Image to be cropped. + bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes. + scale (float, optional): Scale ratio of bboxes, the default value + 1.0 means no padding. + pad_fill (Number | list[Number]): Value to be filled for padding. + Default: None, which means no padding. + + Returns: + list[ndarray] | ndarray: The cropped image patches. + """ + chn = 1 if img.ndim == 2 else img.shape[2] + if pad_fill is not None: + if isinstance(pad_fill, (int, float)): + pad_fill = [pad_fill for _ in range(chn)] + assert len(pad_fill) == chn + + _bboxes = bboxes[None, ...] if bboxes.ndim == 1 else bboxes + scaled_bboxes = bbox_scaling(_bboxes, scale).astype(np.int32) + clipped_bbox = bbox_clip(scaled_bboxes, img.shape) + + patches = [] + for i in range(clipped_bbox.shape[0]): + x1, y1, x2, y2 = tuple(clipped_bbox[i, :]) + if pad_fill is None: + patch = img[y1:y2 + 1, x1:x2 + 1, ...] + else: + _x1, _y1, _x2, _y2 = tuple(scaled_bboxes[i, :]) + if chn == 1: + patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1) + else: + patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1, chn) + patch = np.array( + pad_fill, dtype=img.dtype) * np.ones( + patch_shape, dtype=img.dtype) + x_start = 0 if _x1 >= 0 else -_x1 + y_start = 0 if _y1 >= 0 else -_y1 + w = x2 - x1 + 1 + h = y2 - y1 + 1 + patch[y_start:y_start + h, x_start:x_start + w, + ...] = img[y1:y1 + h, x1:x1 + w, ...] + patches.append(patch) + + if bboxes.ndim == 1: + return patches[0] + else: + return patches + + +def impad(img, + *, + shape=None, + padding=None, + pad_val=0, + padding_mode='constant'): + """Pad the given image to a certain shape or pad on all sides with + specified padding mode and padding value. + + Args: + img (ndarray): Image to be padded. + shape (tuple[int]): Expected padding shape (h, w). Default: None. + padding (int or tuple[int]): Padding on each border. If a single int is + provided this is used to pad all borders. If tuple of length 2 is + provided this is the padding on left/right and top/bottom + respectively. If a tuple of length 4 is provided this is the + padding for the left, top, right and bottom borders respectively. + Default: None. Note that `shape` and `padding` can not be both + set. + pad_val (Number | Sequence[Number]): Values to be filled in padding + areas when padding_mode is 'constant'. Default: 0. + padding_mode (str): Type of padding. Should be: constant, edge, + reflect or symmetric. Default: constant. + - constant: pads with a constant value, this value is specified + with pad_val. + - edge: pads with the last value at the edge of the image. + - reflect: pads with reflection of image without repeating the last + value on the edge. For example, padding [1, 2, 3, 4] with 2 + elements on both sides in reflect mode will result in + [3, 2, 1, 2, 3, 4, 3, 2]. + - symmetric: pads with reflection of image repeating the last value + on the edge. For example, padding [1, 2, 3, 4] with 2 elements on + both sides in symmetric mode will result in + [2, 1, 1, 2, 3, 4, 4, 3] + + Returns: + ndarray: The padded image. + """ + + assert (shape is not None) ^ (padding is not None) + if shape is not None: + width = max(shape[1] - img.shape[1], 0) + height = max(shape[0] - img.shape[0], 0) + padding = (0, 0, width, height) + + # check pad_val + if isinstance(pad_val, tuple): + assert len(pad_val) == img.shape[-1] + elif not isinstance(pad_val, numbers.Number): + raise TypeError('pad_val must be a int or a tuple. ' + f'But received {type(pad_val)}') + + # check padding + if isinstance(padding, tuple) and len(padding) in [2, 4]: + if len(padding) == 2: + padding = (padding[0], padding[1], padding[0], padding[1]) + elif isinstance(padding, numbers.Number): + padding = (padding, padding, padding, padding) + else: + raise ValueError('Padding must be a int or a 2, or 4 element tuple.' + f'But received {padding}') + + # check padding mode + assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] + + border_type = { + 'constant': cv2.BORDER_CONSTANT, + 'edge': cv2.BORDER_REPLICATE, + 'reflect': cv2.BORDER_REFLECT_101, + 'symmetric': cv2.BORDER_REFLECT + } + img = cv2.copyMakeBorder( + img, + padding[1], + padding[3], + padding[0], + padding[2], + border_type[padding_mode], + value=pad_val) + + return img + + +def impad_to_multiple(img, divisor, pad_val=0): + """Pad an image to ensure each edge to be multiple to some number. + + Args: + img (ndarray): Image to be padded. + divisor (int): Padded image edges will be multiple to divisor. + pad_val (Number | Sequence[Number]): Same as :func:`impad`. + + Returns: + ndarray: The padded image. + """ + pad_h = int(np.ceil(img.shape[0] / divisor)) * divisor + pad_w = int(np.ceil(img.shape[1] / divisor)) * divisor + return impad(img, shape=(pad_h, pad_w), pad_val=pad_val) + + +def cutout(img, shape, pad_val=0): + """Randomly cut out a rectangle from the original img. + + Args: + img (ndarray): Image to be cutout. + shape (int | tuple[int]): Expected cutout shape (h, w). If given as a + int, the value will be used for both h and w. + pad_val (int | float | tuple[int | float]): Values to be filled in the + cut area. Defaults to 0. + + Returns: + ndarray: The cutout image. + """ + + channels = 1 if img.ndim == 2 else img.shape[2] + if isinstance(shape, int): + cut_h, cut_w = shape, shape + else: + assert isinstance(shape, tuple) and len(shape) == 2, \ + f'shape must be a int or a tuple with length 2, but got type ' \ + f'{type(shape)} instead.' + cut_h, cut_w = shape + if isinstance(pad_val, (int, float)): + pad_val = tuple([pad_val] * channels) + elif isinstance(pad_val, tuple): + assert len(pad_val) == channels, \ + 'Expected the num of elements in tuple equals the channels' \ + 'of input image. Found {} vs {}'.format( + len(pad_val), channels) + else: + raise TypeError(f'Invalid type {type(pad_val)} for `pad_val`') + + img_h, img_w = img.shape[:2] + y0 = np.random.uniform(img_h) + x0 = np.random.uniform(img_w) + + y1 = int(max(0, y0 - cut_h / 2.)) + x1 = int(max(0, x0 - cut_w / 2.)) + y2 = min(img_h, y1 + cut_h) + x2 = min(img_w, x1 + cut_w) + + if img.ndim == 2: + patch_shape = (y2 - y1, x2 - x1) + else: + patch_shape = (y2 - y1, x2 - x1, channels) + + img_cutout = img.copy() + patch = np.array( + pad_val, dtype=img.dtype) * np.ones( + patch_shape, dtype=img.dtype) + img_cutout[y1:y2, x1:x2, ...] = patch + + return img_cutout + + +def _get_shear_matrix(magnitude, direction='horizontal'): + """Generate the shear matrix for transformation. + + Args: + magnitude (int | float): The magnitude used for shear. + direction (str): The flip direction, either "horizontal" + or "vertical". + + Returns: + ndarray: The shear matrix with dtype float32. + """ + if direction == 'horizontal': + shear_matrix = np.float32([[1, magnitude, 0], [0, 1, 0]]) + elif direction == 'vertical': + shear_matrix = np.float32([[1, 0, 0], [magnitude, 1, 0]]) + return shear_matrix + + +def imshear(img, + magnitude, + direction='horizontal', + border_value=0, + interpolation='bilinear'): + """Shear an image. + + Args: + img (ndarray): Image to be sheared with format (h, w) + or (h, w, c). + magnitude (int | float): The magnitude used for shear. + direction (str): The flip direction, either "horizontal" + or "vertical". + border_value (int | tuple[int]): Value used in case of a + constant border. + interpolation (str): Same as :func:`resize`. + + Returns: + ndarray: The sheared image. + """ + assert direction in ['horizontal', + 'vertical'], f'Invalid direction: {direction}' + height, width = img.shape[:2] + if img.ndim == 2: + channels = 1 + elif img.ndim == 3: + channels = img.shape[-1] + if isinstance(border_value, int): + border_value = tuple([border_value] * channels) + elif isinstance(border_value, tuple): + assert len(border_value) == channels, \ + 'Expected the num of elements in tuple equals the channels' \ + 'of input image. Found {} vs {}'.format( + len(border_value), channels) + else: + raise ValueError( + f'Invalid type {type(border_value)} for `border_value`') + shear_matrix = _get_shear_matrix(magnitude, direction) + sheared = cv2.warpAffine( + img, + shear_matrix, + (width, height), + # Note case when the number elements in `border_value` + # greater than 3 (e.g. shearing masks whose channels large + # than 3) will raise TypeError in `cv2.warpAffine`. + # Here simply slice the first 3 values in `border_value`. + borderValue=border_value[:3], + flags=cv2_interp_codes[interpolation]) + return sheared + + +def _get_translate_matrix(offset, direction='horizontal'): + """Generate the translate matrix. + + Args: + offset (int | float): The offset used for translate. + direction (str): The translate direction, either + "horizontal" or "vertical". + + Returns: + ndarray: The translate matrix with dtype float32. + """ + if direction == 'horizontal': + translate_matrix = np.float32([[1, 0, offset], [0, 1, 0]]) + elif direction == 'vertical': + translate_matrix = np.float32([[1, 0, 0], [0, 1, offset]]) + return translate_matrix + + +def imtranslate(img, + offset, + direction='horizontal', + border_value=0, + interpolation='bilinear'): + """Translate an image. + + Args: + img (ndarray): Image to be translated with format + (h, w) or (h, w, c). + offset (int | float): The offset used for translate. + direction (str): The translate direction, either "horizontal" + or "vertical". + border_value (int | tuple[int]): Value used in case of a + constant border. + interpolation (str): Same as :func:`resize`. + + Returns: + ndarray: The translated image. + """ + assert direction in ['horizontal', + 'vertical'], f'Invalid direction: {direction}' + height, width = img.shape[:2] + if img.ndim == 2: + channels = 1 + elif img.ndim == 3: + channels = img.shape[-1] + if isinstance(border_value, int): + border_value = tuple([border_value] * channels) + elif isinstance(border_value, tuple): + assert len(border_value) == channels, \ + 'Expected the num of elements in tuple equals the channels' \ + 'of input image. Found {} vs {}'.format( + len(border_value), channels) + else: + raise ValueError( + f'Invalid type {type(border_value)} for `border_value`.') + translate_matrix = _get_translate_matrix(offset, direction) + translated = cv2.warpAffine( + img, + translate_matrix, + (width, height), + # Note case when the number elements in `border_value` + # greater than 3 (e.g. translating masks whose channels + # large than 3) will raise TypeError in `cv2.warpAffine`. + # Here simply slice the first 3 values in `border_value`. + borderValue=border_value[:3], + flags=cv2_interp_codes[interpolation]) + return translated diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/image/io.py b/cv/ocr/dbnet/pytorch/dbnet_cv/image/io.py new file mode 100755 index 00000000..ba0b7918 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/image/io.py @@ -0,0 +1,314 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import io +import os.path as osp +import warnings +from pathlib import Path + +import cv2 +import numpy as np +from cv2 import (IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_IGNORE_ORIENTATION, + IMREAD_UNCHANGED) + +from dbnet_cv.fileio import FileClient +from dbnet_cv.utils import is_filepath, is_str + +try: + from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG +except ImportError: + TJCS_RGB = TJPF_GRAY = TJPF_BGR = TurboJPEG = None + +try: + from PIL import Image, ImageOps +except ImportError: + Image = None + +try: + import tifffile +except ImportError: + tifffile = None + +jpeg = None +supported_backends = ['cv2', 'turbojpeg', 'pillow', 'tifffile'] + +imread_flags = { + 'color': IMREAD_COLOR, + 'grayscale': IMREAD_GRAYSCALE, + 'unchanged': IMREAD_UNCHANGED, + 'color_ignore_orientation': IMREAD_IGNORE_ORIENTATION | IMREAD_COLOR, + 'grayscale_ignore_orientation': + IMREAD_IGNORE_ORIENTATION | IMREAD_GRAYSCALE +} + +imread_backend = 'cv2' + + +def use_backend(backend): + """Select a backend for image decoding. + + Args: + backend (str): The image decoding backend type. Options are `cv2`, + `pillow`, `turbojpeg` (see https://github.com/lilohuang/PyTurboJPEG) + and `tifffile`. `turbojpeg` is faster but it only supports `.jpeg` + file format. + """ + assert backend in supported_backends + global imread_backend + imread_backend = backend + if imread_backend == 'turbojpeg': + if TurboJPEG is None: + raise ImportError('`PyTurboJPEG` is not installed') + global jpeg + if jpeg is None: + jpeg = TurboJPEG() + elif imread_backend == 'pillow': + if Image is None: + raise ImportError('`Pillow` is not installed') + elif imread_backend == 'tifffile': + if tifffile is None: + raise ImportError('`tifffile` is not installed') + + +def _jpegflag(flag='color', channel_order='bgr'): + channel_order = channel_order.lower() + if channel_order not in ['rgb', 'bgr']: + raise ValueError('channel order must be either "rgb" or "bgr"') + + if flag == 'color': + if channel_order == 'bgr': + return TJPF_BGR + elif channel_order == 'rgb': + return TJCS_RGB + elif flag == 'grayscale': + return TJPF_GRAY + else: + raise ValueError('flag must be "color" or "grayscale"') + + +def _pillow2array(img, flag='color', channel_order='bgr'): + """Convert a pillow image to numpy array. + + Args: + img (:obj:`PIL.Image.Image`): The image loaded using PIL + flag (str): Flags specifying the color type of a loaded image, + candidates are 'color', 'grayscale' and 'unchanged'. + Default to 'color'. + channel_order (str): The channel order of the output image array, + candidates are 'bgr' and 'rgb'. Default to 'bgr'. + + Returns: + np.ndarray: The converted numpy array + """ + channel_order = channel_order.lower() + if channel_order not in ['rgb', 'bgr']: + raise ValueError('channel order must be either "rgb" or "bgr"') + + if flag == 'unchanged': + array = np.array(img) + if array.ndim >= 3 and array.shape[2] >= 3: # color image + array[:, :, :3] = array[:, :, (2, 1, 0)] # RGB to BGR + else: + # Handle exif orientation tag + if flag in ['color', 'grayscale']: + img = ImageOps.exif_transpose(img) + # If the image mode is not 'RGB', convert it to 'RGB' first. + if img.mode != 'RGB': + if img.mode != 'LA': + # Most formats except 'LA' can be directly converted to RGB + img = img.convert('RGB') + else: + # When the mode is 'LA', the default conversion will fill in + # the canvas with black, which sometimes shadows black objects + # in the foreground. + # + # Therefore, a random color (124, 117, 104) is used for canvas + img_rgba = img.convert('RGBA') + img = Image.new('RGB', img_rgba.size, (124, 117, 104)) + img.paste(img_rgba, mask=img_rgba.split()[3]) # 3 is alpha + if flag in ['color', 'color_ignore_orientation']: + array = np.array(img) + if channel_order != 'rgb': + array = array[:, :, ::-1] # RGB to BGR + elif flag in ['grayscale', 'grayscale_ignore_orientation']: + img = img.convert('L') + array = np.array(img) + else: + raise ValueError( + 'flag must be "color", "grayscale", "unchanged", ' + f'"color_ignore_orientation" or "grayscale_ignore_orientation"' + f' but got {flag}') + return array + + +def imread(img_or_path, + flag='color', + channel_order='bgr', + backend=None, + file_client_args=None): + """Read an image. + + Note: + In v1.4.1 and later, add `file_client_args` parameters. + + Args: + img_or_path (ndarray or str or Path): Either a numpy array or str or + pathlib.Path. If it is a numpy array (loaded image), then + it will be returned as is. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale`, `unchanged`, + `color_ignore_orientation` and `grayscale_ignore_orientation`. + By default, `cv2` and `pillow` backend would rotate the image + according to its EXIF info unless called with `unchanged` or + `*_ignore_orientation` flags. `turbojpeg` and `tifffile` backend + always ignore image's EXIF info regardless of the flag. + The `turbojpeg` backend only supports `color` and `grayscale`. + channel_order (str): Order of channel, candidates are `bgr` and `rgb`. + backend (str | None): The image decoding backend type. Options are + `cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`. + If backend is None, the global imread_backend specified by + ``dbnet_cv.use_backend()`` will be used. Default: None. + file_client_args (dict | None): Arguments to instantiate a + FileClient. See :class:`dbnet_cv.fileio.FileClient` for details. + Default: None. + + Returns: + ndarray: Loaded image array. + + Examples: + >>> import dbnet_cv + >>> img_path = '/path/to/img.jpg' + >>> img = dbnet_cv.imread(img_path) + >>> img = dbnet_cv.imread(img_path, flag='color', channel_order='rgb', + ... backend='cv2') + >>> img = dbnet_cv.imread(img_path, flag='color', channel_order='bgr', + ... backend='pillow') + >>> s3_img_path = 's3://bucket/img.jpg' + >>> # infer the file backend by the prefix s3 + >>> img = dbnet_cv.imread(s3_img_path) + >>> # manually set the file backend petrel + >>> img = dbnet_cv.imread(s3_img_path, file_client_args={ + ... 'backend': 'petrel'}) + >>> http_img_path = 'http://path/to/img.jpg' + >>> img = dbnet_cv.imread(http_img_path) + >>> img = dbnet_cv.imread(http_img_path, file_client_args={ + ... 'backend': 'http'}) + """ + + if isinstance(img_or_path, Path): + img_or_path = str(img_or_path) + + if isinstance(img_or_path, np.ndarray): + return img_or_path + elif is_str(img_or_path): + file_client = FileClient.infer_client(file_client_args, img_or_path) + img_bytes = file_client.get(img_or_path) + return imfrombytes(img_bytes, flag, channel_order, backend) + else: + raise TypeError('"img" must be a numpy array or a str or ' + 'a pathlib.Path object') + + +def imfrombytes(content, flag='color', channel_order='bgr', backend=None): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Same as :func:`imread`. + channel_order (str): The channel order of the output, candidates + are 'bgr' and 'rgb'. Default to 'bgr'. + backend (str | None): The image decoding backend type. Options are + `cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`. If backend is + None, the global imread_backend specified by ``dbnet_cv.use_backend()`` + will be used. Default: None. + + Returns: + ndarray: Loaded image array. + + Examples: + >>> img_path = '/path/to/img.jpg' + >>> with open(img_path, 'rb') as f: + >>> img_buff = f.read() + >>> img = dbnet_cv.imfrombytes(img_buff) + >>> img = dbnet_cv.imfrombytes(img_buff, flag='color', channel_order='rgb') + >>> img = dbnet_cv.imfrombytes(img_buff, backend='pillow') + >>> img = dbnet_cv.imfrombytes(img_buff, backend='cv2') + """ + + if backend is None: + backend = imread_backend + if backend not in supported_backends: + raise ValueError( + f'backend: {backend} is not supported. Supported ' + "backends are 'cv2', 'turbojpeg', 'pillow', 'tifffile'") + if backend == 'turbojpeg': + img = jpeg.decode(content, _jpegflag(flag, channel_order)) + if img.shape[-1] == 1: + img = img[:, :, 0] + return img + elif backend == 'pillow': + with io.BytesIO(content) as buff: + img = Image.open(buff) + img = _pillow2array(img, flag, channel_order) + return img + elif backend == 'tifffile': + with io.BytesIO(content) as buff: + img = tifffile.imread(buff) + return img + else: + img_np = np.frombuffer(content, np.uint8) + flag = imread_flags[flag] if is_str(flag) else flag + img = cv2.imdecode(img_np, flag) + if flag == IMREAD_COLOR and channel_order == 'rgb': + cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) + return img + + +def imwrite(img, + file_path, + params=None, + auto_mkdir=None, + file_client_args=None): + """Write image to file. + + Note: + In v1.4.1 and later, add `file_client_args` parameters. + + Warning: + The parameter `auto_mkdir` will be deprecated in the future and every + file clients will make directory automatically. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. It will be deprecated. + file_client_args (dict | None): Arguments to instantiate a + FileClient. See :class:`dbnet_cv.fileio.FileClient` for details. + Default: None. + + Returns: + bool: Successful or not. + + Examples: + >>> # write to hard disk client + >>> ret = dbnet_cv.imwrite(img, '/path/to/img.jpg') + >>> # infer the file backend by the prefix s3 + >>> ret = dbnet_cv.imwrite(img, 's3://bucket/img.jpg') + >>> # manually set the file backend petrel + >>> ret = dbnet_cv.imwrite(img, 's3://bucket/img.jpg', file_client_args={ + ... 'backend': 'petrel'}) + """ + assert is_filepath(file_path) + file_path = str(file_path) + if auto_mkdir is not None: + warnings.warn( + 'The parameter `auto_mkdir` will be deprecated in the future and ' + 'every file clients will make directory automatically.') + file_client = FileClient.infer_client(file_client_args, file_path) + img_ext = osp.splitext(file_path)[-1] + # Encode image according to image suffix. + # For example, if image path is '/path/your/img.jpg', the encode + # format is '.jpg'. + flag, img_buff = cv2.imencode(img_ext, img, params) + file_client.put(img_buff.tobytes(), file_path) + return flag diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/image/misc.py b/cv/ocr/dbnet/pytorch/dbnet_cv/image/misc.py new file mode 100755 index 00000000..468f44c6 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/image/misc.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +import dbnet_cv + +try: + import torch +except ImportError: + torch = None + + +def tensor2imgs(tensor, mean=None, std=None, to_rgb=True): + """Convert tensor to 3-channel images or 1-channel gray images. + + Args: + tensor (torch.Tensor): Tensor that contains multiple images, shape ( + N, C, H, W). :math:`C` can be either 3 or 1. + mean (tuple[float], optional): Mean of images. If None, + (0, 0, 0) will be used for tensor with 3-channel, + while (0, ) for tensor with 1-channel. Defaults to None. + std (tuple[float], optional): Standard deviation of images. If None, + (1, 1, 1) will be used for tensor with 3-channel, + while (1, ) for tensor with 1-channel. Defaults to None. + to_rgb (bool, optional): Whether the tensor was converted to RGB + format in the first place. If so, convert it back to BGR. + For the tensor with 1 channel, it must be False. Defaults to True. + + Returns: + list[np.ndarray]: A list that contains multiple images. + """ + + if torch is None: + raise RuntimeError('pytorch is not installed') + assert torch.is_tensor(tensor) and tensor.ndim == 4 + channels = tensor.size(1) + assert channels in [1, 3] + if mean is None: + mean = (0, ) * channels + if std is None: + std = (1, ) * channels + assert (channels == len(mean) == len(std) == 3) or \ + (channels == len(mean) == len(std) == 1 and not to_rgb) + + num_imgs = tensor.size(0) + mean = np.array(mean, dtype=np.float32) + std = np.array(std, dtype=np.float32) + imgs = [] + for img_id in range(num_imgs): + img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0) + img = dbnet_cv.imdenormalize( + img, mean, std, to_bgr=to_rgb).astype(np.uint8) + imgs.append(np.ascontiguousarray(img)) + return imgs diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/image/photometric.py b/cv/ocr/dbnet/pytorch/dbnet_cv/image/photometric.py new file mode 100755 index 00000000..bba818f6 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/image/photometric.py @@ -0,0 +1,471 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import cv2 +import numpy as np + +from ..utils import is_tuple_of +from .colorspace import bgr2gray, gray2bgr + + +def imnormalize(img, mean, std, to_rgb=True): + """Normalize an image with mean and std. + + Args: + img (ndarray): Image to be normalized. + mean (ndarray): The mean to be used for normalize. + std (ndarray): The std to be used for normalize. + to_rgb (bool): Whether to convert to rgb. + + Returns: + ndarray: The normalized image. + """ + img = img.copy().astype(np.float32) + return imnormalize_(img, mean, std, to_rgb) + + +def imnormalize_(img, mean, std, to_rgb=True): + """Inplace normalize an image with mean and std. + + Args: + img (ndarray): Image to be normalized. + mean (ndarray): The mean to be used for normalize. + std (ndarray): The std to be used for normalize. + to_rgb (bool): Whether to convert to rgb. + + Returns: + ndarray: The normalized image. + """ + # cv2 inplace normalization does not accept uint8 + assert img.dtype != np.uint8 + mean = np.float64(mean.reshape(1, -1)) + stdinv = 1 / np.float64(std.reshape(1, -1)) + if to_rgb: + cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace + cv2.subtract(img, mean, img) # inplace + cv2.multiply(img, stdinv, img) # inplace + return img + + +def imdenormalize(img, mean, std, to_bgr=True): + assert img.dtype != np.uint8 + mean = mean.reshape(1, -1).astype(np.float64) + std = std.reshape(1, -1).astype(np.float64) + img = cv2.multiply(img, std) # make a copy + cv2.add(img, mean, img) # inplace + if to_bgr: + cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace + return img + + +def iminvert(img): + """Invert (negate) an image. + + Args: + img (ndarray): Image to be inverted. + + Returns: + ndarray: The inverted image. + """ + return np.full_like(img, 255) - img + + +def solarize(img, thr=128): + """Solarize an image (invert all pixel values above a threshold) + + Args: + img (ndarray): Image to be solarized. + thr (int): Threshold for solarizing (0 - 255). + + Returns: + ndarray: The solarized image. + """ + img = np.where(img < thr, img, 255 - img) + return img + + +def posterize(img, bits): + """Posterize an image (reduce the number of bits for each color channel) + + Args: + img (ndarray): Image to be posterized. + bits (int): Number of bits (1 to 8) to use for posterizing. + + Returns: + ndarray: The posterized image. + """ + shift = 8 - bits + img = np.left_shift(np.right_shift(img, shift), shift) + return img + + +def adjust_color(img, alpha=1, beta=None, gamma=0): + r"""It blends the source image and its gray image: + + .. math:: + output = img * alpha + gray\_img * beta + gamma + + Args: + img (ndarray): The input source image. + alpha (int | float): Weight for the source image. Default 1. + beta (int | float): Weight for the converted gray image. + If None, it's assigned the value (1 - `alpha`). + gamma (int | float): Scalar added to each sum. + Same as :func:`cv2.addWeighted`. Default 0. + + Returns: + ndarray: Colored image which has the same size and dtype as input. + """ + gray_img = bgr2gray(img) + gray_img = np.tile(gray_img[..., None], [1, 1, 3]) + if beta is None: + beta = 1 - alpha + colored_img = cv2.addWeighted(img, alpha, gray_img, beta, gamma) + if not colored_img.dtype == np.uint8: + # Note when the dtype of `img` is not the default `np.uint8` + # (e.g. np.float32), the value in `colored_img` got from cv2 + # is not guaranteed to be in range [0, 255], so here clip + # is needed. + colored_img = np.clip(colored_img, 0, 255) + return colored_img + + +def imequalize(img): + """Equalize the image histogram. + + This function applies a non-linear mapping to the input image, + in order to create a uniform distribution of grayscale values + in the output image. + + Args: + img (ndarray): Image to be equalized. + + Returns: + ndarray: The equalized image. + """ + + def _scale_channel(im, c): + """Scale the data in the corresponding channel.""" + im = im[:, :, c] + # Compute the histogram of the image channel. + histo = np.histogram(im, 256, (0, 255))[0] + # For computing the step, filter out the nonzeros. + nonzero_histo = histo[histo > 0] + step = (np.sum(nonzero_histo) - nonzero_histo[-1]) // 255 + if not step: + lut = np.array(range(256)) + else: + # Compute the cumulative sum, shifted by step // 2 + # and then normalized by step. + lut = (np.cumsum(histo) + (step // 2)) // step + # Shift lut, prepending with 0. + lut = np.concatenate([[0], lut[:-1]], 0) + # handle potential integer overflow + lut[lut > 255] = 255 + # If step is zero, return the original image. + # Otherwise, index from lut. + return np.where(np.equal(step, 0), im, lut[im]) + + # Scales each channel independently and then stacks + # the result. + s1 = _scale_channel(img, 0) + s2 = _scale_channel(img, 1) + s3 = _scale_channel(img, 2) + equalized_img = np.stack([s1, s2, s3], axis=-1) + return equalized_img.astype(img.dtype) + + +def adjust_brightness(img, factor=1.): + """Adjust image brightness. + + This function controls the brightness of an image. An + enhancement factor of 0.0 gives a black image. + A factor of 1.0 gives the original image. This function + blends the source image and the degenerated black image: + + .. math:: + output = img * factor + degenerated * (1 - factor) + + Args: + img (ndarray): Image to be brightened. + factor (float): A value controls the enhancement. + Factor 1.0 returns the original image, lower + factors mean less color (brightness, contrast, + etc), and higher values more. Default 1. + + Returns: + ndarray: The brightened image. + """ + degenerated = np.zeros_like(img) + # Note manually convert the dtype to np.float32, to + # achieve as close results as PIL.ImageEnhance.Brightness. + # Set beta=1-factor, and gamma=0 + brightened_img = cv2.addWeighted( + img.astype(np.float32), factor, degenerated.astype(np.float32), + 1 - factor, 0) + brightened_img = np.clip(brightened_img, 0, 255) + return brightened_img.astype(img.dtype) + + +def adjust_contrast(img, factor=1.): + """Adjust image contrast. + + This function controls the contrast of an image. An + enhancement factor of 0.0 gives a solid grey + image. A factor of 1.0 gives the original image. It + blends the source image and the degenerated mean image: + + .. math:: + output = img * factor + degenerated * (1 - factor) + + Args: + img (ndarray): Image to be contrasted. BGR order. + factor (float): Same as :func:`dbnet_cv.adjust_brightness`. + + Returns: + ndarray: The contrasted image. + """ + gray_img = bgr2gray(img) + hist = np.histogram(gray_img, 256, (0, 255))[0] + mean = round(np.sum(gray_img) / np.sum(hist)) + degenerated = (np.ones_like(img[..., 0]) * mean).astype(img.dtype) + degenerated = gray2bgr(degenerated) + contrasted_img = cv2.addWeighted( + img.astype(np.float32), factor, degenerated.astype(np.float32), + 1 - factor, 0) + contrasted_img = np.clip(contrasted_img, 0, 255) + return contrasted_img.astype(img.dtype) + + +def auto_contrast(img, cutoff=0): + """Auto adjust image contrast. + + This function maximize (normalize) image contrast by first removing cutoff + percent of the lightest and darkest pixels from the histogram and remapping + the image so that the darkest pixel becomes black (0), and the lightest + becomes white (255). + + Args: + img (ndarray): Image to be contrasted. BGR order. + cutoff (int | float | tuple): The cutoff percent of the lightest and + darkest pixels to be removed. If given as tuple, it shall be + (low, high). Otherwise, the single value will be used for both. + Defaults to 0. + + Returns: + ndarray: The contrasted image. + """ + + def _auto_contrast_channel(im, c, cutoff): + im = im[:, :, c] + # Compute the histogram of the image channel. + histo = np.histogram(im, 256, (0, 255))[0] + # Remove cut-off percent pixels from histo + histo_sum = np.cumsum(histo) + cut_low = histo_sum[-1] * cutoff[0] // 100 + cut_high = histo_sum[-1] - histo_sum[-1] * cutoff[1] // 100 + histo_sum = np.clip(histo_sum, cut_low, cut_high) - cut_low + histo = np.concatenate([[histo_sum[0]], np.diff(histo_sum)], 0) + + # Compute mapping + low, high = np.nonzero(histo)[0][0], np.nonzero(histo)[0][-1] + # If all the values have been cut off, return the origin img + if low >= high: + return im + scale = 255.0 / (high - low) + offset = -low * scale + lut = np.array(range(256)) + lut = lut * scale + offset + lut = np.clip(lut, 0, 255) + return lut[im] + + if isinstance(cutoff, (int, float)): + cutoff = (cutoff, cutoff) + else: + assert isinstance(cutoff, tuple), 'cutoff must be of type int, ' \ + f'float or tuple, but got {type(cutoff)} instead.' + # Auto adjusts contrast for each channel independently and then stacks + # the result. + s1 = _auto_contrast_channel(img, 0, cutoff) + s2 = _auto_contrast_channel(img, 1, cutoff) + s3 = _auto_contrast_channel(img, 2, cutoff) + contrasted_img = np.stack([s1, s2, s3], axis=-1) + return contrasted_img.astype(img.dtype) + + +def adjust_sharpness(img, factor=1., kernel=None): + """Adjust image sharpness. + + This function controls the sharpness of an image. An + enhancement factor of 0.0 gives a blurred image. A + factor of 1.0 gives the original image. And a factor + of 2.0 gives a sharpened image. It blends the source + image and the degenerated mean image: + + .. math:: + output = img * factor + degenerated * (1 - factor) + + Args: + img (ndarray): Image to be sharpened. BGR order. + factor (float): Same as :func:`dbnet_cv.adjust_brightness`. + kernel (np.ndarray, optional): Filter kernel to be applied on the img + to obtain the degenerated img. Defaults to None. + + Note: + No value sanity check is enforced on the kernel set by users. So with + an inappropriate kernel, the ``adjust_sharpness`` may fail to perform + the function its name indicates but end up performing whatever + transform determined by the kernel. + + Returns: + ndarray: The sharpened image. + """ + + if kernel is None: + # adopted from PIL.ImageFilter.SMOOTH + kernel = np.array([[1., 1., 1.], [1., 5., 1.], [1., 1., 1.]]) / 13 + assert isinstance(kernel, np.ndarray), \ + f'kernel must be of type np.ndarray, but got {type(kernel)} instead.' + assert kernel.ndim == 2, \ + f'kernel must have a dimension of 2, but got {kernel.ndim} instead.' + + degenerated = cv2.filter2D(img, -1, kernel) + sharpened_img = cv2.addWeighted( + img.astype(np.float32), factor, degenerated.astype(np.float32), + 1 - factor, 0) + sharpened_img = np.clip(sharpened_img, 0, 255) + return sharpened_img.astype(img.dtype) + + +def adjust_lighting(img, eigval, eigvec, alphastd=0.1, to_rgb=True): + """AlexNet-style PCA jitter. + + This data augmentation is proposed in `ImageNet Classification with Deep + Convolutional Neural Networks + `_. + + Args: + img (ndarray): Image to be adjusted lighting. BGR order. + eigval (ndarray): the eigenvalue of the convariance matrix of pixel + values, respectively. + eigvec (ndarray): the eigenvector of the convariance matrix of pixel + values, respectively. + alphastd (float): The standard deviation for distribution of alpha. + Defaults to 0.1 + to_rgb (bool): Whether to convert img to rgb. + + Returns: + ndarray: The adjusted image. + """ + assert isinstance(eigval, np.ndarray) and isinstance(eigvec, np.ndarray), \ + f'eigval and eigvec should both be of type np.ndarray, got ' \ + f'{type(eigval)} and {type(eigvec)} instead.' + + assert eigval.ndim == 1 and eigvec.ndim == 2 + assert eigvec.shape == (3, eigval.shape[0]) + n_eigval = eigval.shape[0] + assert isinstance(alphastd, float), 'alphastd should be of type float, ' \ + f'got {type(alphastd)} instead.' + + img = img.copy().astype(np.float32) + if to_rgb: + cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace + + alpha = np.random.normal(0, alphastd, n_eigval) + alter = eigvec \ + * np.broadcast_to(alpha.reshape(1, n_eigval), (3, n_eigval)) \ + * np.broadcast_to(eigval.reshape(1, n_eigval), (3, n_eigval)) + alter = np.broadcast_to(alter.sum(axis=1).reshape(1, 1, 3), img.shape) + img_adjusted = img + alter + return img_adjusted + + +def lut_transform(img, lut_table): + """Transform array by look-up table. + + The function lut_transform fills the output array with values from the + look-up table. Indices of the entries are taken from the input array. + + Args: + img (ndarray): Image to be transformed. + lut_table (ndarray): look-up table of 256 elements; in case of + multi-channel input array, the table should either have a single + channel (in this case the same table is used for all channels) or + the same number of channels as in the input array. + + Returns: + ndarray: The transformed image. + """ + assert isinstance(img, np.ndarray) + assert 0 <= np.min(img) and np.max(img) <= 255 + assert isinstance(lut_table, np.ndarray) + assert lut_table.shape == (256, ) + + return cv2.LUT(np.array(img, dtype=np.uint8), lut_table) + + +def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)): + """Use CLAHE method to process the image. + + See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J]. + Graphics Gems, 1994:474-485.` for more information. + + Args: + img (ndarray): Image to be processed. + clip_limit (float): Threshold for contrast limiting. Default: 40.0. + tile_grid_size (tuple[int]): Size of grid for histogram equalization. + Input image will be divided into equally sized rectangular tiles. + It defines the number of tiles in row and column. Default: (8, 8). + + Returns: + ndarray: The processed image. + """ + assert isinstance(img, np.ndarray) + assert img.ndim == 2 + assert isinstance(clip_limit, (float, int)) + assert is_tuple_of(tile_grid_size, int) + assert len(tile_grid_size) == 2 + + clahe = cv2.createCLAHE(clip_limit, tile_grid_size) + return clahe.apply(np.array(img, dtype=np.uint8)) + + +def adjust_hue(img: np.ndarray, hue_factor: float) -> np.ndarray: + """Adjust hue of an image. + + The image hue is adjusted by converting the image to HSV and cyclically + shifting the intensities in the hue channel (H). The image is then + converted back to original image mode. + + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + + Modified from + https://github.com/pytorch/vision/blob/main/torchvision/ + transforms/functional.py + + Args: + img (ndarray): Image to be adjusted. + hue_factor (float): How much to shift the hue channel. Should be in + [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in + HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + + Returns: + ndarray: Hue adjusted image. + """ + + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError(f'hue_factor:{hue_factor} is not in [-0.5, 0.5].') + if not (isinstance(img, np.ndarray) and (img.ndim in {2, 3})): + raise TypeError('img should be ndarray with dim=[2 or 3].') + + dtype = img.dtype + img = img.astype(np.uint8) + hsv_img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV_FULL) + h, s, v = cv2.split(hsv_img) + h = h.astype(np.uint8) + # uint8 addition take cares of rotation across boundaries + with np.errstate(over='ignore'): + h += np.uint8(hue_factor * 255) + hsv_img = cv2.merge([h, s, v]) + return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2RGB_FULL).astype(dtype) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/deprecated.json b/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/deprecated.json new file mode 100755 index 00000000..25cf6f28 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/deprecated.json @@ -0,0 +1,6 @@ +{ + "resnet50_caffe": "detectron/resnet50_caffe", + "resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr", + "resnet101_caffe": "detectron/resnet101_caffe", + "resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr" +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/mmcls.json b/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/mmcls.json new file mode 100755 index 00000000..c073a41d --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/mmcls.json @@ -0,0 +1,59 @@ +{ + "vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_batch256_imagenet_20210208-4271cd6c.pth", + "vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_batch256_imagenet_20210208-4d1d6080.pth", + "vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_batch256_imagenet_20210208-db26f1a5.pth", + "vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_batch256_imagenet_20210208-e6920e4a.pth", + "vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth", + "vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_batch256_imagenet_20210207-1a8b7864.pth", + "vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_batch256_imagenet_20210208-7e55cd29.pth", + "vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth", + "resnet18": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth", + "resnet34": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth", + "resnet50": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth", + "resnet101": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth", + "resnet152": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.pth", + "resnet50_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.pth", + "resnet101_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.pth", + "resnet152_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.pth", + "resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_b32x8_imagenet_20210429-56066e27.pth", + "resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_b32x8_imagenet_20210506-e0fa3dd5.pth", + "resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_b32x8_imagenet_20210506-23a247d5.pth", + "resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_b32x8_imagenet_20210524-927787be.pth", + "se-resnet50": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth", + "se-resnet101": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet101_batch256_imagenet_20200804-ba5b51d4.pth", + "resnest50": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth", + "resnest101": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth", + "resnest200": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth", + "resnest269": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest269_imagenet_converted-59930960.pth", + "shufflenet_v1": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v1/shufflenet_v1_batch1024_imagenet_20200804-5d6cec73.pth", + "shufflenet_v2": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth", + "mobilenet_v2": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth", + "mobilenet_v3_small": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_small-8427ecf0.pth", + "mobilenet_v3_large": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_large-3ea3c186.pth", + "repvgg_A0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_3rdparty_4xb64-coslr-120e_in1k_20210909-883ab98c.pth", + "repvgg_A1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A1_3rdparty_4xb64-coslr-120e_in1k_20210909-24003a24.pth", + "repvgg_A2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A2_3rdparty_4xb64-coslr-120e_in1k_20210909-97d7695a.pth", + "repvgg_B0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B0_3rdparty_4xb64-coslr-120e_in1k_20210909-446375f4.pth", + "repvgg_B1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1_3rdparty_4xb64-coslr-120e_in1k_20210909-750cdf67.pth", + "repvgg_B1g2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g2_3rdparty_4xb64-coslr-120e_in1k_20210909-344f6422.pth", + "repvgg_B1g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g4_3rdparty_4xb64-coslr-120e_in1k_20210909-d4c1a642.pth", + "repvgg_B2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2_3rdparty_4xb64-coslr-120e_in1k_20210909-bd6b937c.pth", + "repvgg_B2g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-7b7955f0.pth", + "repvgg_B3": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-dda968bf.pth", + "repvgg_B3g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-4e54846a.pth", + "repvgg_D2se": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-D2se_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-cf3139b7.pth", + "res2net101_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net101-w26-s4_3rdparty_8xb32_in1k_20210927-870b6c36.pth", + "res2net50_w14": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w14-s8_3rdparty_8xb32_in1k_20210927-bc967bf1.pth", + "res2net50_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w26-s8_3rdparty_8xb32_in1k_20210927-f547a94b.pth", + "swin_tiny": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth", + "swin_small": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth", + "swin_base": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window7_224_22kto1k-f967f799.pth", + "swin_large": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_large_patch4_window7_224_22kto1k-5f0996db.pth", + "t2t_vit_t_14": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth", + "t2t_vit_t_19": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth", + "t2t_vit_t_24": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth", + "tnt_small": "https://download.openmmlab.com/mmclassification/v0/tnt/tnt-small-p16_3rdparty_in1k_20210903-c56ee7df.pth", + "vit_base_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-98e8652b.pth", + "vit_base_p32": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p32_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-9cea8599.pth", + "vit_large_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-large-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-b20ba619.pth" +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/open_mmlab.json b/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/open_mmlab.json new file mode 100755 index 00000000..a2a0772e --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/open_mmlab.json @@ -0,0 +1,50 @@ +{ + "vgg16_caffe": "https://download.openmmlab.com/pretrain/third_party/vgg16_caffe-292e1171.pth", + "detectron/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth", + "detectron2/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_msra-5891d200.pth", + "detectron/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_caffe-3ad79236.pth", + "detectron2/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_msra-6cc46731.pth", + "detectron2/resnext101_32x8d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x8d-1516f1aa.pth", + "resnext50_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth", + "resnext101_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth", + "resnext101_64x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth", + "contrib/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth", + "detectron/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn-9186a21c.pth", + "detectron/resnet101_gn": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn-cac0ab98.pth", + "jhu/resnet50_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth", + "jhu/resnet101_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth", + "jhu/resnext50_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth", + "jhu/resnext101_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth", + "jhu/resnext50_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth", + "jhu/resnext101_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth", + "msra/hrnetv2_w18_small": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18_small-b5a04e21.pth", + "msra/hrnetv2_w18": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth", + "msra/hrnetv2_w32": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth", + "msra/hrnetv2_w40": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth", + "msra/hrnetv2_w48": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w48-d2186c55.pth", + "bninception_caffe": "https://download.openmmlab.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth", + "kin400/i3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth", + "kin400/nl3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth", + "res2net101_v1d_26w_4s": "https://download.openmmlab.com/pretrain/third_party/res2net101_v1d_26w_4s_dbnet_detv2-f0a600f9.pth", + "regnetx_400mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_400mf-a5b10d96.pth", + "regnetx_800mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth", + "regnetx_1.6gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth", + "regnetx_3.2gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth", + "regnetx_4.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth", + "regnetx_6.4gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth", + "regnetx_8.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth", + "regnetx_12gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth", + "resnet18_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet18_v1c-b5776b93.pth", + "resnet50_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet50_v1c-2cccc1ad.pth", + "resnet101_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet101_v1c-e67eebb6.pth", + "mmedit/vgg16": "https://download.openmmlab.com/mmediting/third_party/vgg_state_dict.pth", + "mmedit/res34_en_nomixup": "https://download.openmmlab.com/mmediting/third_party/model_best_resnet34_En_nomixup.pth", + "mmedit/mobilenet_v2": "https://download.openmmlab.com/mmediting/third_party/mobilenet_v2.pth", + "contrib/mobilenet_v3_large": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth", + "contrib/mobilenet_v3_small": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_small-47085aa1.pth", + "resnest50": "https://download.openmmlab.com/pretrain/third_party/resnest50_d2-7497a55b.pth", + "resnest101": "https://download.openmmlab.com/pretrain/third_party/resnest101_d2-f3b931b2.pth", + "resnest200": "https://download.openmmlab.com/pretrain/third_party/resnest200_d2-ca88e41f.pth", + "darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth", + "dbnet_det/mobilenet_v2": "https://download.openmmlab.com/dbnet_detection/v2.0/third_party/mobilenet_v2_batch256_imagenet-ff34753d.pth" +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/torchvision_0.12.json b/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/torchvision_0.12.json new file mode 100755 index 00000000..06defe67 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/model_zoo/torchvision_0.12.json @@ -0,0 +1,57 @@ +{ + "alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", + "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", + "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", + "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", + "densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth", + "efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", + "efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", + "efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", + "efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", + "efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", + "efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", + "efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", + "efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", + "googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth", + "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", + "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", + "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", + "mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", + "regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", + "regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", + "regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", + "regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", + "regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", + "regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", + "regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", + "regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", + "regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", + "regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", + "regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", + "regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", + "regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", + "regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", + "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth", + "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth", + "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth", + "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth", + "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth", + "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", + "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", + "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", + "shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", + "shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", + "shufflenetv2_x1.5": null, + "shufflenetv2_x2.0": null, + "squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", + "squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", + "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth", + "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth", + "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", + "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", + "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth", + "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", + "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", + "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth" +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/__init__.py new file mode 100755 index 00000000..035b0517 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/__init__.py @@ -0,0 +1,104 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# from .active_rotated_filter import active_rotated_filter +# from .assign_score_withk import assign_score_withk +# from .ball_query import ball_query +# from .bbox import bbox_overlaps +# from .border_align import BorderAlign, border_align +# from .box_iou_rotated import box_iou_rotated +# from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive +# from .cc_attention import CrissCrossAttention +# from .chamfer_distance import chamfer_distance +# from .contour_expand import contour_expand +# from .convex_iou import convex_giou, convex_iou +# from .corner_pool import CornerPool +# from .correlation import Correlation +from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d +from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack, + ModulatedDeformRoIPoolPack, deform_roi_pool) +from .deprecated_wrappers import Conv2d_deprecated as Conv2d +from .deprecated_wrappers import ConvTranspose2d_deprecated as ConvTranspose2d +from .deprecated_wrappers import Linear_deprecated as Linear +from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d +# from .diff_iou_rotated import diff_iou_rotated_2d, diff_iou_rotated_3d +# from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss, +# sigmoid_focal_loss, softmax_focal_loss) +# from .furthest_point_sample import (furthest_point_sample, +# furthest_point_sample_with_dist) +# from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu +# from .gather_points import gather_points +# from .group_points import GroupAll, QueryAndGroup, grouping_operation +from .info import (get_compiler_version, get_compiling_cuda_version, + get_onnxruntime_op_path) +# from .iou3d import (boxes_iou3d, boxes_iou_bev, boxes_overlap_bev, nms3d, +# nms3d_normal, nms_bev, nms_normal_bev) +# from .knn import knn +# from .masked_conv import MaskedConv2d, masked_conv2d +# from .min_area_polygons import min_area_polygons +from .modulated_deform_conv import (ModulatedDeformConv2d, + ModulatedDeformConv2dPack, + modulated_deform_conv2d) +from .multi_scale_deform_attn import MultiScaleDeformableAttention +# from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms +# from .pixel_group import pixel_group +# from .point_sample import (SimpleRoIAlign, point_sample, +# rel_roi_point_to_rel_img_point) +# from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu, +# points_in_boxes_part) +# from .points_in_polygons import points_in_polygons +# from .points_sampler import PointsSampler +# from .psa_mask import PSAMask +# from .riroi_align_rotated import RiRoIAlignRotated, riroi_align_rotated +from .roi_align import RoIAlign, roi_align +# from .roi_align_rotated import RoIAlignRotated, roi_align_rotated +from .roi_pool import RoIPool, roi_pool +# from .roiaware_pool3d import RoIAwarePool3d +# from .roipoint_pool3d import RoIPointPool3d +# from .rotated_feature_align import rotated_feature_align +# from .saconv import SAConv2d +# from .scatter_points import DynamicScatter, dynamic_scatter +# from .sparse_conv import (SparseConv2d, SparseConv3d, SparseConvTranspose2d, +# SparseConvTranspose3d, SparseInverseConv2d, +# SparseInverseConv3d, SubMConv2d, SubMConv3d) +# from .sparse_modules import SparseModule, SparseSequential +# from .sparse_pool import SparseMaxPool2d, SparseMaxPool3d +# from .sparse_structure import SparseConvTensor, scatter_nd +from .sync_bn import SyncBatchNorm +# from .three_interpolate import three_interpolate +# from .three_nn import three_nn +# from .tin_shift import TINShift, tin_shift +# from .upfirdn2d import upfirdn2d +# from .voxelize import Voxelization, voxelization + +# __all__ = [ +# 'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe', +# 'carafe_naive', 'CornerPool', 'DeformConv2d', 'DeformConv2dPack', +# 'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack', +# 'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss', +# 'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss', +# 'get_compiler_version', 'get_compiling_cuda_version', +# 'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d', +# 'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack', +# 'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match', +# 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d', +# 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask', +# 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign', +# 'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk', +# 'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query', +# 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', +# 'rotated_feature_align', 'RiRoIAlignRotated', 'riroi_align_rotated', +# 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup', +# 'GroupAll', 'grouping_operation', 'contour_expand', 'three_nn', +# 'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign', +# 'border_align', 'gather_points', 'furthest_point_sample', +# 'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation', +# 'boxes_iou3d', 'boxes_iou_bev', 'boxes_overlap_bev', 'nms_bev', +# 'nms_normal_bev', 'nms3d', 'nms3d_normal', 'Voxelization', 'voxelization', +# 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d', 'SparseConv2d', +# 'SparseConv3d', 'SparseConvTranspose2d', 'SparseConvTranspose3d', +# 'SparseInverseConv2d', 'SparseInverseConv3d', 'SubMConv2d', 'SubMConv3d', +# 'SparseModule', 'SparseSequential', 'SparseMaxPool2d', 'SparseMaxPool3d', +# 'SparseConvTensor', 'scatter_nd', 'points_in_boxes_part', +# 'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons', +# 'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou', +# 'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance' +# ] diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/README.md b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/README.md new file mode 100755 index 00000000..3a979b10 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/README.md @@ -0,0 +1,170 @@ +# Code Structure of CUDA operators + +This folder contains all non-python code for DBNET_CV custom ops. Please follow the same architecture if you want to add new ops. + +## Directories Tree + +```folder +. +├── common +│ ├── box_iou_rotated_utils.hpp +│ ├── parrots_cpp_helper.hpp +│ ├── parrots_cuda_helper.hpp +│ ├── pytorch_cpp_helper.hpp +│ ├── pytorch_cuda_helper.hpp +│ ├── pytorch_device_registry.hpp +│   └── cuda +│   ├── common_cuda_helper.hpp +│   ├── parrots_cudawarpfunction.cuh +│   ├── ... +│   └── ops_cuda_kernel.cuh +├── onnxruntime +│   ├── onnxruntime_register.h +│   ├── onnxruntime_session_options_config_keys.h +│   ├── ort_dbnet_cv_utils.h +│   ├── ... +│   ├── onnx_ops.h +│   └── cpu +│ ├── onnxruntime_register.cpp +│      ├── ... +│      └── onnx_ops_impl.cpp +├── parrots +│   ├── ... +│   ├── ops.cpp +│   ├── ops_parrots.cpp +│   └── ops_pytorch.h +├── pytorch +│   ├── info.cpp +│   ├── pybind.cpp +│   ├── ... +│   ├── ops.cpp +│   ├── cuda +│   │   ├── ... +│   │   └── ops_cuda.cu +│   └── cpu +│      ├── ... +│      └── ops.cpp +└── tensorrt + ├── trt_cuda_helper.cuh + ├── trt_plugin_helper.hpp + ├── trt_plugin.hpp + ├── trt_serialize.hpp + ├── ... + ├── trt_ops.hpp + └── plugins +    ├── trt_cuda_helper.cu +    ├── trt_plugin.cpp +    ├── ... +    ├── trt_ops.cpp +    └── trt_ops_kernel.cu +``` + +## Components + +- `common`: This directory contains all tools and shared codes. + - `cuda`: The cuda kernels which can be shared by all backends. **HIP** kernel is also here since they have similar syntax. +- `onnxruntime`: **ONNX Runtime** support for custom ops. + - `cpu`: CPU implementation of supported ops. +- `parrots`: **Parrots** is a deep learning frame for model training and inference. Parrots custom ops are placed in this directory. +- `pytorch`: **PyTorch** custom ops are supported by binding C++ to Python with **pybind11**. The ops implementation and binding codes are placed in this directory. + - `cuda`: This directory contains cuda kernel launchers, which feed memory pointers of tensor to the cuda kernel in `common/cuda`. The launchers provide c++ interface of cuda implementation of corresponding custom ops. + - `cpu`: This directory contain cpu implementations of corresponding custom ops. +- `tensorrt`: **TensorRT** support for custom ops. + - `plugins`: This directory contains the implementation of the supported custom ops. Some ops might also use shared cuda kernel in `common/cuda`. + +## How to add new PyTorch ops? + +1. (Optional) Add shared kernel in `common` to support special hardware platform. + + ```c++ + // src/common/cuda/new_ops_cuda_kernel.cuh + + template + __global__ void new_ops_forward_cuda_kernel(const T* input, T* output, ...) { + // forward here + } + + ``` + + Add cuda kernel launcher in `pytorch/cuda`. + + ```c++ + // src/pytorch/cuda + #include + + void NewOpsForwardCUDAKernelLauncher(Tensor input, Tensor output, ...){ + // initialize + at::cuda::CUDAGuard device_guard(input.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + ... + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "new_ops_forward_cuda_kernel", ([&] { + new_ops_forward_cuda_kernel + <<>>( + input.data_ptr(), output.data_ptr(),...); + })); + AT_CUDA_CHECK(cudaGetLastError()); + } + ``` + +2. Register implementation for different devices. + + ```c++ + // src/pytorch/cuda/cudabind.cpp + ... + + Tensor new_ops_forward_cuda(Tensor input, Tensor output, ...){ + // implement cuda forward here + // use `NewOpsForwardCUDAKernelLauncher` here + } + // declare interface here. + Tensor new_ops_forward_impl(Tensor input, Tensor output, ...); + // register the implementation for given device (CUDA here). + REGISTER_DEVICE_IMPL(new_ops_forward_impl, CUDA, new_ops_forward_cuda); + ``` + +3. Add ops implementation in `pytorch` directory. Select different implementations according to device type. + + ```c++ + // src/pytorch/new_ops.cpp + Tensor new_ops_forward_impl(Tensor input, Tensor output, ...){ + // dispatch the implementation according to the device type of input. + DISPATCH_DEVICE_IMPL(new_ops_forward_impl, input, output, ...); + } + ... + + Tensor new_ops_forward(Tensor input, Tensor output, ...){ + return new_ops_forward_impl(input, output, ...); + } + ``` + +4. Binding the implementation in `pytorch/pybind.cpp` + + ```c++ + // src/pytorch/pybind.cpp + + ... + + Tensor new_ops_forward(Tensor input, Tensor output, ...); + + ... + + // bind with pybind11 + m.def("new_ops_forward", &new_ops_forward, "new_ops_forward", + py::arg("input"), py::arg("output"), ...); + + ... + + ``` + +5. Build DBNET_CV again. Enjoy new ops in python + + ```python + from ..utils import ext_loader + ext_module = ext_loader.load_ext('_ext', ['new_ops_forward']) + + ... + + ext_module.new_ops_forward(input, output, ...) + + ``` diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/common_cuda_helper.hpp b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/common_cuda_helper.hpp new file mode 100755 index 00000000..e18036ba --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/common_cuda_helper.hpp @@ -0,0 +1,120 @@ +#ifndef COMMON_CUDA_HELPER +#define COMMON_CUDA_HELPER + +#include + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) \ + for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); \ + j += blockDim.y * gridDim.y) + +#define CUDA_2D_KERNEL_BLOCK_LOOP(i, n, j, m) \ + for (size_t i = blockIdx.x; i < (n); i += gridDim.x) \ + for (size_t j = blockIdx.y; j < (m); j += gridDim.y) + +#define THREADS_PER_BLOCK 512 + +inline int GET_BLOCKS(const int N, const int num_threads = THREADS_PER_BLOCK) { + int optimal_block_num = (N + num_threads - 1) / num_threads; + int max_block_num = 4096; + return std::min(optimal_block_num, max_block_num); +} + +template +__device__ T bilinear_interpolate(const T* input, const int height, + const int width, T y, T x, + const int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) return 0; + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + // do bilinear interpolation + T v1 = input[y_low * width + x_low]; + T v2 = input[y_low * width + x_high]; + T v3 = input[y_high * width + x_low]; + T v4 = input[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} + +template +__device__ void bilinear_interpolate_gradient( + const int height, const int width, T y, T x, T& w1, T& w2, T& w3, T& w4, + int& x_low, int& x_high, int& y_low, int& y_high, + const int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + return; +} +#endif // COMMON_CUDA_HELPER diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/deform_conv_cuda_kernel.cuh b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/deform_conv_cuda_kernel.cuh new file mode 100755 index 00000000..9e864bb8 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/deform_conv_cuda_kernel.cuh @@ -0,0 +1,367 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer + ***************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE + *FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + *DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + *SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + *CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + *OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer + ********************* + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/dbnet_detection/dbnet_det/ops/dcn/src/deform_conv_cuda_kernel.cu + +#ifndef DEFORM_CONV_CUDA_KERNEL_CUH +#define DEFORM_CONV_CUDA_KERNEL_CUH + +#include +#ifdef DBNET_CV_WITH_TRT +#include "common_cuda_helper.hpp" +#else // DBNET_CV_WITH_TRT +#ifdef DBNET_CV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else // DBNET_CV_USE_PARROTS +#include "pytorch_cuda_helper.hpp" +#endif // DBNET_CV_USE_PARROTS +#endif // DBNET_CV_WITH_TRT + +template +__device__ T deformable_im2col_bilinear(const T *input, const int data_width, + const int height, const int width, T h, + T w) { + if (h <= -1 || height <= h || w <= -1 || width <= w) { + return 0; + } + + int h_low = floorf(h); + int w_low = floorf(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + T lh = h - h_low; + T lw = w - w_low; + T hh = 1 - lh, hw = 1 - lw; + + T v1 = 0; + if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low]; + T v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = input[h_low * data_width + w_high]; + T v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = input[h_high * data_width + w_low]; + T v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = input[h_high * data_width + w_high]; + + T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ T get_gradient_weight(T argmax_h, T argmax_w, const int h, + const int w, const int height, + const int width) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + // empty + return 0; + } + + int argmax_h_low = floorf(argmax_h); + int argmax_w_low = floorf(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + T weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ T get_coordinate_weight(T argmax_h, T argmax_w, const int height, + const int width, const T *im_data, + const int data_width, const int bp_dir) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + // empty + return 0; + } + + int argmax_h_low = floorf(argmax_h); + int argmax_w_low = floorf(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + T weight = 0; + + if (bp_dir == 0) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } else if (bp_dir == 1) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void deformable_im2col_gpu_kernel( + const int n, const T *data_im, const T *data_offset, const int height, + const int width, const int kernel_h, const int kernel_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int num_channels, const int deformable_group, const int height_col, + const int width_col, T *data_col) { + CUDA_1D_KERNEL_LOOP(index, n) { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + T *data_col_ptr = + data_col + + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + const T *data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const T *data_offset_ptr = + data_offset + (b_col * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + + w_col; + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + T val = static_cast(0); + const T h_im = h_in + i * dilation_h + offset_h; + const T w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + val = deformable_im2col_bilinear(data_im_ptr, width, height, width, + h_im, w_im); + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, const T *data_col, const T *data_offset, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int deformable_group, const int height_col, const int width_col, + T *grad_im) { + CUDA_1D_KERNEL_LOOP(index, n) { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = + (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = + index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const T *data_offset_ptr = + data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T cur_inv_h_data = h_in + i * dilation_h + offset_h; + const T cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const T cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) { + for (int dx = -2; dx <= 2; dx++) { + if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && + cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) { + int cur_bottom_grad_pos = + ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + T weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, + cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void deformable_col2im_coord_gpu_kernel( + const int n, const T *data_col, const T *data_im, const T *data_offset, + const int channels, const int height, const int width, const int kernel_h, + const int kernel_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int offset_channels, const int deformable_group, const int height_col, + const int width_col, T *grad_offset) { + CUDA_1D_KERNEL_LOOP(index, n) { + T val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const T *data_col_ptr = data_col + deformable_group_index * + channel_per_deformable_group * + batch_size * width_col * height_col; + const T *data_im_ptr = + data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * + height * width; + const T *data_offset_ptr = + data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; + col_c += col_step) { + const int col_pos = + (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = + (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = + (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = + (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + + w_out); + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + T inv_h = h_in + i * dilation_h + offset_h; + T inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + inv_h = inv_w = -2; + const T weight = get_coordinate_weight(inv_h, inv_w, height, width, + data_im_ptr + cnt * height * width, + width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + +#endif // DEFORM_CONV_CUDA_KERNEL_CUH diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/deform_roi_pool_cuda_kernel.cuh b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/deform_roi_pool_cuda_kernel.cuh new file mode 100755 index 00000000..a1c0520d --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/deform_roi_pool_cuda_kernel.cuh @@ -0,0 +1,186 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef DEFORM_ROI_POOL_CUDA_KERNEL_CUH +#define DEFORM_ROI_POOL_CUDA_KERNEL_CUH + +#ifdef DBNET_CV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +template +__global__ void deform_roi_pool_forward_cuda_kernel( + const int nthreads, const T* input, const T* rois, const T* offset, + T* output, const int pooled_height, const int pooled_width, + const T spatial_scale, const int sampling_ratio, const T gamma, + const int channels, const int height, const int width) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - 0.5; + T roi_start_h = offset_rois[2] * spatial_scale - 0.5; + T roi_end_w = offset_rois[3] * spatial_scale - 0.5; + T roi_end_h = offset_rois[4] * spatial_scale - 0.5; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_height / pooled_height)); + int roi_bin_grid_w = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_width / pooled_width)); + + // Compute roi offset + if (offset != NULL) { + const T* offset_cur_w = offset + n * pooled_width * pooled_height * 2 + + ph * pooled_width + pw; + T offset_roi_w = gamma * roi_width * offset_cur_w[0]; + T offset_roi_h = + gamma * roi_height * offset_cur_w[pooled_width * pooled_height]; + roi_start_w += offset_roi_w; + roi_start_h += offset_roi_h; + } + + // We do average pooling inside a bin + const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T val = bilinear_interpolate(offset_input, height, width, y, x, index); + output_val += val; + } + } + output[index] = output_val / count; + } +} + +template +__global__ void deform_roi_pool_backward_cuda_kernel( + const int nthreads, const T* grad_output, const T* input, const T* rois, + const T* offset, T* grad_input, T* grad_offset, const int pooled_height, + const int pooled_width, const T spatial_scale, const int sampling_ratio, + const T gamma, const int channels, const int height, const int width) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + const T* offset_input = + input + ((roi_batch_ind * channels + c) * height * width); + T* offset_grad_input = + grad_input + ((roi_batch_ind * channels + c) * height * width); + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - 0.5; + T roi_start_h = offset_rois[2] * spatial_scale - 0.5; + T roi_end_w = offset_rois[3] * spatial_scale - 0.5; + T roi_end_h = offset_rois[4] * spatial_scale - 0.5; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_height / pooled_height)); + int roi_bin_grid_w = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_width / pooled_width)); + + // Compute roi offset + if (offset != NULL) { + const T* offset_cur_w = offset + n * pooled_width * pooled_height * 2 + + ph * pooled_width + pw; + T offset_roi_w = gamma * roi_width * offset_cur_w[0]; + T offset_roi_h = + gamma * roi_height * offset_cur_w[pooled_width * pooled_height]; + roi_start_w += offset_roi_w; + roi_start_h += offset_roi_h; + } + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + const T grad_output_this_bin = grad_output[index] / count; + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4, + x_low, x_high, y_low, y_high, index); + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomicAdd(offset_grad_input + y_low * width + x_low, + grad_output_this_bin * w1); + atomicAdd(offset_grad_input + y_low * width + x_high, + grad_output_this_bin * w2); + atomicAdd(offset_grad_input + y_high * width + x_low, + grad_output_this_bin * w3); + atomicAdd(offset_grad_input + y_high * width + x_high, + grad_output_this_bin * w4); + if (offset != NULL) { + T input_00 = offset_input[y_low * width + x_low]; + T input_10 = offset_input[y_low * width + x_high]; + T input_01 = offset_input[y_high * width + x_low]; + T input_11 = offset_input[y_high * width + x_high]; + T ogx = gamma * roi_width * grad_output_this_bin * + (input_11 * (y - y_low) + input_10 * (y_high - y) + + input_01 * (y_low - y) + input_00 * (y - y_high)); + T ogy = gamma * roi_height * grad_output_this_bin * + (input_11 * (x - x_low) + input_01 * (x_high - x) + + input_10 * (x_low - x) + input_00 * (x - x_high)); + atomicAdd(grad_offset + n * pooled_width * pooled_height * 2 + + ph * pooled_width + pw, + ogx); + atomicAdd(grad_offset + n * pooled_width * pooled_height * 2 + + pooled_width * pooled_height + ph * pooled_width + pw, + ogy); + } + } + } + } + } +} + +#endif // DEFORM_ROI_POOL_CUDA_KERNEL_CUH diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/modulated_deform_conv_cuda_kernel.cuh b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/modulated_deform_conv_cuda_kernel.cuh new file mode 100755 index 00000000..5046c8d4 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/modulated_deform_conv_cuda_kernel.cuh @@ -0,0 +1,399 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer + ***************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE + *FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + *DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + *SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + *CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + *OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer + ********************* + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/dbnet_detection/dbnet_det/ops/dcn/src/deform_conv_cuda_kernel.cu + +#ifndef MODULATED_DEFORM_CONV_CUDA_KERNEL_CUH +#define MODULATED_DEFORM_CONV_CUDA_KERNEL_CUH + +#include +#ifdef DBNET_CV_WITH_TRT +#include "common_cuda_helper.hpp" +#else // DBNET_CV_WITH_TRT +#ifdef DBNET_CV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else // DBNET_CV_USE_PARROTS +#include "pytorch_cuda_helper.hpp" +#endif // DBNET_CV_USE_PARROTS +#endif // DBNET_CV_WITH_TRT + +template +__device__ T dmcn_im2col_bilinear(const T *input, const int data_width, + const int height, const int width, T h, T w) { + int h_low = floorf(h); + int w_low = floorf(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + T lh = h - h_low; + T lw = w - w_low; + T hh = 1 - lh, hw = 1 - lw; + + T v1 = 0; + if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low]; + T v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = input[h_low * data_width + w_high]; + T v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = input[h_high * data_width + w_low]; + T v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = input[h_high * data_width + w_high]; + + T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ T dmcn_get_gradient_weight(T argmax_h, T argmax_w, const int h, + const int w, const int height, + const int width) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + // empty + return 0; + } + + int argmax_h_low = floorf(argmax_h); + int argmax_w_low = floorf(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + T weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ T dmcn_get_coordinate_weight(T argmax_h, T argmax_w, + const int height, const int width, + const T *im_data, const int data_width, + const int bp_dir) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + // empty + return 0; + } + + int argmax_h_low = floorf(argmax_h); + int argmax_w_low = floorf(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + T weight = 0; + + if (bp_dir == 0) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } else if (bp_dir == 1) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void modulated_deformable_im2col_gpu_kernel( + const int n, const T *data_im, const T *data_offset, const T *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int num_channels, const int deformable_group, const int height_col, + const int width_col, T *data_col) { + CUDA_1D_KERNEL_LOOP(index, n) { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + T *data_col_ptr = + data_col + + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + const T *data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const T *data_offset_ptr = + data_offset + (b_col * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const T *data_mask_ptr = + data_mask + (b_col * deformable_group + deformable_group_index) * + kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + + w_col; + const int data_mask_hw_ptr = + ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T mask = data_mask_ptr[data_mask_hw_ptr]; + T val = static_cast(0); + const T h_im = h_in + i * dilation_h + offset_h; + const T w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, + w_im); + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +template +__global__ void modulated_deformable_col2im_gpu_kernel( + const int n, const T *data_col, const T *data_offset, const T *data_mask, + const int channels, const int height, const int width, const int kernel_h, + const int kernel_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int deformable_group, const int height_col, const int width_col, + T *grad_im) { + CUDA_1D_KERNEL_LOOP(index, n) { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = + (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = + index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const T *data_offset_ptr = + data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + const T *data_mask_ptr = + data_mask + (b * deformable_group + deformable_group_index) * kernel_h * + kernel_w * height_col * width_col; + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = + ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T mask = data_mask_ptr[data_mask_hw_ptr]; + const T cur_inv_h_data = h_in + i * dilation_h + offset_h; + const T cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const T cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) { + for (int dx = -2; dx <= 2; dx++) { + if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && + cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) { + int cur_bottom_grad_pos = + ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + T weight = + dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, + cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void modulated_deformable_col2im_coord_gpu_kernel( + const int n, const T *data_col, const T *data_im, const T *data_offset, + const T *data_mask, const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, T *grad_offset, T *grad_mask) { + CUDA_1D_KERNEL_LOOP(index, n) { + T val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const T *data_col_ptr = data_col + deformable_group_index * + channel_per_deformable_group * + batch_size * width_col * height_col; + const T *data_im_ptr = + data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * + height * width; + const T *data_offset_ptr = + data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + const T *data_mask_ptr = + data_mask + (b * deformable_group + deformable_group_index) * kernel_h * + kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; + col_c += col_step) { + const int col_pos = + (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = + (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = + (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = + (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + + w_out); + const int data_mask_hw_ptr = + (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T mask = data_mask_ptr[data_mask_hw_ptr]; + T inv_h = h_in + i * dilation_h + offset_h; + T inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + inv_h = inv_w = -2; + else + mval += data_col_ptr[col_pos] * + dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, + height, width, inv_h, inv_w); + const T weight = dmcn_get_coordinate_weight( + inv_h, inv_w, height, width, data_im_ptr + cnt * height * width, + width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + + // deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * + // height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * + kernel_w + + offset_c / 2) * + height_col + + h) * + width_col + + w] = mval; + } +} + +#endif // MODULATED_DEFORM_CONV_CUDA_KERNEL_CUH diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/ms_deform_attn_cuda_kernel.cuh b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/ms_deform_attn_cuda_kernel.cuh new file mode 100755 index 00000000..12225ffd --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/ms_deform_attn_cuda_kernel.cuh @@ -0,0 +1,801 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ +#ifndef DEFORM_ATTN_CUDA_KERNEL +#define DEFORM_ATTN_CUDA_KERNEL + +#include "common_cuda_helper.hpp" +#include "pytorch_cuda_helper.hpp" + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear( + const scalar_t *&bottom_data, const int &height, const int &width, + const int &nheads, const int &channels, const scalar_t &h, + const scalar_t &w, const int &m, const int &c) { + const int h_low = floorf(h); + const int w_low = floorf(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ void ms_deform_attn_col2im_bilinear( + const scalar_t *&bottom_data, const int &height, const int &width, + const int &nheads, const int &channels, const scalar_t &h, + const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad, + const scalar_t &attn_weight, scalar_t *&grad_value, + scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { + const int h_low = floorf(h); + const int w_low = floorf(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value + ptr1, w1 * top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value + ptr2, w2 * top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value + ptr3, w3 * top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value + ptr4, w4 * top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = width * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; +} + +template +__device__ void ms_deform_attn_col2im_bilinear_gm( + const scalar_t *&bottom_data, const int &height, const int &width, + const int &nheads, const int &channels, const scalar_t &h, + const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad, + const scalar_t &attn_weight, scalar_t *&grad_value, + scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { + const int h_low = floorf(h); + const int w_low = floorf(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value + ptr1, w1 * top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value + ptr2, w2 * top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value + ptr3, w3 * top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value + ptr4, w4 * top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + +template +__global__ void ms_deformable_im2col_gpu_kernel( + const int n, const scalar_t *data_value, const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, const int batch_size, + const int spatial_size, const int num_heads, const int channels, + const int num_levels, const int num_query, const int num_point, + scalar_t *data_col) { + CUDA_1D_KERNEL_LOOP(index, n) { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + _temp /= num_query; + const int b_col = _temp; + + scalar_t *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t *data_value_ptr = + data_value + + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, + spatial_w, num_heads, channels, + h_im, w_im, m_col, c_col) * + weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1( + const int n, const scalar_t *grad_col, const scalar_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, + const int batch_size, const int spatial_size, const int num_heads, + const int channels, const int num_levels, const int num_query, + const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + const int qid_stride = num_heads * channels; + CUDA_1D_KERNEL_LOOP(index, n) { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + scalar_t *grad_sampling_loc_out = + grad_sampling_loc + (grad_sampling_ptr << 1); + scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, + w_im, m_col, c_col, top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + if (tid == 0) { + scalar_t _grad_w = cache_grad_sampling_loc[0], + _grad_h = cache_grad_sampling_loc[1], + _grad_a = cache_grad_attn_weight[0]; + int sid = 2; + for (unsigned int _tid = 1; _tid < blockSize; ++_tid) { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[_tid]; + sid += 2; + } + + *grad_sampling_loc_out = _grad_w; + *(grad_sampling_loc_out + 1) = _grad_h; + *grad_attn_weight_out = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight_out += grad_weight_stride; + grad_sampling_loc_out += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2( + const int n, const scalar_t *grad_col, const scalar_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, + const int batch_size, const int spatial_size, const int num_heads, + const int channels, const int num_levels, const int num_query, + const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + CUDA_1D_KERNEL_LOOP(index, n) { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + scalar_t *grad_sampling_loc_out = + grad_sampling_loc + (grad_sampling_ptr << 1); + scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, + w_im, m_col, c_col, top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s = blockSize / 2; s > 0; s >>= 1) { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) { + *grad_sampling_loc_out = cache_grad_sampling_loc[0]; + *(grad_sampling_loc_out + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight_out = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight_out += grad_weight_stride; + grad_sampling_loc_out += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1( + const int n, const scalar_t *grad_col, const scalar_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, + const int batch_size, const int spatial_size, const int num_heads, + const int channels, const int num_levels, const int num_query, + const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) { + extern __shared__ int _s[]; + scalar_t *cache_grad_sampling_loc = reinterpret_cast(_s); + scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + CUDA_1D_KERNEL_LOOP(index, n) { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + scalar_t *grad_sampling_loc_out = + grad_sampling_loc + (grad_sampling_ptr << 1); + scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, + w_im, m_col, c_col, top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + if (tid == 0) { + scalar_t _grad_w = cache_grad_sampling_loc[0], + _grad_h = cache_grad_sampling_loc[1], + _grad_a = cache_grad_attn_weight[0]; + int sid = 2; + for (unsigned int _tid = 1; _tid < blockDim.x; ++_tid) { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[_tid]; + sid += 2; + } + + *grad_sampling_loc_out = _grad_w; + *(grad_sampling_loc_out + 1) = _grad_h; + *grad_attn_weight_out = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight_out += grad_weight_stride; + grad_sampling_loc_out += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2( + const int n, const scalar_t *grad_col, const scalar_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, + const int batch_size, const int spatial_size, const int num_heads, + const int channels, const int num_levels, const int num_query, + const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) { + extern __shared__ int _s[]; + scalar_t *cache_grad_sampling_loc = reinterpret_cast(_s); + scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + CUDA_1D_KERNEL_LOOP(index, n) { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + scalar_t *grad_sampling_loc_out = + grad_sampling_loc + (grad_sampling_ptr << 1); + scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, + w_im, m_col, c_col, top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0; + s >>= 1, spre >>= 1) { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) { + cache_grad_attn_weight[tid] += + cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += + cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) { + *grad_sampling_loc_out = cache_grad_sampling_loc[0]; + *(grad_sampling_loc_out + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight_out = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight_out += grad_weight_stride; + grad_sampling_loc_out += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks( + const int n, const scalar_t *grad_col, const scalar_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, + const int batch_size, const int spatial_size, const int num_heads, + const int channels, const int num_levels, const int num_query, + const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) { + extern __shared__ int _s[]; + scalar_t *cache_grad_sampling_loc = reinterpret_cast(_s); + scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + CUDA_1D_KERNEL_LOOP(index, n) { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + scalar_t *grad_sampling_loc_out = + grad_sampling_loc + (grad_sampling_ptr << 1); + scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, + w_im, m_col, c_col, top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0; + s >>= 1, spre >>= 1) { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) { + cache_grad_attn_weight[tid] += + cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += + cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) { + atomicAdd(grad_sampling_loc_out, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc_out + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight_out, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight_out += grad_weight_stride; + grad_sampling_loc_out += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_gm( + const int n, const scalar_t *grad_col, const scalar_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, + const int batch_size, const int spatial_size, const int num_heads, + const int channels, const int num_levels, const int num_query, + const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) { + CUDA_1D_KERNEL_LOOP(index, n) { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + scalar_t *grad_sampling_loc_out = + grad_sampling_loc + (grad_sampling_ptr << 1); + scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + ms_deform_attn_col2im_bilinear_gm( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, + w_im, m_col, c_col, top_grad, weight, grad_value_ptr, + grad_sampling_loc_out, grad_attn_weight_out); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight_out += grad_weight_stride; + grad_sampling_loc_out += grad_loc_stride; + } + } + } +} +#endif // DEFORM_ATTN_CUDA_KERNEL diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/roi_align_cuda_kernel.cuh b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/roi_align_cuda_kernel.cuh new file mode 100755 index 00000000..e608474a --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/roi_align_cuda_kernel.cuh @@ -0,0 +1,212 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef ROI_ALIGN_CUDA_KERNEL_CUH +#define ROI_ALIGN_CUDA_KERNEL_CUH + +#include +#ifdef DBNET_CV_WITH_TRT +#include "common_cuda_helper.hpp" +#else // DBNET_CV_WITH_TRT +#ifdef DBNET_CV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else // DBNET_CV_USE_PARROTS +#include "pytorch_cuda_helper.hpp" +#endif // DBNET_CV_USE_PARROTS +#endif // DBNET_CV_WITH_TRT + +/*** Forward ***/ +template +__global__ void roi_align_forward_cuda_kernel( + const int nthreads, const T* input, const T* rois, T* output, T* argmax_y, + T* argmax_x, const int pooled_height, const int pooled_width, + const T spatial_scale, const int sampling_ratio, + const int pool_mode, // 0 - max pool, 1 - avg pool + const bool aligned, const int channels, const int height, const int width) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { // for backward-compatibility only + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_height / pooled_height)); + int roi_bin_grid_w = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_width / pooled_width)); + + if (pool_mode == 0) { + // We do max pooling inside a bin + T maxval = -FLT_MAX; + T maxidx_y = -1.f, maxidx_x = -1.f; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T val = + bilinear_interpolate(offset_input, height, width, y, x, index); + if (val > maxval) { + maxval = val; + maxidx_y = y; + maxidx_x = x; + } + } + } + output[index] = maxval; + argmax_y[index] = maxidx_y; + argmax_x[index] = maxidx_x; + } else if (pool_mode == 1) { + // We do average pooling inside a bin + const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T val = + bilinear_interpolate(offset_input, height, width, y, x, index); + output_val += val; + } + } + output[index] = output_val / count; + } + } +} + +/*** Backward ***/ +template +__global__ void roi_align_backward_cuda_kernel( + const int nthreads, const T* grad_output, const T* rois, const T* argmax_y, + const T* argmax_x, T* grad_input, const int pooled_height, + const int pooled_width, const T spatial_scale, const int sampling_ratio, + const int pool_mode, // 0 - max pool, 1 - avg pool + const bool aligned, const int channels, const int height, const int width) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T grad_output_this_bin = grad_output[index]; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + T* offset_grad_input = + grad_input + ((roi_batch_ind * channels + c) * height * width); + + if (pool_mode == 0) { + T y = argmax_y[index], x = argmax_x[index]; + if (y != -1.f) { + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4, + x_low, x_high, y_low, y_high, index); + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomicAdd(offset_grad_input + y_low * width + x_low, + grad_output_this_bin * w1); + atomicAdd(offset_grad_input + y_low * width + x_high, + grad_output_this_bin * w2); + atomicAdd(offset_grad_input + y_high * width + x_low, + grad_output_this_bin * w3); + atomicAdd(offset_grad_input + y_high * width + x_high, + grad_output_this_bin * w4); + } + } + } else if (pool_mode == 1) { + // Do not using rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { // for backward-compatibility only + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_height / pooled_height)); + int roi_bin_grid_w = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_width / pooled_width)); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4, + x_low, x_high, y_low, y_high, index); + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomicAdd(offset_grad_input + y_low * width + x_low, + grad_output_this_bin * w1 / count); + atomicAdd(offset_grad_input + y_low * width + x_high, + grad_output_this_bin * w2 / count); + atomicAdd(offset_grad_input + y_high * width + x_low, + grad_output_this_bin * w3 / count); + atomicAdd(offset_grad_input + y_high * width + x_high, + grad_output_this_bin * w4 / count); + } + } + } + } + } +} + +#endif // ROI_ALIGN_CUDA_KERNEL_CUH diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/roi_pool_cuda_kernel.cuh b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/roi_pool_cuda_kernel.cuh new file mode 100755 index 00000000..ba732685 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/roi_pool_cuda_kernel.cuh @@ -0,0 +1,93 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef ROI_POOL_CUDA_KERNEL_CUH +#define ROI_POOL_CUDA_KERNEL_CUH + +#ifdef DBNET_CV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +template +__global__ void roi_pool_forward_cuda_kernel( + const int nthreads, const T* input, const T* rois, T* output, int* argmax, + const int pooled_height, const int pooled_width, const T spatial_scale, + const int channels, const int height, const int width) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + // calculate the roi region on feature maps + T roi_x1 = offset_rois[1] * spatial_scale; + T roi_y1 = offset_rois[2] * spatial_scale; + T roi_x2 = (offset_rois[3] + 1) * spatial_scale; + T roi_y2 = (offset_rois[4] + 1) * spatial_scale; + + // force malformed rois to be 1x1 + T roi_w = roi_x2 - roi_x1; + T roi_h = roi_y2 - roi_y1; + if (roi_w <= 0 || roi_h <= 0) continue; + + T bin_size_w = roi_w / static_cast(pooled_width); + T bin_size_h = roi_h / static_cast(pooled_height); + + // the corresponding bin region + int bin_x1 = floorf(static_cast(pw) * bin_size_w + roi_x1); + int bin_y1 = floorf(static_cast(ph) * bin_size_h + roi_y1); + int bin_x2 = ceilf(static_cast(pw + 1) * bin_size_w + roi_x1); + int bin_y2 = ceilf(static_cast(ph + 1) * bin_size_h + roi_y1); + + // add roi offsets and clip to input boundaries + bin_x1 = min(max(bin_x1, 0), width); + bin_y1 = min(max(bin_y1, 0), height); + bin_x2 = min(max(bin_x2, 0), width); + bin_y2 = min(max(bin_y2, 0), height); + bool is_empty = (bin_y2 <= bin_y1) || (bin_x2 <= bin_x1); + + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + // Define an empty pooling region to be zero + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + T max_val = is_empty ? 0 : -FLT_MAX; + int max_idx = -1; + for (int h = bin_y1; h < bin_y2; ++h) { + for (int w = bin_x1; w < bin_x2; ++w) { + int offset = h * width + w; + if (offset_input[offset] > max_val) { + max_val = offset_input[offset]; + max_idx = offset; + } + } + } + output[index] = max_val; + if (argmax != NULL) argmax[index] = max_idx; + } +} + +template +__global__ void roi_pool_backward_cuda_kernel( + const int nthreads, const T* grad_output, const T* rois, const int* argmax, + T* grad_input, const int pooled_height, const int pooled_width, + const int channels, const int height, const int width) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c) is an element in the pooled output + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + int roi_batch_ind = rois[n * 5]; + T* grad_input_offset = + grad_input + ((roi_batch_ind * channels + c) * height * width); + int argmax_index = argmax[index]; + + if (argmax_index != -1) { + atomicAdd(grad_input_offset + argmax_index, grad_output[index]); + } + } +} + +#endif // ROI_POOL_CUDA_KERNEL_CUH diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/sync_bn_cuda_kernel.cuh b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/sync_bn_cuda_kernel.cuh new file mode 100755 index 00000000..78d97bed --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/cuda/sync_bn_cuda_kernel.cuh @@ -0,0 +1,331 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef SYNCBN_CUDA_KERNEL_CUH +#define SYNCBN_CUDA_KERNEL_CUH + +#ifdef DBNET_CV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +template +__global__ void sync_bn_forward_mean_cuda_kernel(const T *input, float *mean, + int num, int channels, + int spatial) { + __shared__ float buffer[THREADS_PER_BLOCK]; + int tid = threadIdx.x; + int c = blockIdx.x; + buffer[tid] = 0; + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = (i / spatial) * channels * spatial + c * spatial + i % spatial; + buffer[tid] += input[index]; + } + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + buffer[tid] += buffer[tid + s]; + } + __syncthreads(); + } + int total = num * spatial; + if (tid == 0) { + mean[c] = buffer[0] / total; + } +} + +template <> +__global__ void sync_bn_forward_mean_cuda_kernel(const phalf *input, + float *mean, int num, + int channels, int spatial) { + __shared__ float buffer[THREADS_PER_BLOCK]; + int tid = threadIdx.x; + int c = blockIdx.x; + buffer[tid] = 0; + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = (i / spatial) * channels * spatial + c * spatial + i % spatial; + buffer[tid] += static_cast(input[index]); + } + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + buffer[tid] += buffer[tid + s]; + } + __syncthreads(); + } + int total = num * spatial; + if (tid == 0) { + mean[c] = buffer[0] / total; + } +} + +template +__global__ void sync_bn_forward_var_cuda_kernel(const T *input, + const float *mean, float *var, + int num, int channels, + int spatial) { + __shared__ float buffer[THREADS_PER_BLOCK]; + int tid = threadIdx.x; + int c = blockIdx.x; + buffer[tid] = 0; + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = (i / spatial) * channels * spatial + c * spatial + i % spatial; + float td = input[index] - mean[c]; + buffer[tid] += td * td; + } + __syncthreads(); + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + buffer[tid] += buffer[tid + s]; + } + __syncthreads(); + } + int total = num * spatial; + if (tid == 0) { + var[c] = buffer[0] / total; + } +} + +template <> +__global__ void sync_bn_forward_var_cuda_kernel(const phalf *input, + const float *mean, float *var, + int num, int channels, + int spatial) { + __shared__ float buffer[THREADS_PER_BLOCK]; + int tid = threadIdx.x; + int c = blockIdx.x; + buffer[tid] = 0; + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = (i / spatial) * channels * spatial + c * spatial + i % spatial; + float td = static_cast(input[index]) - mean[c]; + buffer[tid] += td * td; + } + __syncthreads(); + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + buffer[tid] += buffer[tid + s]; + } + __syncthreads(); + } + int total = num * spatial; + if (tid == 0) { + var[c] = buffer[0] / total; + } +} + +template +__global__ void sync_bn_forward_output_cuda_kernel( + const T *input, const float *mean, const float *var, float *running_mean, + float *running_var, const float *weight, const float *bias, float *norm, + float *std, T *output, int num, int channels, int spatial, float eps, + float momentum, int group_size) { + int tid = threadIdx.x; + int c = blockIdx.x; + float mean_value = mean[c]; + float std_value = sqrt(var[c] + eps); + + if (weight != nullptr) { + float weight_value = weight[c]; + float bias_value = bias[c]; + if (norm != nullptr) { + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = + (i / spatial) * channels * spatial + c * spatial + i % spatial; + norm[index] = (input[index] - mean_value) / std_value; + output[index] = norm[index] * weight_value + bias_value; + } + } else { + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = + (i / spatial) * channels * spatial + c * spatial + i % spatial; + output[index] = + (input[index] - mean_value) / std_value * weight_value + bias_value; + } + } + } else { + if (norm != nullptr) { + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = + (i / spatial) * channels * spatial + c * spatial + i % spatial; + output[index] = norm[index] = (input[index] - mean_value) / std_value; + } + } else { + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = + (i / spatial) * channels * spatial + c * spatial + i % spatial; + output[index] = (input[index] - mean_value) / std_value; + } + } + } + if (tid == 0) { + if (std != nullptr) std[c] = std_value; + if (running_mean != nullptr) { + running_mean[c] = + momentum * mean_value + (1 - momentum) * running_mean[c]; + int count = num * spatial * group_size; + float var_unbias = count > 1 ? var[c] * count / (count - 1) : var[c]; + running_var[c] = momentum * var_unbias + (1 - momentum) * running_var[c]; + } + } +} + +template <> +__global__ void sync_bn_forward_output_cuda_kernel( + const phalf *input, const float *mean, const float *var, + float *running_mean, float *running_var, const float *weight, + const float *bias, float *norm, float *std, phalf *output, int num, + int channels, int spatial, float eps, float momentum, int group_size) { + int tid = threadIdx.x; + int c = blockIdx.x; + float mean_value = mean[c]; + float std_value = sqrt(var[c] + eps); + if (weight != nullptr) { + float weight_value = weight[c]; + float bias_value = bias[c]; + if (norm != nullptr) { + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = + (i / spatial) * channels * spatial + c * spatial + i % spatial; + norm[index] = + (static_cast(input[index]) - mean_value) / std_value; + output[index] = + static_cast(norm[index] * weight_value + bias_value); + } + } else { + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = + (i / spatial) * channels * spatial + c * spatial + i % spatial; + output[index] = + static_cast((static_cast(input[index]) - mean_value) / + std_value * weight_value + + bias_value); + } + } + } else { + if (norm != nullptr) { + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = + (i / spatial) * channels * spatial + c * spatial + i % spatial; + norm[index] = + (static_cast(input[index]) - mean_value) / std_value; + output[index] = static_cast(norm[index]); + } + } else { + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = + (i / spatial) * channels * spatial + c * spatial + i % spatial; + output[index] = static_cast( + (static_cast(input[index]) - mean_value) / std_value); + } + } + } + if (tid == 0) { + if (std != nullptr) std[c] = std_value; + if (running_mean != nullptr) { + running_mean[c] = + momentum * mean_value + (1 - momentum) * running_mean[c]; + int count = num * spatial * group_size; + float var_unbias = count > 1 ? var[c] * count / (count - 1) : var[c]; + running_var[c] = momentum * var_unbias + (1 - momentum) * running_var[c]; + } + } +} + +template +__global__ void sync_bn_backward_param_cuda_kernel(const T *grad_output, + const float *norm, + float *grad_weight, + float *grad_bias, int num, + int channels, int spatial) { + __shared__ float buffer1[THREADS_PER_BLOCK]; + __shared__ float buffer2[THREADS_PER_BLOCK]; + + int tid = threadIdx.x; + int c = blockIdx.x; + buffer1[tid] = buffer2[tid] = 0; + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = (i / spatial) * channels * spatial + c * spatial + i % spatial; + buffer1[tid] += grad_output[index] * norm[index]; + buffer2[tid] += grad_output[index]; + } + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + buffer1[tid] += buffer1[tid + s]; + buffer2[tid] += buffer2[tid + s]; + } + __syncthreads(); + } + if (tid == 0) { + grad_weight[c] = buffer1[0]; + grad_bias[c] = buffer2[0]; + } +} + +template <> +__global__ void sync_bn_backward_param_cuda_kernel(const phalf *grad_output, + const float *norm, + float *grad_weight, + float *grad_bias, int num, + int channels, int spatial) { + __shared__ float buffer1[THREADS_PER_BLOCK]; + __shared__ float buffer2[THREADS_PER_BLOCK]; + + int tid = threadIdx.x; + int c = blockIdx.x; + buffer1[tid] = buffer2[tid] = 0; + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = (i / spatial) * channels * spatial + c * spatial + i % spatial; + buffer1[tid] += static_cast(grad_output[index]) * norm[index]; + buffer2[tid] += static_cast(grad_output[index]); + } + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + buffer1[tid] += buffer1[tid + s]; + buffer2[tid] += buffer2[tid + s]; + } + __syncthreads(); + } + if (tid == 0) { + grad_weight[c] = buffer1[0]; + grad_bias[c] = buffer2[0]; + } +} + +template +__global__ void sync_bn_backward_data_cuda_kernel( + int output_size, const T *grad_output, const float *weight, + const float *grad_weight, const float *grad_bias, const float *norm, + const float *std, T *grad_input, int num, int channels, int spatial) { + int factor = num * spatial; + CUDA_1D_KERNEL_LOOP(index, output_size) { + int c = (index / spatial) % channels; + grad_input[index] = + weight[c] * + (grad_output[index] - + (grad_weight[c] * norm[index] + grad_bias[c]) / factor) / + std[c]; + } +} + +template <> +__global__ void sync_bn_backward_data_cuda_kernel( + int output_size, const phalf *grad_output, const float *weight, + const float *grad_weight, const float *grad_bias, const float *norm, + const float *std, phalf *grad_input, int num, int channels, int spatial) { + int factor = num * spatial; + CUDA_1D_KERNEL_LOOP(index, output_size) { + int c = (index / spatial) % channels; + grad_input[index] = static_cast( + weight[c] * + (static_cast(grad_output[index]) - + (grad_weight[c] * norm[index] + grad_bias[c]) / factor) / + std[c]); + } +} + +#endif // SYNCBN_CUDA_KERNEL_CUH diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_cpp_helper.hpp b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_cpp_helper.hpp new file mode 100755 index 00000000..f68e8740 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_cpp_helper.hpp @@ -0,0 +1,27 @@ +#ifndef PYTORCH_CPP_HELPER +#define PYTORCH_CPP_HELPER +#include + +#include + +using namespace at; + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_MLU(x) \ + TORCH_CHECK(x.device().type() == at::kMLU, #x " must be a MLU tensor") +#define CHECK_CPU(x) \ + TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CUDA_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) +#define CHECK_MLU_INPUT(x) \ + CHECK_MLU(x); \ + CHECK_CONTIGUOUS(x) +#define CHECK_CPU_INPUT(x) \ + CHECK_CPU(x); \ + CHECK_CONTIGUOUS(x) + +#endif // PYTORCH_CPP_HELPER diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_cuda_helper.hpp b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_cuda_helper.hpp new file mode 100755 index 00000000..9869b535 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_cuda_helper.hpp @@ -0,0 +1,19 @@ +#ifndef PYTORCH_CUDA_HELPER +#define PYTORCH_CUDA_HELPER + +#include +#include +#include + +#include +#include + +#include "common_cuda_helper.hpp" + +using at::Half; +using at::Tensor; +using phalf = at::Half; + +#define __PHALF(x) (x) + +#endif // PYTORCH_CUDA_HELPER diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_device_registry.hpp b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_device_registry.hpp new file mode 100755 index 00000000..2a32b727 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/common/pytorch_device_registry.hpp @@ -0,0 +1,141 @@ +#ifndef PYTORCH_DEVICE_REGISTRY_H +#define PYTORCH_DEVICE_REGISTRY_H + +// Using is recommended in the official documentation in +// https://pytorch.org/tutorials/advanced/cpp_extension.html#writing-the-c-op. +// However, we use for compatibility with CUDA 9.0 +// Read https://github.com/pytorch/extension-cpp/issues/35 for more details. +#include + +#include +#include +#include +#include + +inline std::string GetDeviceStr(const at::Device& device) { + std::string str = DeviceTypeName(device.type(), true); + if (device.has_index()) { + str.push_back(':'); + str.append(std::to_string(device.index())); + } + return str; +} + +// Registry +template +class DeviceRegistry; + +template +class DeviceRegistry { + public: + using FunctionType = Ret (*)(Args...); + static const int MAX_DEVICE_TYPES = + int8_t(at::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES); + + void Register(at::DeviceType device, FunctionType function) { + funcs_[int8_t(device)] = function; + } + + FunctionType Find(at::DeviceType device) const { + return funcs_[int8_t(device)]; + } + + static DeviceRegistry& instance() { + static DeviceRegistry inst; + return inst; + } + + private: + DeviceRegistry() { + for (size_t i = 0; i < MAX_DEVICE_TYPES; ++i) { + funcs_[i] = nullptr; + } + }; + FunctionType funcs_[MAX_DEVICE_TYPES]; +}; + +// get device of first tensor param + +template , at::Tensor>::value, + bool> = true> +at::Device GetFirstTensorDevice(T&& t, Args&&... args) { + return std::forward(t).device(); +} +template , at::Tensor>::value, + bool> = true> +at::Device GetFirstTensorDevice(T&& t, Args&&... args) { + return GetFirstTensorDevice(std::forward(args)...); +} + +// check device consistency + +inline std::pair CheckDeviceConsistency( + const at::Device& device, int index) { + return {index, device}; +} + +template , at::Tensor>::value, + bool> = true> +std::pair CheckDeviceConsistency(const at::Device& device, + int index, T&& t, + Args&&... args); + +template , at::Tensor>::value, + bool> = true> +std::pair CheckDeviceConsistency(const at::Device& device, + int index, T&& t, + Args&&... args) { + auto new_device = std::forward(t).device(); + if (new_device.type() != device.type() || + new_device.index() != device.index()) { + return {index, new_device}; + } + return CheckDeviceConsistency(device, index + 1, std::forward(args)...); +} + +template < + typename T, typename... Args, + std::enable_if_t, at::Tensor>::value, bool>> +std::pair CheckDeviceConsistency(const at::Device& device, + int index, T&& t, + Args&&... args) { + return CheckDeviceConsistency(device, index + 1, std::forward(args)...); +} + +// dispatch + +template +auto Dispatch(const R& registry, const char* name, Args&&... args) { + auto device = GetFirstTensorDevice(std::forward(args)...); + auto inconsist = + CheckDeviceConsistency(device, 0, std::forward(args)...); + TORCH_CHECK(inconsist.first >= int(sizeof...(Args)), name, ": at param ", + inconsist.first, + ", inconsistent device: ", GetDeviceStr(inconsist.second).c_str(), + " vs ", GetDeviceStr(device).c_str(), "\n") + auto f_ptr = registry.Find(device.type()); + TORCH_CHECK(f_ptr != nullptr, name, ": implementation for device ", + GetDeviceStr(device).c_str(), " not found.\n") + return f_ptr(std::forward(args)...); +} + +// helper macro + +#define DEVICE_REGISTRY(key) DeviceRegistry::instance() + +#define REGISTER_DEVICE_IMPL(key, device, value) \ + struct key##_##device##_registerer { \ + key##_##device##_registerer() { \ + DEVICE_REGISTRY(key).Register(at::k##device, value); \ + } \ + }; \ + static key##_##device##_registerer _##key##_##device##_registerer; + +#define DISPATCH_DEVICE_IMPL(key, ...) \ + Dispatch(DEVICE_REGISTRY(key), #key, __VA_ARGS__) + +#endif // PYTORCH_DEVICE_REGISTRY diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/cudabind.cpp b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/cudabind.cpp new file mode 100755 index 00000000..479786c7 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/cudabind.cpp @@ -0,0 +1,372 @@ +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + + + +void deformable_im2col_cuda(Tensor data_im, Tensor data_offset, + const int channels, const int height, + const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor data_col); + +void deformable_col2im_cuda(Tensor data_col, Tensor data_offset, + const int channels, const int height, + const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor grad_im); + +void deformable_col2im_coord_cuda( + Tensor data_col, Tensor data_im, Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, Tensor grad_offset); + +void deformable_im2col_impl(Tensor data_im, Tensor data_offset, + const int channels, const int height, + const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor data_col); + +void deformable_col2im_impl(Tensor data_col, Tensor data_offset, + const int channels, const int height, + const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor grad_im); + +void deformable_col2im_coord_impl( + Tensor data_col, Tensor data_im, Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, Tensor grad_offset); + +REGISTER_DEVICE_IMPL(deformable_im2col_impl, CUDA, deformable_im2col_cuda); +REGISTER_DEVICE_IMPL(deformable_col2im_impl, CUDA, deformable_col2im_cuda); +REGISTER_DEVICE_IMPL(deformable_col2im_coord_impl, CUDA, + deformable_col2im_coord_cuda); + +void DeformRoIPoolForwardCUDAKernelLauncher(Tensor input, Tensor rois, + Tensor offset, Tensor output, + int pooled_height, int pooled_width, + float spatial_scale, + int sampling_ratio, float gamma); + +void DeformRoIPoolBackwardCUDAKernelLauncher( + Tensor grad_output, Tensor input, Tensor rois, Tensor offset, + Tensor grad_input, Tensor grad_offset, int pooled_height, int pooled_width, + float spatial_scale, int sampling_ratio, float gamma); + +void deform_roi_pool_forward_cuda(Tensor input, Tensor rois, Tensor offset, + Tensor output, int pooled_height, + int pooled_width, float spatial_scale, + int sampling_ratio, float gamma) { + DeformRoIPoolForwardCUDAKernelLauncher(input, rois, offset, output, + pooled_height, pooled_width, + spatial_scale, sampling_ratio, gamma); +} + +void deform_roi_pool_backward_cuda(Tensor grad_output, Tensor input, + Tensor rois, Tensor offset, + Tensor grad_input, Tensor grad_offset, + int pooled_height, int pooled_width, + float spatial_scale, int sampling_ratio, + float gamma) { + DeformRoIPoolBackwardCUDAKernelLauncher( + grad_output, input, rois, offset, grad_input, grad_offset, pooled_height, + pooled_width, spatial_scale, sampling_ratio, gamma); +} + +void deform_roi_pool_forward_impl(Tensor input, Tensor rois, Tensor offset, + Tensor output, int pooled_height, + int pooled_width, float spatial_scale, + int sampling_ratio, float gamma); + +void deform_roi_pool_backward_impl(Tensor grad_output, Tensor input, + Tensor rois, Tensor offset, + Tensor grad_input, Tensor grad_offset, + int pooled_height, int pooled_width, + float spatial_scale, int sampling_ratio, + float gamma); + +REGISTER_DEVICE_IMPL(deform_roi_pool_forward_impl, CUDA, + deform_roi_pool_forward_cuda); +REGISTER_DEVICE_IMPL(deform_roi_pool_backward_impl, CUDA, + deform_roi_pool_backward_cuda); + + + +void modulated_deformable_im2col_cuda( + const Tensor data_im, const Tensor data_offset, const Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, Tensor data_col); + +void modulated_deformable_col2im_cuda( + const Tensor data_col, const Tensor data_offset, const Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const Tensor data_col, const Tensor data_im, const Tensor data_offset, + const Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + Tensor grad_offset, Tensor grad_mask); + +void modulated_deformable_im2col_impl( + const Tensor data_im, const Tensor data_offset, const Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, Tensor data_col); + +void modulated_deformable_col2im_impl( + const Tensor data_col, const Tensor data_offset, const Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, Tensor grad_im); + +void modulated_deformable_col2im_coord_impl( + const Tensor data_col, const Tensor data_im, const Tensor data_offset, + const Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + Tensor grad_offset, Tensor grad_mask); + +REGISTER_DEVICE_IMPL(modulated_deformable_im2col_impl, CUDA, + modulated_deformable_im2col_cuda); +REGISTER_DEVICE_IMPL(modulated_deformable_col2im_impl, CUDA, + modulated_deformable_col2im_cuda); +REGISTER_DEVICE_IMPL(modulated_deformable_col2im_coord_impl, CUDA, + modulated_deformable_col2im_coord_cuda); + +Tensor ms_deform_attn_cuda_forward(const Tensor& value, + const Tensor& spatial_shapes, + const Tensor& level_start_index, + const Tensor& sampling_loc, + const Tensor& attn_weight, + const int im2col_step); + +void ms_deform_attn_cuda_backward( + const Tensor& value, const Tensor& spatial_shapes, + const Tensor& level_start_index, const Tensor& sampling_loc, + const Tensor& attn_weight, const Tensor& grad_output, Tensor& grad_value, + Tensor& grad_sampling_loc, Tensor& grad_attn_weight, const int im2col_step); + +Tensor ms_deform_attn_impl_forward(const Tensor& value, + const Tensor& spatial_shapes, + const Tensor& level_start_index, + const Tensor& sampling_loc, + const Tensor& attn_weight, + const int im2col_step); + +void ms_deform_attn_impl_backward( + const Tensor& value, const Tensor& spatial_shapes, + const Tensor& level_start_index, const Tensor& sampling_loc, + const Tensor& attn_weight, const Tensor& grad_output, Tensor& grad_value, + Tensor& grad_sampling_loc, Tensor& grad_attn_weight, const int im2col_step); + +REGISTER_DEVICE_IMPL(ms_deform_attn_impl_forward, CUDA, + ms_deform_attn_cuda_forward); +REGISTER_DEVICE_IMPL(ms_deform_attn_impl_backward, CUDA, + ms_deform_attn_cuda_backward); + + + +void ROIAlignForwardCUDAKernelLauncher(Tensor input, Tensor rois, Tensor output, + Tensor argmax_y, Tensor argmax_x, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + int pool_mode, bool aligned); + +void ROIAlignBackwardCUDAKernelLauncher(Tensor grad_output, Tensor rois, + Tensor argmax_y, Tensor argmax_x, + Tensor grad_input, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, int pool_mode, + bool aligned); + +void roi_align_forward_cuda(Tensor input, Tensor rois, Tensor output, + Tensor argmax_y, Tensor argmax_x, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + int pool_mode, bool aligned) { + ROIAlignForwardCUDAKernelLauncher( + input, rois, output, argmax_y, argmax_x, aligned_height, aligned_width, + spatial_scale, sampling_ratio, pool_mode, aligned); +} + +void roi_align_backward_cuda(Tensor grad_output, Tensor rois, Tensor argmax_y, + Tensor argmax_x, Tensor grad_input, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + int pool_mode, bool aligned) { + ROIAlignBackwardCUDAKernelLauncher( + grad_output, rois, argmax_y, argmax_x, grad_input, aligned_height, + aligned_width, spatial_scale, sampling_ratio, pool_mode, aligned); +} + +void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output, + Tensor argmax_y, Tensor argmax_x, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + int pool_mode, bool aligned); + +void roi_align_backward_impl(Tensor grad_output, Tensor rois, Tensor argmax_y, + Tensor argmax_x, Tensor grad_input, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + int pool_mode, bool aligned); + +REGISTER_DEVICE_IMPL(roi_align_forward_impl, CUDA, roi_align_forward_cuda); +REGISTER_DEVICE_IMPL(roi_align_backward_impl, CUDA, roi_align_backward_cuda); + + + +void ROIPoolForwardCUDAKernelLauncher(Tensor input, Tensor rois, Tensor output, + Tensor argmax, int pooled_height, + int pooled_width, float spatial_scale); + +void ROIPoolBackwardCUDAKernelLauncher(Tensor grad_output, Tensor rois, + Tensor argmax, Tensor grad_input, + int pooled_height, int pooled_width, + float spatial_scale); + +void roi_pool_forward_cuda(Tensor input, Tensor rois, Tensor output, + Tensor argmax, int pooled_height, int pooled_width, + float spatial_scale) { + ROIPoolForwardCUDAKernelLauncher(input, rois, output, argmax, pooled_height, + pooled_width, spatial_scale); +} + +void roi_pool_backward_cuda(Tensor grad_output, Tensor rois, Tensor argmax, + Tensor grad_input, int pooled_height, + int pooled_width, float spatial_scale) { + ROIPoolBackwardCUDAKernelLauncher(grad_output, rois, argmax, grad_input, + pooled_height, pooled_width, spatial_scale); +} + +void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output, + Tensor argmax, int pooled_height, int pooled_width, + float spatial_scale); +void roi_pool_backward_impl(Tensor grad_output, Tensor rois, Tensor argmax, + Tensor grad_input, int pooled_height, + int pooled_width, float spatial_scale); +REGISTER_DEVICE_IMPL(roi_pool_forward_impl, CUDA, roi_pool_forward_cuda); +REGISTER_DEVICE_IMPL(roi_pool_backward_impl, CUDA, roi_pool_backward_cuda); + + + +void SyncBNForwardMeanCUDAKernelLauncher(const Tensor input, Tensor mean); + +void SyncBNForwardVarCUDAKernelLauncher(const Tensor input, const Tensor mean, + Tensor var); + +void SyncBNForwardOutputCUDAKernelLauncher( + const Tensor input, const Tensor mean, const Tensor var, + Tensor running_mean, Tensor running_var, const Tensor weight, + const Tensor bias, Tensor norm, Tensor std, Tensor output, float eps, + float momentum, int group_size); + +void SyncBNBackwardParamCUDAKernelLauncher(const Tensor grad_output, + const Tensor norm, + Tensor grad_weight, + Tensor grad_bias); + +void SyncBNBackwardDataCUDAKernelLauncher(const Tensor grad_output, + const Tensor weight, + const Tensor grad_weight, + const Tensor grad_bias, + const Tensor norm, const Tensor std, + Tensor grad_input); + +void sync_bn_forward_mean_cuda(const Tensor input, Tensor mean) { + SyncBNForwardMeanCUDAKernelLauncher(input, mean); +} + +void sync_bn_forward_var_cuda(const Tensor input, const Tensor mean, + Tensor var) { + SyncBNForwardVarCUDAKernelLauncher(input, mean, var); +} + +void sync_bn_forward_output_cuda(const Tensor input, const Tensor mean, + const Tensor var, Tensor running_mean, + Tensor running_var, const Tensor weight, + const Tensor bias, Tensor norm, Tensor std, + Tensor output, float eps, float momentum, + int group_size) { + SyncBNForwardOutputCUDAKernelLauncher(input, mean, var, running_mean, + running_var, weight, bias, norm, std, + output, eps, momentum, group_size); +} + +void sync_bn_backward_param_cuda(const Tensor grad_output, const Tensor norm, + Tensor grad_weight, Tensor grad_bias) { + SyncBNBackwardParamCUDAKernelLauncher(grad_output, norm, grad_weight, + grad_bias); +} + +void sync_bn_backward_data_cuda(const Tensor grad_output, const Tensor weight, + const Tensor grad_weight, + const Tensor grad_bias, const Tensor norm, + const Tensor std, Tensor grad_input) { + SyncBNBackwardDataCUDAKernelLauncher(grad_output, weight, grad_weight, + grad_bias, norm, std, grad_input); +} + +void sync_bn_forward_mean_impl(const Tensor input, Tensor mean); + +void sync_bn_forward_var_impl(const Tensor input, const Tensor mean, + Tensor var); + +void sync_bn_forward_output_impl(const Tensor input, const Tensor mean, + const Tensor var, Tensor running_mean, + Tensor running_var, const Tensor weight, + const Tensor bias, Tensor norm, Tensor std, + Tensor output, float eps, float momentum, + int group_size); + +void sync_bn_backward_param_impl(const Tensor grad_output, const Tensor norm, + Tensor grad_weight, Tensor grad_bias); + +void sync_bn_backward_data_impl(const Tensor grad_output, const Tensor weight, + const Tensor grad_weight, + const Tensor grad_bias, const Tensor norm, + const Tensor std, Tensor grad_input); + +REGISTER_DEVICE_IMPL(sync_bn_forward_mean_impl, CUDA, + sync_bn_forward_mean_cuda); +REGISTER_DEVICE_IMPL(sync_bn_forward_var_impl, CUDA, sync_bn_forward_var_cuda); +REGISTER_DEVICE_IMPL(sync_bn_forward_output_impl, CUDA, + sync_bn_forward_output_cuda); +REGISTER_DEVICE_IMPL(sync_bn_backward_param_impl, CUDA, + sync_bn_backward_param_cuda); +REGISTER_DEVICE_IMPL(sync_bn_backward_data_impl, CUDA, + sync_bn_backward_data_cuda); + + diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/deform_conv_cuda.cu b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/deform_conv_cuda.cu new file mode 100755 index 00000000..05fc08b7 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/deform_conv_cuda.cu @@ -0,0 +1,105 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "deform_conv_cuda_kernel.cuh" +#include "pytorch_cuda_helper.hpp" + +void deformable_im2col_cuda(Tensor data_im, Tensor data_offset, + const int channels, const int height, + const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor data_col) { + // num_axes should be smaller than block size + // todo: check parallel_imgs is correctly passed in + int height_col = + (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = + (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, height, width, ksize_h, + ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, parallel_imgs, channels, + deformable_group, height_col, width_col, data_col_); + })); + AT_CUDA_CHECK(cudaGetLastError()); +} + +void deformable_col2im_cuda(Tensor data_col, Tensor data_offset, + const int channels, const int height, + const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor grad_im) { + // todo: make sure parallel_imgs is passed in correctly + int height_col = + (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = + (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = + channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, channels, height, width, + ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h, + dilation_w, channel_per_deformable_group, parallel_imgs, + deformable_group, height_col, width_col, grad_im_); + })); + AT_CUDA_CHECK(cudaGetLastError()); +} + +void deformable_col2im_coord_cuda( + Tensor data_col, Tensor data_im, Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, Tensor grad_offset) { + int height_col = + (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = + (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * + deformable_group * parallel_imgs; + int channel_per_deformable_group = + channels * ksize_h * ksize_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + + deformable_col2im_coord_gpu_kernel<<< + GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, + at::cuda::getCurrentCUDAStream()>>>( + num_kernels, data_col_, data_im_, data_offset_, channels, height, + width, ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, parallel_imgs, + 2 * ksize_h * ksize_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_); + })); + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/deform_roi_pool_cuda.cu b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/deform_roi_pool_cuda.cu new file mode 100755 index 00000000..d4439982 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/deform_roi_pool_cuda.cu @@ -0,0 +1,55 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "deform_roi_pool_cuda_kernel.cuh" +#include "pytorch_cuda_helper.hpp" + +void DeformRoIPoolForwardCUDAKernelLauncher(Tensor input, Tensor rois, + Tensor offset, Tensor output, + int pooled_height, int pooled_width, + float spatial_scale, + int sampling_ratio, float gamma) { + int output_size = output.numel(); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + at::cuda::CUDAGuard device_guard(input.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "deform_roi_pool_forward_cuda_kernel", [&] { + deform_roi_pool_forward_cuda_kernel + <<>>( + output_size, input.data_ptr(), + rois.data_ptr(), offset.data_ptr(), + output.data_ptr(), pooled_height, pooled_width, + static_cast(spatial_scale), sampling_ratio, + static_cast(gamma), channels, height, width); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void DeformRoIPoolBackwardCUDAKernelLauncher( + Tensor grad_output, Tensor input, Tensor rois, Tensor offset, + Tensor grad_input, Tensor grad_offset, int pooled_height, int pooled_width, + float spatial_scale, int sampling_ratio, float gamma) { + int output_size = grad_output.numel(); + int channels = grad_input.size(1); + int height = grad_input.size(2); + int width = grad_input.size(3); + + at::cuda::CUDAGuard device_guard(grad_output.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "deform_roi_pool_backward_cuda_kernel", [&] { + deform_roi_pool_backward_cuda_kernel + <<>>( + output_size, grad_output.data_ptr(), + input.data_ptr(), rois.data_ptr(), + offset.data_ptr(), grad_input.data_ptr(), + grad_offset.data_ptr(), pooled_height, pooled_width, + static_cast(spatial_scale), sampling_ratio, + static_cast(gamma), channels, height, width); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/modulated_deform_conv_cuda.cu b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/modulated_deform_conv_cuda.cu new file mode 100755 index 00000000..2b52796e --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/modulated_deform_conv_cuda.cu @@ -0,0 +1,96 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "modulated_deform_conv_cuda_kernel.cuh" +#include "pytorch_cuda_helper.hpp" + +void modulated_deformable_im2col_cuda( + const Tensor data_im, const Tensor data_offset, const Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, Tensor data_col) { + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + modulated_deformable_im2col_gpu_kernel<<< + GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, + at::cuda::getCurrentCUDAStream()>>>( + num_kernels, data_im_, data_offset_, data_mask_, height_im, + width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, batch_size, + channels, deformable_group, height_col, width_col, data_col_); + })); + AT_CUDA_CHECK(cudaGetLastError()); +} + +void modulated_deformable_col2im_cuda( + const Tensor data_col, const Tensor data_offset, const Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, Tensor grad_im) { + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = + channels * kernel_h * kernel_w * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + modulated_deformable_col2im_gpu_kernel<<< + GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, + at::cuda::getCurrentCUDAStream()>>>( + num_kernels, data_col_, data_offset_, data_mask_, channels, + height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im_); + })); + AT_CUDA_CHECK(cudaGetLastError()); +} + +void modulated_deformable_col2im_coord_cuda( + const Tensor data_col, const Tensor data_im, const Tensor data_offset, + const Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + Tensor grad_offset, Tensor grad_mask) { + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * + kernel_w * deformable_group; + const int channel_per_deformable_group = + channels * kernel_h * kernel_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + scalar_t *grad_mask_ = grad_mask.data_ptr(); + + modulated_deformable_col2im_coord_gpu_kernel<<< + GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, + at::cuda::getCurrentCUDAStream()>>>( + num_kernels, data_col_, data_im_, data_offset_, data_mask_, + channels, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, + stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, batch_size, + 2 * kernel_h * kernel_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_, grad_mask_); + })); + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/ms_deform_attn_cuda.cu b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/ms_deform_attn_cuda.cu new file mode 100755 index 00000000..fd191ee9 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/ms_deform_attn_cuda.cu @@ -0,0 +1,351 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include +#include +#include +#include + +#include +#include + +#include "ms_deform_attn_cuda_kernel.cuh" + +template +void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, const int spatial_size, + const int num_heads, const int channels, + const int num_levels, const int num_query, + const int num_point, scalar_t *data_col) { + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = THREADS_PER_BLOCK; + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_value, data_spatial_shapes, data_level_start_index, + data_sampling_loc, data_attn_weight, batch_size, spatial_size, + num_heads, channels, num_levels, num_query, num_point, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +template +void ms_deformable_col2im_cuda( + cudaStream_t stream, const scalar_t *grad_col, const scalar_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, + const int batch_size, const int spatial_size, const int num_heads, + const int channels, const int num_levels, const int num_query, + const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) { + const int num_threads = + (channels > THREADS_PER_BLOCK) ? THREADS_PER_BLOCK : channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > THREADS_PER_BLOCK) { + if ((channels & THREADS_PER_BLOCK - 1) == 0) { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, + num_query, num_point, grad_value, grad_sampling_loc, + grad_attn_weight); + } else { + ms_deformable_col2im_gpu_kernel_gm + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + } + } else { + switch (channels) { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + default: + if (channels < 64) { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, + num_query, num_point, grad_value, grad_sampling_loc, + grad_attn_weight); + } else { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, + num_query, num_point, grad_value, grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) { + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), + "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), + "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), + "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), + "attn_weight tensor has to be contiguous"); + + AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.is_cuda(), + "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", + batch, im2col_step_); + + auto output = + at::zeros({batch, num_query, num_heads, channels}, value.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view( + {batch / im2col_step_, batch_n, num_query, num_heads, channels}); + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + for (int n = 0; n < batch / im2col_step_; ++n) { + auto columns = output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES( + value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] { + ms_deformable_im2col_cuda( + at::cuda::getCurrentCUDAStream(), + value.data_ptr() + n * im2col_step_ * per_value_size, + spatial_shapes.data_ptr(), + level_start_index.data_ptr(), + sampling_loc.data_ptr() + + n * im2col_step_ * per_sample_loc_size, + attn_weight.data_ptr() + + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, + num_point, columns.data_ptr()); + })); + } + + output = output.view({batch, num_query, num_heads * channels}); + + return output; +} + +void ms_deform_attn_cuda_backward( + const at::Tensor &value, const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, const at::Tensor &grad_output, + at::Tensor &grad_value, at::Tensor &grad_sampling_loc, + at::Tensor &grad_attn_weight, const int im2col_step) { + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), + "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), + "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), + "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), + "attn_weight tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), + "grad_output tensor has to be contiguous"); + + AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.is_cuda(), + "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", + batch, im2col_step_); + + const int batch_n = im2col_step_; + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + auto grad_output_n = grad_output.view( + {batch / im2col_step_, batch_n, num_query, num_heads, channels}); + + for (int n = 0; n < batch / im2col_step_; ++n) { + auto grad_output_g = grad_output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES( + value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] { + ms_deformable_col2im_cuda( + at::cuda::getCurrentCUDAStream(), + grad_output_g.data_ptr(), + value.data_ptr() + n * im2col_step_ * per_value_size, + spatial_shapes.data_ptr(), + level_start_index.data_ptr(), + sampling_loc.data_ptr() + + n * im2col_step_ * per_sample_loc_size, + attn_weight.data_ptr() + + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, + num_point, + grad_value.data_ptr() + + n * im2col_step_ * per_value_size, + grad_sampling_loc.data_ptr() + + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data_ptr() + + n * im2col_step_ * per_attn_weight_size); + })); + } +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/roi_align_cuda.cu b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/roi_align_cuda.cu new file mode 100755 index 00000000..3d4f7614 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/roi_align_cuda.cu @@ -0,0 +1,58 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_cuda_helper.hpp" +#include "roi_align_cuda_kernel.cuh" + +void ROIAlignForwardCUDAKernelLauncher(Tensor input, Tensor rois, Tensor output, + Tensor argmax_y, Tensor argmax_x, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + int pool_mode, bool aligned) { + int output_size = output.numel(); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + at::cuda::CUDAGuard device_guard(input.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "roi_align_forward_cuda_kernel", [&] { + roi_align_forward_cuda_kernel + <<>>( + output_size, input.data_ptr(), + rois.data_ptr(), output.data_ptr(), + argmax_y.data_ptr(), argmax_x.data_ptr(), + aligned_height, aligned_width, + static_cast(spatial_scale), sampling_ratio, pool_mode, + aligned, channels, height, width); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void ROIAlignBackwardCUDAKernelLauncher(Tensor grad_output, Tensor rois, + Tensor argmax_y, Tensor argmax_x, + Tensor grad_input, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, int pool_mode, + bool aligned) { + int output_size = grad_output.numel(); + int channels = grad_input.size(1); + int height = grad_input.size(2); + int width = grad_input.size(3); + + at::cuda::CUDAGuard device_guard(grad_output.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "roi_align_backward_cuda_kernel", [&] { + roi_align_backward_cuda_kernel + <<>>( + output_size, grad_output.data_ptr(), + rois.data_ptr(), argmax_y.data_ptr(), + argmax_x.data_ptr(), grad_input.data_ptr(), + aligned_height, aligned_width, + static_cast(spatial_scale), sampling_ratio, pool_mode, + aligned, channels, height, width); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/roi_pool_cuda.cu b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/roi_pool_cuda.cu new file mode 100755 index 00000000..d9cdf305 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/roi_pool_cuda.cu @@ -0,0 +1,50 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_cuda_helper.hpp" +#include "roi_pool_cuda_kernel.cuh" + +void ROIPoolForwardCUDAKernelLauncher(Tensor input, Tensor rois, Tensor output, + Tensor argmax, int pooled_height, + int pooled_width, float spatial_scale) { + int output_size = output.numel(); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + at::cuda::CUDAGuard device_guard(input.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "roi_pool_forward_cuda_kernel", [&] { + roi_pool_forward_cuda_kernel + <<>>( + output_size, input.data_ptr(), + rois.data_ptr(), output.data_ptr(), + argmax.data_ptr(), pooled_height, pooled_width, + static_cast(spatial_scale), channels, height, width); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void ROIPoolBackwardCUDAKernelLauncher(Tensor grad_output, Tensor rois, + Tensor argmax, Tensor grad_input, + int pooled_height, int pooled_width, + float spatial_scale) { + int output_size = grad_output.numel(); + int channels = grad_input.size(1); + int height = grad_input.size(2); + int width = grad_input.size(3); + + at::cuda::CUDAGuard device_guard(grad_output.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "roi_pool_backward_cuda_kernel", [&] { + roi_pool_backward_cuda_kernel + <<>>( + output_size, grad_output.data_ptr(), + rois.data_ptr(), argmax.data_ptr(), + grad_input.data_ptr(), pooled_height, pooled_width, + channels, height, width); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/sync_bn_cuda.cu b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/sync_bn_cuda.cu new file mode 100755 index 00000000..657c8170 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/cuda/sync_bn_cuda.cu @@ -0,0 +1,110 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_cuda_helper.hpp" +#include "sync_bn_cuda_kernel.cuh" + +void SyncBNForwardMeanCUDAKernelLauncher(const Tensor input, Tensor mean) { + int num = input.size(0); + int channels = input.size(1); + int spatial = input.size(2); + + at::cuda::CUDAGuard device_guard(input.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "sync_bn_forward_mean_cuda_kernel", [&] { + sync_bn_forward_mean_cuda_kernel + <<>>( + input.data_ptr(), mean.data_ptr(), num, + channels, spatial); + }); + AT_CUDA_CHECK(cudaGetLastError()); +} + +void SyncBNForwardVarCUDAKernelLauncher(const Tensor input, const Tensor mean, + Tensor var) { + int num = input.size(0); + int channels = input.size(1); + int spatial = input.size(2); + + at::cuda::CUDAGuard device_guard(input.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "sync_bn_forward_mean_cuda_kernel", [&] { + sync_bn_forward_var_cuda_kernel + <<>>( + input.data_ptr(), mean.data_ptr(), + var.data_ptr(), num, channels, spatial); + }); + AT_CUDA_CHECK(cudaGetLastError()); +} + +void SyncBNForwardOutputCUDAKernelLauncher( + const Tensor input, const Tensor mean, const Tensor var, + Tensor running_mean, Tensor running_var, const Tensor weight, + const Tensor bias, Tensor norm, Tensor std, Tensor output, float eps, + float momentum, int group_size) { + int num = input.size(0); + int channels = input.size(1); + int spatial = input.size(2); + + at::cuda::CUDAGuard device_guard(input.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "sync_bn_forward_mean_cuda_kernel", [&] { + sync_bn_forward_output_cuda_kernel + <<>>( + input.data_ptr(), mean.data_ptr(), + var.data_ptr(), running_mean.data_ptr(), + running_var.data_ptr(), weight.data_ptr(), + bias.data_ptr(), norm.data_ptr(), + std.data_ptr(), output.data_ptr(), num, + channels, spatial, eps, momentum, group_size); + }); + AT_CUDA_CHECK(cudaGetLastError()); +} + +void SyncBNBackwardParamCUDAKernelLauncher(const Tensor grad_output, + const Tensor norm, + Tensor grad_weight, + Tensor grad_bias) { + int num = grad_output.size(0); + int channels = grad_output.size(1); + int spatial = grad_output.size(2); + + at::cuda::CUDAGuard device_guard(grad_output.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "sync_bn_backward_param_cuda_kernel", [&] { + sync_bn_backward_param_cuda_kernel + <<>>( + grad_output.data_ptr(), norm.data_ptr(), + grad_weight.data_ptr(), grad_bias.data_ptr(), num, + channels, spatial); + }); + AT_CUDA_CHECK(cudaGetLastError()); +} + +void SyncBNBackwardDataCUDAKernelLauncher(const Tensor grad_output, + const Tensor weight, + const Tensor grad_weight, + const Tensor grad_bias, + const Tensor norm, const Tensor std, + Tensor grad_input) { + int output_size = grad_input.numel(); + int num = grad_input.size(0); + int channels = grad_input.size(1); + int spatial = grad_input.size(2); + + at::cuda::CUDAGuard device_guard(grad_input.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "sync_bn_backward_data_cuda_kernel", [&] { + sync_bn_backward_data_cuda_kernel + <<>>( + output_size, grad_output.data_ptr(), + weight.data_ptr(), grad_weight.data_ptr(), + grad_bias.data_ptr(), norm.data_ptr(), + std.data_ptr(), grad_input.data_ptr(), num, + channels, spatial); + }); + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/deform_conv.cpp b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/deform_conv.cpp new file mode 100755 index 00000000..6ee4895a --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/deform_conv.cpp @@ -0,0 +1,517 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +void deformable_im2col_impl(Tensor data_im, Tensor data_offset, + const int channels, const int height, + const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor data_col) { + DISPATCH_DEVICE_IMPL(deformable_im2col_impl, data_im, data_offset, channels, + height, width, ksize_h, ksize_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, parallel_imgs, + deformable_group, data_col); +} + +void deformable_col2im_impl(Tensor data_col, Tensor data_offset, + const int channels, const int height, + const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor grad_im) { + DISPATCH_DEVICE_IMPL(deformable_col2im_impl, data_col, data_offset, channels, + height, width, ksize_h, ksize_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, parallel_imgs, + deformable_group, grad_im); +} + +void deformable_col2im_coord_impl( + Tensor data_col, Tensor data_im, Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, Tensor grad_offset) { + DISPATCH_DEVICE_IMPL(deformable_col2im_coord_impl, data_col, data_im, + data_offset, channels, height, width, ksize_h, ksize_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + parallel_imgs, deformable_group, grad_offset); +} + +void deform_conv_shape_check(at::Tensor input, at::Tensor offset, + at::Tensor *gradOutput, at::Tensor weight, int kH, + int kW, int dH, int dW, int padH, int padW, + int dilationH, int dilationW, int group, + int deformable_group) { + TORCH_CHECK( + weight.ndimension() == 4, + "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, but got: %s", + weight.ndimension()); + + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + TORCH_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: %d kW: %d", + kH, kW); + + TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW), + "kernel size should be consistent with weight, ", + "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", + kH, kW, weight.size(2), weight.size(3)); + + TORCH_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: %d dW: %d", dH, + dW); + + TORCH_CHECK( + dilationW > 0 && dilationH > 0, + "dilation should be greater than 0, but got dilationH: %d dilationW: %d", + dilationH, dilationW); + + int ndim = input.ndimension(); + int dimf = 0; + int dimh = 1; + int dimw = 2; + + if (ndim == 4) { + dimf++; + dimh++; + dimw++; + } + + TORCH_CHECK(ndim == 3 || ndim == 4, + "3D or 4D input tensor expected but got: %s", ndim); + + long nInputPlane = weight.size(1) * group; + long inputHeight = input.size(dimh); + long inputWidth = input.size(dimw); + long nOutputPlane = weight.size(0); + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + TORCH_CHECK(nInputPlane % deformable_group == 0, + "input channels must divide deformable group size"); + + if (outputWidth < 1 || outputHeight < 1) + AT_ERROR( + "Given input size: (%ld x %ld x %ld). " + "Calculated output size: (%ld x %ld x %ld). Output size is too small", + nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, + outputWidth); + + TORCH_CHECK(input.size(1) == nInputPlane, + "invalid number of input planes, expected: %d, but got: %d", + nInputPlane, input.size(1)); + + TORCH_CHECK((inputHeight >= kH && inputWidth >= kW), + "input image is smaller than kernel"); + + TORCH_CHECK( + (offset.size(2) == outputHeight && offset.size(3) == outputWidth), + "invalid spatial size of offset, expected height: %d width: %d, but " + "got height: %d width: %d", + outputHeight, outputWidth, offset.size(2), offset.size(3)); + + TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), + "invalid number of channels of offset"); + + if (gradOutput != NULL) { + TORCH_CHECK( + gradOutput->size(dimf) == nOutputPlane, + "invalid number of gradOutput planes, expected: %d, but got: %d", + nOutputPlane, gradOutput->size(dimf)); + + TORCH_CHECK( + (gradOutput->size(dimh) == outputHeight && + gradOutput->size(dimw) == outputWidth), + "invalid size of gradOutput, expected height: %d width: %d , but " + "got height: %d width: %d", + outputHeight, outputWidth, gradOutput->size(dimh), + gradOutput->size(dimw)); + } +} + +void deform_conv_forward(Tensor input, Tensor weight, Tensor offset, + Tensor output, Tensor columns, Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef DBNET_CV_WITH_CUDA + CHECK_CUDA_INPUT(input); + CHECK_CUDA_INPUT(offset); + CHECK_CUDA_INPUT(weight); + CHECK_CUDA_INPUT(output); + CHECK_CUDA_INPUT(columns); + CHECK_CUDA_INPUT(ones); +#else + AT_ERROR("DeformConv is not compiled with GPU support"); +#endif + } else { + CHECK_CPU_INPUT(input); + CHECK_CPU_INPUT(offset); + CHECK_CPU_INPUT(weight); + CHECK_CPU_INPUT(output); + CHECK_CPU_INPUT(columns); + CHECK_CPU_INPUT(ones); + } + + deform_conv_shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, + padW, dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + int batch = 1; + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input.unsqueeze_(0); + offset.unsqueeze_(0); + } + + // todo: assert batchsize dividable by im2col_step + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, + outputHeight, outputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < outputHeight * outputWidth) { + ones = at::ones({outputHeight, outputWidth}, input.options()); + } + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + Tensor output_buffer = at::zeros({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}, + output.options()); + + output_buffer = output_buffer.view( + {output_buffer.size(0), group, output_buffer.size(1) / group, + output_buffer.size(2), output_buffer.size(3)}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col_impl(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + output_buffer[elt][g] = output_buffer[elt][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output_buffer[elt][g]); + } + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + } + + output_buffer = output_buffer.view( + {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), + output_buffer.size(3), output_buffer.size(4)}); + + output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step, outputHeight, outputWidth}); + output_buffer.transpose_(1, 2); + output.copy_(output_buffer); + output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + output = output.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + } +} + +void deform_conv_backward_input(Tensor input, Tensor offset, Tensor gradOutput, + Tensor gradInput, Tensor gradOffset, + Tensor weight, Tensor columns, int kW, int kH, + int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef DBNET_CV_WITH_CUDA + CHECK_CUDA_INPUT(input); + CHECK_CUDA_INPUT(offset); + CHECK_CUDA_INPUT(gradOutput); + CHECK_CUDA_INPUT(gradInput); + CHECK_CUDA_INPUT(gradOffset); + CHECK_CUDA_INPUT(weight); + CHECK_CUDA_INPUT(columns); +#else + AT_ERROR("DeformConv is not compiled with GPU support"); +#endif + } else { + CHECK_CPU_INPUT(input); + CHECK_CPU_INPUT(offset); + CHECK_CPU_INPUT(gradOutput); + CHECK_CPU_INPUT(gradInput); + CHECK_CPU_INPUT(gradOffset); + CHECK_CPU_INPUT(weight); + CHECK_CPU_INPUT(columns); + } + deform_conv_shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, + padH, padW, dilationH, dilationW, group, + deformable_group); + + at::DeviceGuard guard(input.device()); + + int batch = 1; + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view({1, input.size(0), input.size(1), input.size(2)}); + offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + // change order of grad output + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, + outputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + // divide into groups + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), group, gradOutput.size(1) / group, + gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); + + for (int g = 0; g < group; g++) { + columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + gradOutput[elt][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), + gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); + + deformable_col2im_coord_impl(columns, input[elt], offset[elt], nInputPlane, + inputHeight, inputWidth, kH, kW, padH, padW, + dH, dW, dilationH, dilationW, im2col_step, + deformable_group, gradOffset[elt]); + + deformable_col2im_impl(columns, offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, + gradInput[elt]); + + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + } + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + gradOffset = + gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); + } +} + +void deform_conv_backward_parameters(Tensor input, Tensor offset, + Tensor gradOutput, Tensor gradWeight, + Tensor columns, Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, float scale, + int im2col_step) { + if (input.device().is_cuda()) { +#ifdef DBNET_CV_WITH_CUDA + CHECK_CUDA_INPUT(input); + CHECK_CUDA_INPUT(offset); + CHECK_CUDA_INPUT(gradOutput); + CHECK_CUDA_INPUT(gradWeight); + CHECK_CUDA_INPUT(columns); + CHECK_CUDA_INPUT(ones); +#else + AT_ERROR("DeformConv is not compiled with GPU support"); +#endif + } else { + CHECK_CPU_INPUT(input); + CHECK_CPU_INPUT(offset); + CHECK_CPU_INPUT(gradOutput); + CHECK_CPU_INPUT(gradWeight); + CHECK_CPU_INPUT(columns); + CHECK_CPU_INPUT(ones); + } + + deform_conv_shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, + dW, padH, padW, dilationH, dilationW, group, + deformable_group); + at::DeviceGuard guard(input.device()); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view( + at::IntList({1, input.size(0), input.size(1), input.size(2)})); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = gradWeight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + Tensor gradOutputBuffer = at::zeros_like(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step, + outputHeight, outputWidth}); + gradOutputBuffer = gradOutputBuffer.contiguous(); + gradOutputBuffer.copy_(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}); + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col_impl(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + // divide into group + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, + gradOutputBuffer.size(2), gradOutputBuffer.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + gradWeight = + gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3)}); + + for (int g = 0; g < group; g++) { + gradWeight[g] = gradWeight[g] + .flatten(1) + .addmm_(gradOutputBuffer[elt][g].flatten(1), + columns[g].transpose(1, 0), 1.0, scale) + .view_as(gradWeight[g]); + } + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), + gradOutputBuffer.size(1) * gradOutputBuffer.size(2), + gradOutputBuffer.size(3), gradOutputBuffer.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3), + gradWeight.size(4)}); + } + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + } +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/deform_roi_pool.cpp b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/deform_roi_pool.cpp new file mode 100755 index 00000000..4fb78a96 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/deform_roi_pool.cpp @@ -0,0 +1,42 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +void deform_roi_pool_forward_impl(Tensor input, Tensor rois, Tensor offset, + Tensor output, int pooled_height, + int pooled_width, float spatial_scale, + int sampling_ratio, float gamma) { + DISPATCH_DEVICE_IMPL(deform_roi_pool_forward_impl, input, rois, offset, + output, pooled_height, pooled_width, spatial_scale, + sampling_ratio, gamma); +} + +void deform_roi_pool_backward_impl(Tensor grad_output, Tensor input, + Tensor rois, Tensor offset, + Tensor grad_input, Tensor grad_offset, + int pooled_height, int pooled_width, + float spatial_scale, int sampling_ratio, + float gamma) { + DISPATCH_DEVICE_IMPL(deform_roi_pool_backward_impl, grad_output, input, rois, + offset, grad_input, grad_offset, pooled_height, + pooled_width, spatial_scale, sampling_ratio, gamma); +} + +void deform_roi_pool_forward(Tensor input, Tensor rois, Tensor offset, + Tensor output, int pooled_height, int pooled_width, + float spatial_scale, int sampling_ratio, + float gamma) { + deform_roi_pool_forward_impl(input, rois, offset, output, pooled_height, + pooled_width, spatial_scale, sampling_ratio, + gamma); +} + +void deform_roi_pool_backward(Tensor grad_output, Tensor input, Tensor rois, + Tensor offset, Tensor grad_input, + Tensor grad_offset, int pooled_height, + int pooled_width, float spatial_scale, + int sampling_ratio, float gamma) { + deform_roi_pool_backward_impl(grad_output, input, rois, offset, grad_input, + grad_offset, pooled_height, pooled_width, + spatial_scale, sampling_ratio, gamma); +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/info.cpp b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/info.cpp new file mode 100755 index 00000000..4c53598c --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/info.cpp @@ -0,0 +1,56 @@ +// Copyright (c) OpenMMLab. All rights reserved +// modified from +// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/vision.cpp +#include "pytorch_cpp_helper.hpp" + +#ifdef DBNET_CV_WITH_CUDA +#ifndef HIP_DIFF +#include +int get_cudart_version() { return CUDART_VERSION; } +#endif +#endif + +std::string get_compiling_cuda_version() { +#ifdef DBNET_CV_WITH_CUDA +#ifndef HIP_DIFF + std::ostringstream oss; + // copied from + // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231 + auto printCudaStyleVersion = [&](int v) { + oss << (v / 1000) << "." << (v / 10 % 100); + if (v % 10 != 0) { + oss << "." << (v % 10); + } + }; + printCudaStyleVersion(get_cudart_version()); + return oss.str(); +#else + return std::string("rocm not available"); +#endif +#else + return std::string("not available"); +#endif +} + +// similar to +// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp +std::string get_compiler_version() { + std::ostringstream ss; +#if defined(__GNUC__) +#ifndef __clang__ + { ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; } +#endif +#endif + +#if defined(__clang_major__) + { + ss << "clang " << __clang_major__ << "." << __clang_minor__ << "." + << __clang_patchlevel__; + } +#endif + +#if defined(_MSC_VER) + { ss << "MSVC " << _MSC_FULL_VER; } +#endif + return ss.str(); +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/modulated_deform_conv.cpp b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/modulated_deform_conv.cpp new file mode 100755 index 00000000..12b538a0 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/modulated_deform_conv.cpp @@ -0,0 +1,237 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +void modulated_deformable_im2col_impl( + const Tensor data_im, const Tensor data_offset, const Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, Tensor data_col) { + DISPATCH_DEVICE_IMPL(modulated_deformable_im2col_impl, data_im, data_offset, + data_mask, batch_size, channels, height_im, width_im, + height_col, width_col, kernel_h, kernel_w, pad_h, pad_w, + stride_h, stride_w, dilation_h, dilation_w, + deformable_group, data_col); +} + +void modulated_deformable_col2im_impl( + const Tensor data_col, const Tensor data_offset, const Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, Tensor grad_im) { + DISPATCH_DEVICE_IMPL(modulated_deformable_col2im_impl, data_col, data_offset, + data_mask, batch_size, channels, height_im, width_im, + height_col, width_col, kernel_h, kernel_w, pad_h, pad_w, + stride_h, stride_w, dilation_h, dilation_w, + deformable_group, grad_im); +} + +void modulated_deformable_col2im_coord_impl( + const Tensor data_col, const Tensor data_im, const Tensor data_offset, + const Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + Tensor grad_offset, Tensor grad_mask) { + DISPATCH_DEVICE_IMPL(modulated_deformable_col2im_coord_impl, data_col, + data_im, data_offset, data_mask, batch_size, channels, + height_im, width_im, height_col, width_col, kernel_h, + kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, + dilation_w, deformable_group, grad_offset, grad_mask); +} + +void modulated_deform_conv_forward( + Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset, + Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w, + const int stride_h, const int stride_w, const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, const int group, + const int deformable_group, const bool with_bias) { + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + // resize output + output = output.view({batch, channels_out, height_out, width_out}).zero_(); + // resize temporary columns + columns = + at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, + input.options()); + + output = output.view({output.size(0), group, output.size(1) / group, + output.size(2), output.size(3)}); + + for (int b = 0; b < batch; b++) { + modulated_deformable_im2col_impl( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + // divide into group + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + + for (int g = 0; g < group; g++) { + output[b][g] = output[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output[b][g]); + } + + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + output = output.view({output.size(0), output.size(1) * output.size(2), + output.size(3), output.size(4)}); + + if (with_bias) { + output += bias.view({1, bias.size(0), 1, 1}); + } +} + +void modulated_deform_conv_backward( + Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset, + Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight, + Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + grad_input = grad_input.view({batch, channels, height, width}); + columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, + input.options()); + + grad_output = + grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, + grad_output.size(2), grad_output.size(3)}); + + for (int b = 0; b < batch; b++) { + // divide int group + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + grad_output[b][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_impl( + columns, input[b], offset[b], mask[b], 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], + grad_mask[b]); + // gradient w.r.t. input data + modulated_deformable_col2im_impl( + columns, offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, grad_input[b]); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and + // group + modulated_deformable_im2col_impl( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + grad_weight = grad_weight.view({group, grad_weight.size(0) / group, + grad_weight.size(1), grad_weight.size(2), + grad_weight.size(3)}); + if (with_bias) + grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); + + for (int g = 0; g < group; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) + .view_as(grad_weight[g]); + if (with_bias) { + grad_bias[g] = + grad_bias[g] + .view({-1, 1}) + .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) + .view(-1); + } + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), grad_weight.size(3), + grad_weight.size(4)}); + if (with_bias) + grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); + } + grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), + grad_output.size(2), grad_output.size(3), + grad_output.size(4)}); +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/ms_deform_attn.cpp b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/ms_deform_attn.cpp new file mode 100755 index 00000000..25c8f620 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/ms_deform_attn.cpp @@ -0,0 +1,60 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +Tensor ms_deform_attn_impl_forward(const Tensor &value, + const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const int im2col_step) { + return DISPATCH_DEVICE_IMPL(ms_deform_attn_impl_forward, value, + spatial_shapes, level_start_index, sampling_loc, + attn_weight, im2col_step); +} + +void ms_deform_attn_impl_backward( + const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, const Tensor &sampling_loc, + const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value, + Tensor &grad_sampling_loc, Tensor &grad_attn_weight, + const int im2col_step) { + DISPATCH_DEVICE_IMPL(ms_deform_attn_impl_backward, value, spatial_shapes, + level_start_index, sampling_loc, attn_weight, + grad_output, grad_value, grad_sampling_loc, + grad_attn_weight, im2col_step); +} + +Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const int im2col_step) { + at::DeviceGuard guard(value.device()); + return ms_deform_attn_impl_forward(value, spatial_shapes, level_start_index, + sampling_loc, attn_weight, im2col_step); +} + +void ms_deform_attn_backward(const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const Tensor &grad_output, Tensor &grad_value, + Tensor &grad_sampling_loc, + Tensor &grad_attn_weight, const int im2col_step) { + at::DeviceGuard guard(value.device()); + ms_deform_attn_impl_backward(value, spatial_shapes, level_start_index, + sampling_loc, attn_weight, grad_output, + grad_value, grad_sampling_loc, grad_attn_weight, + im2col_step); +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/pybind.cpp b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/pybind.cpp new file mode 100755 index 00000000..031797d7 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/pybind.cpp @@ -0,0 +1,218 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include + +#include "pytorch_cpp_helper.hpp" + +std::string get_compiler_version(); +std::string get_compiling_cuda_version(); + + + +void deform_conv_forward(Tensor input, Tensor weight, Tensor offset, + Tensor output, Tensor columns, Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step); + +void deform_conv_backward_input(Tensor input, Tensor offset, Tensor gradOutput, + Tensor gradInput, Tensor gradOffset, + Tensor weight, Tensor columns, int kW, int kH, + int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step); + +void deform_conv_backward_parameters(Tensor input, Tensor offset, + Tensor gradOutput, Tensor gradWeight, + Tensor columns, Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, float scale, + int im2col_step); + +void deform_roi_pool_forward(Tensor input, Tensor rois, Tensor offset, + Tensor output, int pooled_height, int pooled_width, + float spatial_scale, int sampling_ratio, + float gamma); + +void deform_roi_pool_backward(Tensor grad_output, Tensor input, Tensor rois, + Tensor offset, Tensor grad_input, + Tensor grad_offset, int pooled_height, + int pooled_width, float spatial_scale, + int sampling_ratio, float gamma); + + + +void modulated_deform_conv_forward( + Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset, + Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w, + const int stride_h, const int stride_w, const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, const int group, + const int deformable_group, const bool with_bias); + +void modulated_deform_conv_backward( + Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset, + Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight, + Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias); + +Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, const int im2col_step); + +void ms_deform_attn_backward(const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const Tensor &grad_output, Tensor &grad_value, + Tensor &grad_sampling_loc, + Tensor &grad_attn_weight, const int im2col_step); + + + +void roi_align_forward(Tensor input, Tensor rois, Tensor output, + Tensor argmax_y, Tensor argmax_x, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, int pool_mode, bool aligned); + +void roi_align_backward(Tensor grad_output, Tensor rois, Tensor argmax_y, + Tensor argmax_x, Tensor grad_input, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, int pool_mode, bool aligned); + +void roi_pool_forward(Tensor input, Tensor rois, Tensor output, Tensor argmax, + int pooled_height, int pooled_width, float spatial_scale); + +void roi_pool_backward(Tensor grad_output, Tensor rois, Tensor argmax, + Tensor grad_input, int pooled_height, int pooled_width, + float spatial_scale); + +void sync_bn_forward_mean(const Tensor input, Tensor mean); + +void sync_bn_forward_var(const Tensor input, const Tensor mean, Tensor var); + +void sync_bn_forward_output(const Tensor input, const Tensor mean, + const Tensor var, const Tensor weight, + const Tensor bias, Tensor running_mean, + Tensor running_var, Tensor norm, Tensor std, + Tensor output, float eps, float momentum, + int group_size); + +void sync_bn_backward_param(const Tensor grad_output, const Tensor norm, + Tensor grad_weight, Tensor grad_bias); + +void sync_bn_backward_data(const Tensor grad_output, const Tensor weight, + const Tensor grad_weight, const Tensor grad_bias, + const Tensor norm, const Tensor std, + Tensor grad_input); + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + + m.def("get_compiler_version", &get_compiler_version, "get_compiler_version"); + m.def("get_compiling_cuda_version", &get_compiling_cuda_version, + "get_compiling_cuda_version"); + + m.def("deform_conv_forward", &deform_conv_forward, "deform_conv_forward", + py::arg("input"), py::arg("weight"), py::arg("offset"), + py::arg("output"), py::arg("columns"), py::arg("ones"), py::arg("kW"), + py::arg("kH"), py::arg("dW"), py::arg("dH"), py::arg("padW"), + py::arg("padH"), py::arg("dilationW"), py::arg("dilationH"), + py::arg("group"), py::arg("deformable_group"), py::arg("im2col_step")); + m.def("deform_conv_backward_input", &deform_conv_backward_input, + "deform_conv_backward_input", py::arg("input"), py::arg("offset"), + py::arg("gradOutput"), py::arg("gradInput"), py::arg("gradOffset"), + py::arg("weight"), py::arg("columns"), py::arg("kW"), py::arg("kH"), + py::arg("dW"), py::arg("dH"), py::arg("padW"), py::arg("padH"), + py::arg("dilationW"), py::arg("dilationH"), py::arg("group"), + py::arg("deformable_group"), py::arg("im2col_step")); + m.def("deform_conv_backward_parameters", &deform_conv_backward_parameters, + "deform_conv_backward_parameters", py::arg("input"), py::arg("offset"), + py::arg("gradOutput"), py::arg("gradWeight"), py::arg("columns"), + py::arg("ones"), py::arg("kW"), py::arg("kH"), py::arg("dW"), + py::arg("dH"), py::arg("padW"), py::arg("padH"), py::arg("dilationW"), + py::arg("dilationH"), py::arg("group"), py::arg("deformable_group"), + py::arg("scale"), py::arg("im2col_step")); + m.def("deform_roi_pool_forward", &deform_roi_pool_forward, + "deform roi pool forward", py::arg("input"), py::arg("rois"), + py::arg("offset"), py::arg("output"), py::arg("pooled_height"), + py::arg("pooled_width"), py::arg("spatial_scale"), + py::arg("sampling_ratio"), py::arg("gamma")); + m.def("deform_roi_pool_backward", &deform_roi_pool_backward, + "deform roi pool backward", py::arg("grad_output"), py::arg("input"), + py::arg("rois"), py::arg("offset"), py::arg("grad_input"), + py::arg("grad_offset"), py::arg("pooled_height"), + py::arg("pooled_width"), py::arg("spatial_scale"), + py::arg("sampling_ratio"), py::arg("gamma")); + + m.def("modulated_deform_conv_forward", &modulated_deform_conv_forward, + "modulated deform conv forward", py::arg("input"), py::arg("weight"), + py::arg("bias"), py::arg("ones"), py::arg("offset"), py::arg("mask"), + py::arg("output"), py::arg("columns"), py::arg("kernel_h"), + py::arg("kernel_w"), py::arg("stride_h"), py::arg("stride_w"), + py::arg("pad_h"), py::arg("pad_w"), py::arg("dilation_h"), + py::arg("dilation_w"), py::arg("group"), py::arg("deformable_group"), + py::arg("with_bias")); + m.def("modulated_deform_conv_backward", &modulated_deform_conv_backward, + "modulated deform conv backward", py::arg("input"), py::arg("weight"), + py::arg("bias"), py::arg("ones"), py::arg("offset"), py::arg("mask"), + py::arg("columns"), py::arg("grad_input"), py::arg("grad_weight"), + py::arg("grad_bias"), py::arg("grad_offset"), py::arg("grad_mask"), + py::arg("grad_output"), py::arg("kernel_h"), py::arg("kernel_w"), + py::arg("stride_h"), py::arg("stride_w"), py::arg("pad_h"), + py::arg("pad_w"), py::arg("dilation_h"), py::arg("dilation_w"), + py::arg("group"), py::arg("deformable_group"), py::arg("with_bias")); + + m.def("roi_align_forward", &roi_align_forward, "roi_align forward", + py::arg("input"), py::arg("rois"), py::arg("output"), + py::arg("argmax_y"), py::arg("argmax_x"), py::arg("aligned_height"), + py::arg("aligned_width"), py::arg("spatial_scale"), + py::arg("sampling_ratio"), py::arg("pool_mode"), py::arg("aligned")); + m.def("roi_align_backward", &roi_align_backward, "roi_align backward", + py::arg("grad_output"), py::arg("rois"), py::arg("argmax_y"), + py::arg("argmax_x"), py::arg("grad_input"), py::arg("aligned_height"), + py::arg("aligned_width"), py::arg("spatial_scale"), + py::arg("sampling_ratio"), py::arg("pool_mode"), py::arg("aligned")); + m.def("roi_pool_forward", &roi_pool_forward, "roi_pool forward", + py::arg("input"), py::arg("rois"), py::arg("output"), py::arg("argmax"), + py::arg("pooled_height"), py::arg("pooled_width"), + py::arg("spatial_scale")); + m.def("roi_pool_backward", &roi_pool_backward, "roi_pool backward", + py::arg("grad_output"), py::arg("rois"), py::arg("argmax"), + py::arg("grad_input"), py::arg("pooled_height"), + py::arg("pooled_width"), py::arg("spatial_scale")); + m.def("sync_bn_forward_mean", &sync_bn_forward_mean, "sync_bn forward_mean", + py::arg("input"), py::arg("mean")); + m.def("sync_bn_forward_var", &sync_bn_forward_var, "sync_bn forward_var", + py::arg("input"), py::arg("mean"), py::arg("var")); + m.def("sync_bn_forward_output", &sync_bn_forward_output, + "sync_bn forward_output", py::arg("input"), py::arg("mean"), + py::arg("var"), py::arg("weight"), py::arg("bias"), + py::arg("running_mean"), py::arg("running_var"), py::arg("norm"), + py::arg("std"), py::arg("output"), py::arg("eps"), py::arg("momentum"), + py::arg("group_size")); + m.def("sync_bn_backward_param", &sync_bn_backward_param, + "sync_bn backward_param", py::arg("grad_output"), py::arg("norm"), + py::arg("grad_weight"), py::arg("grad_bias")); + m.def("sync_bn_backward_data", &sync_bn_backward_data, + "sync_bn backward_data", py::arg("grad_output"), py::arg("weight"), + py::arg("grad_weight"), py::arg("grad_bias"), py::arg("norm"), + py::arg("std"), py::arg("grad_input")); + + m.def("ms_deform_attn_forward", &ms_deform_attn_forward, + "forward function of multi-scale deformable attention", + py::arg("value"), py::arg("value_spatial_shapes"), + py::arg("value_level_start_index"), py::arg("sampling_locations"), + py::arg("attention_weights"), py::arg("im2col_step")); + m.def("ms_deform_attn_backward", &ms_deform_attn_backward, + "backward function of multi-scale deformable attention", + py::arg("value"), py::arg("value_spatial_shapes"), + py::arg("value_level_start_index"), py::arg("sampling_locations"), + py::arg("attention_weights"), py::arg("grad_output"), + py::arg("grad_value"), py::arg("grad_sampling_loc"), + py::arg("grad_attn_weight"), py::arg("im2col_step")); + +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/roi_align.cpp b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/roi_align.cpp new file mode 100755 index 00000000..6e707739 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/roi_align.cpp @@ -0,0 +1,41 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output, + Tensor argmax_y, Tensor argmax_x, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + int pool_mode, bool aligned) { + DISPATCH_DEVICE_IMPL(roi_align_forward_impl, input, rois, output, argmax_y, + argmax_x, aligned_height, aligned_width, spatial_scale, + sampling_ratio, pool_mode, aligned); +} + +void roi_align_backward_impl(Tensor grad_output, Tensor rois, Tensor argmax_y, + Tensor argmax_x, Tensor grad_input, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + int pool_mode, bool aligned) { + DISPATCH_DEVICE_IMPL(roi_align_backward_impl, grad_output, rois, argmax_y, + argmax_x, grad_input, aligned_height, aligned_width, + spatial_scale, sampling_ratio, pool_mode, aligned); +} + +void roi_align_forward(Tensor input, Tensor rois, Tensor output, + Tensor argmax_y, Tensor argmax_x, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, int pool_mode, bool aligned) { + roi_align_forward_impl(input, rois, output, argmax_y, argmax_x, + aligned_height, aligned_width, spatial_scale, + sampling_ratio, pool_mode, aligned); +} + +void roi_align_backward(Tensor grad_output, Tensor rois, Tensor argmax_y, + Tensor argmax_x, Tensor grad_input, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, int pool_mode, bool aligned) { + roi_align_backward_impl(grad_output, rois, argmax_y, argmax_x, grad_input, + aligned_height, aligned_width, spatial_scale, + sampling_ratio, pool_mode, aligned); +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/roi_pool.cpp b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/roi_pool.cpp new file mode 100755 index 00000000..bba90b80 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/roi_pool.cpp @@ -0,0 +1,31 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output, + Tensor argmax, int pooled_height, int pooled_width, + float spatial_scale) { + DISPATCH_DEVICE_IMPL(roi_pool_forward_impl, input, rois, output, argmax, + pooled_height, pooled_width, spatial_scale); +} + +void roi_pool_backward_impl(Tensor grad_output, Tensor rois, Tensor argmax, + Tensor grad_input, int pooled_height, + int pooled_width, float spatial_scale) { + DISPATCH_DEVICE_IMPL(roi_pool_backward_impl, grad_output, rois, argmax, + grad_input, pooled_height, pooled_width, spatial_scale); +} + +void roi_pool_forward(Tensor input, Tensor rois, Tensor output, Tensor argmax, + int pooled_height, int pooled_width, + float spatial_scale) { + roi_pool_forward_impl(input, rois, output, argmax, pooled_height, + pooled_width, spatial_scale); +} + +void roi_pool_backward(Tensor grad_output, Tensor rois, Tensor argmax, + Tensor grad_input, int pooled_height, int pooled_width, + float spatial_scale) { + roi_pool_backward_impl(grad_output, rois, argmax, grad_input, pooled_height, + pooled_width, spatial_scale); +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/sync_bn.cpp b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/sync_bn.cpp new file mode 100755 index 00000000..fd5a5132 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/csrc/pytorch/sync_bn.cpp @@ -0,0 +1,69 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +void sync_bn_forward_mean_impl(const Tensor input, Tensor mean) { + DISPATCH_DEVICE_IMPL(sync_bn_forward_mean_impl, input, mean); +} + +void sync_bn_forward_var_impl(const Tensor input, const Tensor mean, + Tensor var) { + DISPATCH_DEVICE_IMPL(sync_bn_forward_var_impl, input, mean, var); +} + +void sync_bn_forward_output_impl(const Tensor input, const Tensor mean, + const Tensor var, Tensor running_mean, + Tensor running_var, const Tensor weight, + const Tensor bias, Tensor norm, Tensor std, + Tensor output, float eps, float momentum, + int group_size) { + DISPATCH_DEVICE_IMPL(sync_bn_forward_output_impl, input, mean, var, + running_mean, running_var, weight, bias, norm, std, + output, eps, momentum, group_size); +} + +void sync_bn_backward_param_impl(const Tensor grad_output, const Tensor norm, + Tensor grad_weight, Tensor grad_bias) { + DISPATCH_DEVICE_IMPL(sync_bn_backward_param_impl, grad_output, norm, + grad_weight, grad_bias); +} + +void sync_bn_backward_data_impl(const Tensor grad_output, const Tensor weight, + const Tensor grad_weight, + const Tensor grad_bias, const Tensor norm, + const Tensor std, Tensor grad_input) { + DISPATCH_DEVICE_IMPL(sync_bn_backward_data_impl, grad_output, weight, + grad_weight, grad_bias, norm, std, grad_input); +} + +void sync_bn_forward_mean(const Tensor input, Tensor mean) { + sync_bn_forward_mean_impl(input, mean); +} + +void sync_bn_forward_var(const Tensor input, const Tensor mean, Tensor var) { + sync_bn_forward_var_impl(input, mean, var); +} + +void sync_bn_forward_output(const Tensor input, const Tensor mean, + const Tensor var, const Tensor weight, + const Tensor bias, Tensor running_mean, + Tensor running_var, Tensor norm, Tensor std, + Tensor output, float eps, float momentum, + int group_size) { + sync_bn_forward_output_impl(input, mean, var, running_mean, running_var, + weight, bias, norm, std, output, eps, momentum, + group_size); +} + +void sync_bn_backward_param(const Tensor grad_output, const Tensor norm, + Tensor grad_weight, Tensor grad_bias) { + sync_bn_backward_param_impl(grad_output, norm, grad_weight, grad_bias); +} + +void sync_bn_backward_data(const Tensor grad_output, const Tensor weight, + const Tensor grad_weight, const Tensor grad_bias, + const Tensor norm, const Tensor std, + Tensor grad_input) { + sync_bn_backward_data_impl(grad_output, weight, grad_weight, grad_bias, norm, + std, grad_input); +} diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/deform_conv.py b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/deform_conv.py new file mode 100755 index 00000000..5d3cf655 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/deform_conv.py @@ -0,0 +1,408 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair, _single + +from dbnet_cv.utils import deprecated_api_warning +from ..cnn import CONV_LAYERS +from ..utils import ext_loader, print_log + +ext_module = ext_loader.load_ext('_ext', [ + 'deform_conv_forward', 'deform_conv_backward_input', + 'deform_conv_backward_parameters' +]) + + +class DeformConv2dFunction(Function): + + @staticmethod + def symbolic(g, + input, + offset, + weight, + stride, + padding, + dilation, + groups, + deform_groups, + bias=False, + im2col_step=32): + return g.op( + 'dbnet_cv::DBNET_CVDeformConv2d', + input, + offset, + weight, + stride_i=stride, + padding_i=padding, + dilation_i=dilation, + groups_i=groups, + deform_groups_i=deform_groups, + bias_i=bias, + im2col_step_i=im2col_step) + + @staticmethod + def forward(ctx, + input: Tensor, + offset: Tensor, + weight: Tensor, + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[int, Tuple[int, ...]] = 0, + dilation: Union[int, Tuple[int, ...]] = 1, + groups: int = 1, + deform_groups: int = 1, + bias: bool = False, + im2col_step: int = 32) -> Tensor: + if input is not None and input.dim() != 4: + raise ValueError( + f'Expected 4D tensor as input, got {input.dim()}D tensor \ + instead.') + assert bias is False, 'Only support bias is False.' + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deform_groups = deform_groups + ctx.im2col_step = im2col_step + + # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; + # amp won't cast the type of model (float32), but "offset" is cast + # to float16 by nn.Conv2d automatically, leading to the type + # mismatch with input (when it is float32) or weight. + # The flag for whether to use fp16 or amp is the type of "offset", + # we cast weight and input to temporarily support fp16 and amp + # whatever the pytorch version is. + input = input.type_as(offset) + weight = weight.type_as(input) + ctx.save_for_backward(input, offset, weight) + + output = input.new_empty( + DeformConv2dFunction._output_size(ctx, input, weight)) + + ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones + + cur_im2col_step = min(ctx.im2col_step, input.size(0)) + assert (input.size(0) % cur_im2col_step + ) == 0, 'batch size must be divisible by im2col_step' + ext_module.deform_conv_forward( + input, + weight, + offset, + output, + ctx.bufs_[0], + ctx.bufs_[1], + kW=weight.size(3), + kH=weight.size(2), + dW=ctx.stride[1], + dH=ctx.stride[0], + padW=ctx.padding[1], + padH=ctx.padding[0], + dilationW=ctx.dilation[1], + dilationH=ctx.dilation[0], + group=ctx.groups, + deformable_group=ctx.deform_groups, + im2col_step=cur_im2col_step) + return output + + @staticmethod + @once_differentiable + def backward( + ctx, grad_output: Tensor + ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], None, + None, None, None, None, None, None]: + input, offset, weight = ctx.saved_tensors + + grad_input = grad_offset = grad_weight = None + + cur_im2col_step = min(ctx.im2col_step, input.size(0)) + assert (input.size(0) % cur_im2col_step + ) == 0, 'batch size must be divisible by im2col_step' + + grad_output = grad_output.contiguous() + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + ext_module.deform_conv_backward_input( + input, + offset, + grad_output, + grad_input, + grad_offset, + weight, + ctx.bufs_[0], + kW=weight.size(3), + kH=weight.size(2), + dW=ctx.stride[1], + dH=ctx.stride[0], + padW=ctx.padding[1], + padH=ctx.padding[0], + dilationW=ctx.dilation[1], + dilationH=ctx.dilation[0], + group=ctx.groups, + deformable_group=ctx.deform_groups, + im2col_step=cur_im2col_step) + + if ctx.needs_input_grad[2]: + grad_weight = torch.zeros_like(weight) + ext_module.deform_conv_backward_parameters( + input, + offset, + grad_output, + grad_weight, + ctx.bufs_[0], + ctx.bufs_[1], + kW=weight.size(3), + kH=weight.size(2), + dW=ctx.stride[1], + dH=ctx.stride[0], + padW=ctx.padding[1], + padH=ctx.padding[0], + dilationW=ctx.dilation[1], + dilationH=ctx.dilation[0], + group=ctx.groups, + deformable_group=ctx.deform_groups, + scale=1, + im2col_step=cur_im2col_step) + + return grad_input, grad_offset, grad_weight, \ + None, None, None, None, None, None, None + + @staticmethod + def _output_size(ctx, input, weight): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = ctx.padding[d] + kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = ctx.stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError( + 'convolution input is too small (output would be ' + + 'x'.join(map(str, output_size)) + ')') + return output_size + + +deform_conv2d = DeformConv2dFunction.apply + + +class DeformConv2d(nn.Module): + r"""Deformable 2D convolution. + + Applies a deformable 2D convolution over an input signal composed of + several input planes. DeformConv2d was described in the paper + `Deformable Convolutional Networks + `_ + + Note: + The argument ``im2col_step`` was added in version 1.3.17, which means + number of samples processed by the ``im2col_cuda_kernel`` per call. + It enables users to define ``batch_size`` and ``im2col_step`` more + flexibly and solved `issue dbnet_cv#1440 + `_. + + Args: + in_channels (int): Number of channels in the input image. + out_channels (int): Number of channels produced by the convolution. + kernel_size(int, tuple): Size of the convolving kernel. + stride(int, tuple): Stride of the convolution. Default: 1. + padding (int or tuple): Zero-padding added to both sides of the input. + Default: 0. + dilation (int or tuple): Spacing between kernel elements. Default: 1. + groups (int): Number of blocked connections from input. + channels to output channels. Default: 1. + deform_groups (int): Number of deformable group partitions. + bias (bool): If True, adds a learnable bias to the output. + Default: False. + im2col_step (int): Number of samples processed by im2col_cuda_kernel + per call. It will work when ``batch_size`` > ``im2col_step``, but + ``batch_size`` must be divisible by ``im2col_step``. Default: 32. + `New in version 1.3.17.` + """ + + @deprecated_api_warning({'deformable_groups': 'deform_groups'}, + cls_name='DeformConv2d') + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, ...]], + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[int, Tuple[int, ...]] = 0, + dilation: Union[int, Tuple[int, ...]] = 1, + groups: int = 1, + deform_groups: int = 1, + bias: bool = False, + im2col_step: int = 32) -> None: + super().__init__() + + assert not bias, \ + f'bias={bias} is not supported in DeformConv2d.' + assert in_channels % groups == 0, \ + f'in_channels {in_channels} cannot be divisible by groups {groups}' + assert out_channels % groups == 0, \ + f'out_channels {out_channels} cannot be divisible by groups \ + {groups}' + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deform_groups = deform_groups + self.im2col_step = im2col_step + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + # only weight, no bias + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // self.groups, + *self.kernel_size)) + + self.reset_parameters() + + def reset_parameters(self): + # switch the initialization of `self.weight` to the standard kaiming + # method described in `Delving deep into rectifiers: Surpassing + # human-level performance on ImageNet classification` - He, K. et al. + # (2015), using a uniform distribution + nn.init.kaiming_uniform_(self.weight, nonlinearity='relu') + + def forward(self, x: Tensor, offset: Tensor) -> Tensor: + """Deformable Convolutional forward function. + + Args: + x (Tensor): Input feature, shape (B, C_in, H_in, W_in) + offset (Tensor): Offset for deformable convolution, shape + (B, deform_groups*kernel_size[0]*kernel_size[1]*2, + H_out, W_out), H_out, W_out are equal to the output's. + + An offset is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`. + The spatial arrangement is like: + + .. code:: text + + (x0, y0) (x1, y1) (x2, y2) + (x3, y3) (x4, y4) (x5, y5) + (x6, y6) (x7, y7) (x8, y8) + + Returns: + Tensor: Output of the layer. + """ + # To fix an assert error in deform_conv_cuda.cpp:128 + # input image is smaller than kernel + input_pad = (x.size(2) < self.kernel_size[0]) or (x.size(3) < + self.kernel_size[1]) + if input_pad: + pad_h = max(self.kernel_size[0] - x.size(2), 0) + pad_w = max(self.kernel_size[1] - x.size(3), 0) + x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0) + offset = offset.contiguous() + out = deform_conv2d(x, offset, self.weight, self.stride, self.padding, + self.dilation, self.groups, self.deform_groups, + False, self.im2col_step) + if input_pad: + out = out[:, :, :out.size(2) - pad_h, :out.size(3) - + pad_w].contiguous() + return out + + def __repr__(self): + s = self.__class__.__name__ + s += f'(in_channels={self.in_channels},\n' + s += f'out_channels={self.out_channels},\n' + s += f'kernel_size={self.kernel_size},\n' + s += f'stride={self.stride},\n' + s += f'padding={self.padding},\n' + s += f'dilation={self.dilation},\n' + s += f'groups={self.groups},\n' + s += f'deform_groups={self.deform_groups},\n' + # bias is not supported in DeformConv2d. + s += 'bias=False)' + return s + + +@CONV_LAYERS.register_module('DCN') +class DeformConv2dPack(DeformConv2d): + """A Deformable Conv Encapsulation that acts as normal Conv layers. + + The offset tensor is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`. + The spatial arrangement is like: + + .. code:: text + + (x0, y0) (x1, y1) (x2, y2) + (x3, y3) (x4, y4) (x5, y5) + (x6, y6) (x7, y7) (x8, y8) + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x: Tensor) -> Tensor: # type: ignore + offset = self.conv_offset(x) + return deform_conv2d(x, offset, self.weight, self.stride, self.padding, + self.dilation, self.groups, self.deform_groups, + False, self.im2col_step) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + version = local_metadata.get('version', None) + + if version is None or version < 2: + # the key is different in early versions + # In version < 2, DeformConvPack loads previous benchmark models. + if (prefix + 'conv_offset.weight' not in state_dict + and prefix[:-1] + '_offset.weight' in state_dict): + state_dict[prefix + 'conv_offset.weight'] = state_dict.pop( + prefix[:-1] + '_offset.weight') + if (prefix + 'conv_offset.bias' not in state_dict + and prefix[:-1] + '_offset.bias' in state_dict): + state_dict[prefix + + 'conv_offset.bias'] = state_dict.pop(prefix[:-1] + + '_offset.bias') + + if version is not None and version > 1: + print_log( + f'DeformConv2dPack {prefix.rstrip(".")} is upgraded to ' + 'version 2.', + logger='root') + + super()._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, + error_msgs) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/deform_roi_pool.py b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/deform_roi_pool.py new file mode 100755 index 00000000..321b4303 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/deform_roi_pool.py @@ -0,0 +1,209 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +from torch import Tensor, nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['deform_roi_pool_forward', 'deform_roi_pool_backward']) + + +class DeformRoIPoolFunction(Function): + + @staticmethod + def symbolic(g, input, rois, offset, output_size, spatial_scale, + sampling_ratio, gamma): + return g.op( + 'dbnet_cv::DBNET_CVDeformRoIPool', + input, + rois, + offset, + pooled_height_i=output_size[0], + pooled_width_i=output_size[1], + spatial_scale_f=spatial_scale, + sampling_ratio_f=sampling_ratio, + gamma_f=gamma) + + @staticmethod + def forward(ctx, + input: Tensor, + rois: Tensor, + offset: Optional[Tensor], + output_size: Tuple[int, ...], + spatial_scale: float = 1.0, + sampling_ratio: int = 0, + gamma: float = 0.1) -> Tensor: + if offset is None: + offset = input.new_zeros(0) + ctx.output_size = _pair(output_size) + ctx.spatial_scale = float(spatial_scale) + ctx.sampling_ratio = int(sampling_ratio) + ctx.gamma = float(gamma) + + assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!' + + output_shape = (rois.size(0), input.size(1), ctx.output_size[0], + ctx.output_size[1]) + output = input.new_zeros(output_shape) + + ext_module.deform_roi_pool_forward( + input, + rois, + offset, + output, + pooled_height=ctx.output_size[0], + pooled_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale, + sampling_ratio=ctx.sampling_ratio, + gamma=ctx.gamma) + + ctx.save_for_backward(input, rois, offset) + return output + + @staticmethod + @once_differentiable + def backward( + ctx, grad_output: Tensor + ) -> Tuple[Tensor, None, Tensor, None, None, None, None]: + input, rois, offset = ctx.saved_tensors + grad_input = grad_output.new_zeros(input.shape) + grad_offset = grad_output.new_zeros(offset.shape) + + ext_module.deform_roi_pool_backward( + grad_output, + input, + rois, + offset, + grad_input, + grad_offset, + pooled_height=ctx.output_size[0], + pooled_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale, + sampling_ratio=ctx.sampling_ratio, + gamma=ctx.gamma) + if grad_offset.numel() == 0: + grad_offset = None + return grad_input, None, grad_offset, None, None, None, None + + +deform_roi_pool = DeformRoIPoolFunction.apply + + +class DeformRoIPool(nn.Module): + + def __init__(self, + output_size: Tuple[int, ...], + spatial_scale: float = 1.0, + sampling_ratio: int = 0, + gamma: float = 0.1): + super().__init__() + self.output_size = _pair(output_size) + self.spatial_scale = float(spatial_scale) + self.sampling_ratio = int(sampling_ratio) + self.gamma = float(gamma) + + def forward(self, + input: Tensor, + rois: Tensor, + offset: Optional[Tensor] = None) -> Tensor: + return deform_roi_pool(input, rois, offset, self.output_size, + self.spatial_scale, self.sampling_ratio, + self.gamma) + + +class DeformRoIPoolPack(DeformRoIPool): + + def __init__(self, + output_size: Tuple[int, ...], + output_channels: int, + deform_fc_channels: int = 1024, + spatial_scale: float = 1.0, + sampling_ratio: int = 0, + gamma: float = 0.1): + super().__init__(output_size, spatial_scale, sampling_ratio, gamma) + + self.output_channels = output_channels + self.deform_fc_channels = deform_fc_channels + + self.offset_fc = nn.Sequential( + nn.Linear( + self.output_size[0] * self.output_size[1] * + self.output_channels, self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, + self.output_size[0] * self.output_size[1] * 2)) + self.offset_fc[-1].weight.data.zero_() + self.offset_fc[-1].bias.data.zero_() + + def forward(self, input: Tensor, rois: Tensor) -> Tensor: # type: ignore + assert input.size(1) == self.output_channels + x = deform_roi_pool(input, rois, None, self.output_size, + self.spatial_scale, self.sampling_ratio, + self.gamma) + rois_num = rois.size(0) + offset = self.offset_fc(x.view(rois_num, -1)) + offset = offset.view(rois_num, 2, self.output_size[0], + self.output_size[1]) + return deform_roi_pool(input, rois, offset, self.output_size, + self.spatial_scale, self.sampling_ratio, + self.gamma) + + +class ModulatedDeformRoIPoolPack(DeformRoIPool): + + def __init__(self, + output_size: Tuple[int, ...], + output_channels: int, + deform_fc_channels: int = 1024, + spatial_scale: float = 1.0, + sampling_ratio: int = 0, + gamma: float = 0.1): + super().__init__(output_size, spatial_scale, sampling_ratio, gamma) + + self.output_channels = output_channels + self.deform_fc_channels = deform_fc_channels + + self.offset_fc = nn.Sequential( + nn.Linear( + self.output_size[0] * self.output_size[1] * + self.output_channels, self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, + self.output_size[0] * self.output_size[1] * 2)) + self.offset_fc[-1].weight.data.zero_() + self.offset_fc[-1].bias.data.zero_() + + self.mask_fc = nn.Sequential( + nn.Linear( + self.output_size[0] * self.output_size[1] * + self.output_channels, self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, + self.output_size[0] * self.output_size[1] * 1), + nn.Sigmoid()) + self.mask_fc[2].weight.data.zero_() + self.mask_fc[2].bias.data.zero_() + + def forward(self, input: Tensor, rois: Tensor) -> Tensor: # type: ignore + assert input.size(1) == self.output_channels + x = deform_roi_pool(input, rois, None, self.output_size, + self.spatial_scale, self.sampling_ratio, + self.gamma) + rois_num = rois.size(0) + offset = self.offset_fc(x.view(rois_num, -1)) + offset = offset.view(rois_num, 2, self.output_size[0], + self.output_size[1]) + mask = self.mask_fc(x.view(rois_num, -1)) + mask = mask.view(rois_num, 1, self.output_size[0], self.output_size[1]) + d = deform_roi_pool(input, rois, offset, self.output_size, + self.spatial_scale, self.sampling_ratio, + self.gamma) + return d * mask diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/deprecated_wrappers.py b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/deprecated_wrappers.py new file mode 100755 index 00000000..37b8031b --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/deprecated_wrappers.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This file is for backward compatibility. +# Module wrappers for empty tensor have been moved to dbnet_cv.cnn.bricks. +import warnings + +from ..cnn.bricks.wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d + + +class Conv2d_deprecated(Conv2d): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + warnings.warn( + 'Importing Conv2d wrapper from "dbnet_cv.ops" will be deprecated in' + ' the future. Please import them from "dbnet_cv.cnn" instead', + DeprecationWarning) + + +class ConvTranspose2d_deprecated(ConvTranspose2d): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + warnings.warn( + 'Importing ConvTranspose2d wrapper from "dbnet_cv.ops" will be ' + 'deprecated in the future. Please import them from "dbnet_cv.cnn" ' + 'instead', DeprecationWarning) + + +class MaxPool2d_deprecated(MaxPool2d): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + warnings.warn( + 'Importing MaxPool2d wrapper from "dbnet_cv.ops" will be deprecated in' + ' the future. Please import them from "dbnet_cv.cnn" instead', + DeprecationWarning) + + +class Linear_deprecated(Linear): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + warnings.warn( + 'Importing Linear wrapper from "dbnet_cv.ops" will be deprecated in' + ' the future. Please import them from "dbnet_cv.cnn" instead', + DeprecationWarning) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/info.py b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/info.py new file mode 100755 index 00000000..29f2e559 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/info.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import glob +import os + +import torch + +if torch.__version__ == 'parrots': + import parrots + + def get_compiler_version(): + return 'GCC ' + parrots.version.compiler + + def get_compiling_cuda_version(): + return parrots.version.cuda +else: + from ..utils import ext_loader + ext_module = ext_loader.load_ext( + '_ext', ['get_compiler_version', 'get_compiling_cuda_version']) + + def get_compiler_version(): + return ext_module.get_compiler_version() + + def get_compiling_cuda_version(): + return ext_module.get_compiling_cuda_version() + + +def get_onnxruntime_op_path(): + wildcard = os.path.join( + os.path.abspath(os.path.dirname(os.path.dirname(__file__))), + '_ext_ort.*.so') + + paths = glob.glob(wildcard) + if len(paths) > 0: + return paths[0] + else: + return '' diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/modulated_deform_conv.py b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/modulated_deform_conv.py new file mode 100755 index 00000000..3a035627 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/modulated_deform_conv.py @@ -0,0 +1,283 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair, _single + +from dbnet_cv.utils import deprecated_api_warning +from ..cnn import CONV_LAYERS +from ..utils import ext_loader, print_log + +ext_module = ext_loader.load_ext( + '_ext', + ['modulated_deform_conv_forward', 'modulated_deform_conv_backward']) + + +class ModulatedDeformConv2dFunction(Function): + + @staticmethod + def symbolic(g, input, offset, mask, weight, bias, stride, padding, + dilation, groups, deform_groups): + input_tensors = [input, offset, mask, weight] + if bias is not None: + input_tensors.append(bias) + return g.op( + 'dbnet_cv::DBNET_CVModulatedDeformConv2d', + *input_tensors, + stride_i=stride, + padding_i=padding, + dilation_i=dilation, + groups_i=groups, + deform_groups_i=deform_groups) + + @staticmethod + def forward(ctx, + input, + offset, + mask, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + deform_groups=1): + if input is not None and input.dim() != 4: + raise ValueError( + f'Expected 4D tensor as input, got {input.dim()}D tensor \ + instead.') + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deform_groups = deform_groups + ctx.with_bias = bias is not None + if not ctx.with_bias: + bias = input.new_empty(0) # fake tensor + # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; + # amp won't cast the type of model (float32), but "offset" is cast + # to float16 by nn.Conv2d automatically, leading to the type + # mismatch with input (when it is float32) or weight. + # The flag for whether to use fp16 or amp is the type of "offset", + # we cast weight and input to temporarily support fp16 and amp + # whatever the pytorch version is. + input = input.type_as(offset) + weight = weight.type_as(input) + bias = bias.type_as(input) + ctx.save_for_backward(input, offset, mask, weight, bias) + output = input.new_empty( + ModulatedDeformConv2dFunction._output_size(ctx, input, weight)) + ctx._bufs = [input.new_empty(0), input.new_empty(0)] + ext_module.modulated_deform_conv_forward( + input, + weight, + bias, + ctx._bufs[0], + offset, + mask, + output, + ctx._bufs[1], + kernel_h=weight.size(2), + kernel_w=weight.size(3), + stride_h=ctx.stride[0], + stride_w=ctx.stride[1], + pad_h=ctx.padding[0], + pad_w=ctx.padding[1], + dilation_h=ctx.dilation[0], + dilation_w=ctx.dilation[1], + group=ctx.groups, + deformable_group=ctx.deform_groups, + with_bias=ctx.with_bias) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + grad_mask = torch.zeros_like(mask) + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) + grad_output = grad_output.contiguous() + ext_module.modulated_deform_conv_backward( + input, + weight, + bias, + ctx._bufs[0], + offset, + mask, + ctx._bufs[1], + grad_input, + grad_weight, + grad_bias, + grad_offset, + grad_mask, + grad_output, + kernel_h=weight.size(2), + kernel_w=weight.size(3), + stride_h=ctx.stride[0], + stride_w=ctx.stride[1], + pad_h=ctx.padding[0], + pad_w=ctx.padding[1], + dilation_h=ctx.dilation[0], + dilation_w=ctx.dilation[1], + group=ctx.groups, + deformable_group=ctx.deform_groups, + with_bias=ctx.with_bias) + if not ctx.with_bias: + grad_bias = None + + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, + None, None, None, None, None) + + @staticmethod + def _output_size(ctx, input, weight): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = ctx.padding[d] + kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = ctx.stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError( + 'convolution input is too small (output would be ' + + 'x'.join(map(str, output_size)) + ')') + return output_size + + +modulated_deform_conv2d = ModulatedDeformConv2dFunction.apply + + +class ModulatedDeformConv2d(nn.Module): + + @deprecated_api_warning({'deformable_groups': 'deform_groups'}, + cls_name='ModulatedDeformConv2d') + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deform_groups=1, + bias=True): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deform_groups = deform_groups + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, + *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.init_weights() + + def init_weights(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, x, offset, mask): + return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias, + self.stride, self.padding, + self.dilation, self.groups, + self.deform_groups) + + +@CONV_LAYERS.register_module('DCNv2') +class ModulatedDeformConv2dPack(ModulatedDeformConv2d): + """A ModulatedDeformable Conv Encapsulation that acts as normal Conv + layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int): Same as nn.Conv2d, while tuple is not supported. + padding (int): Same as nn.Conv2d, while tuple is not supported. + dilation (int): Same as nn.Conv2d, while tuple is not supported. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + bias=True) + self.init_weights() + + def init_weights(self): + super().init_weights() + if hasattr(self, 'conv_offset'): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + out = self.conv_offset(x) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias, + self.stride, self.padding, + self.dilation, self.groups, + self.deform_groups) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + version = local_metadata.get('version', None) + + if version is None or version < 2: + # the key is different in early versions + # In version < 2, ModulatedDeformConvPack + # loads previous benchmark models. + if (prefix + 'conv_offset.weight' not in state_dict + and prefix[:-1] + '_offset.weight' in state_dict): + state_dict[prefix + 'conv_offset.weight'] = state_dict.pop( + prefix[:-1] + '_offset.weight') + if (prefix + 'conv_offset.bias' not in state_dict + and prefix[:-1] + '_offset.bias' in state_dict): + state_dict[prefix + + 'conv_offset.bias'] = state_dict.pop(prefix[:-1] + + '_offset.bias') + + if version is not None and version > 1: + print_log( + f'ModulatedDeformConvPack {prefix.rstrip(".")} is upgraded to ' + 'version 2.', + logger='root') + + super()._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, + error_msgs) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/multi_scale_deform_attn.py b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/multi_scale_deform_attn.py new file mode 100755 index 00000000..19b941dd --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/multi_scale_deform_attn.py @@ -0,0 +1,357 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd.function import Function, once_differentiable + +from dbnet_cv import deprecated_api_warning +from dbnet_cv.cnn import constant_init, xavier_init +from dbnet_cv.cnn.bricks.registry import ATTENTION +from dbnet_cv.runner import BaseModule +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward']) + + +class MultiScaleDeformableAttnFunction(Function): + + @staticmethod + def forward(ctx, value, value_spatial_shapes, value_level_start_index, + sampling_locations, attention_weights, im2col_step): + """GPU version of multi-scale deformable attention. + + Args: + value (torch.Tensor): The value has shape + (bs, num_keys, mum_heads, embed_dims//num_heads) + value_spatial_shapes (torch.Tensor): Spatial shape of + each feature map, has shape (num_levels, 2), + last dimension 2 represent (h, w) + sampling_locations (torch.Tensor): The location of sampling points, + has shape + (bs ,num_queries, num_heads, num_levels, num_points, 2), + the last dimension 2 represent (x, y). + attention_weights (torch.Tensor): The weight of sampling points + used when calculate the attention, has shape + (bs ,num_queries, num_heads, num_levels, num_points), + im2col_step (Tensor): The step used in image to column. + + Returns: + torch.Tensor: has shape (bs, num_queries, embed_dims) + """ + + ctx.im2col_step = im2col_step + output = ext_module.ms_deform_attn_forward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + im2col_step=ctx.im2col_step) + ctx.save_for_backward(value, value_spatial_shapes, + value_level_start_index, sampling_locations, + attention_weights) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + """GPU version of backward function. + + Args: + grad_output (torch.Tensor): Gradient of output tensor of forward. + + Returns: + tuple[Tensor]: Gradient of input tensors in forward. + """ + value, value_spatial_shapes, value_level_start_index,\ + sampling_locations, attention_weights = ctx.saved_tensors + grad_value = torch.zeros_like(value) + grad_sampling_loc = torch.zeros_like(sampling_locations) + grad_attn_weight = torch.zeros_like(attention_weights) + + ext_module.ms_deform_attn_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + grad_output.contiguous(), + grad_value, + grad_sampling_loc, + grad_attn_weight, + im2col_step=ctx.im2col_step) + + return grad_value, None, None, \ + grad_sampling_loc, grad_attn_weight, None + + +def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes, + sampling_locations, attention_weights): + """CPU version of multi-scale deformable attention. + + Args: + value (torch.Tensor): The value has shape + (bs, num_keys, num_heads, embed_dims//num_heads) + value_spatial_shapes (torch.Tensor): Spatial shape of + each feature map, has shape (num_levels, 2), + last dimension 2 represent (h, w) + sampling_locations (torch.Tensor): The location of sampling points, + has shape + (bs ,num_queries, num_heads, num_levels, num_points, 2), + the last dimension 2 represent (x, y). + attention_weights (torch.Tensor): The weight of sampling points used + when calculate the attention, has shape + (bs ,num_queries, num_heads, num_levels, num_points), + + Returns: + torch.Tensor: has shape (bs, num_queries, embed_dims) + """ + + bs, _, num_heads, embed_dims = value.shape + _, num_queries, num_heads, num_levels, num_points, _ =\ + sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], + dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level, (H_, W_) in enumerate(value_spatial_shapes): + # bs, H_*W_, num_heads, embed_dims -> + # bs, H_*W_, num_heads*embed_dims -> + # bs, num_heads*embed_dims, H_*W_ -> + # bs*num_heads, embed_dims, H_, W_ + value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape( + bs * num_heads, embed_dims, H_, W_) + # bs, num_queries, num_heads, num_points, 2 -> + # bs, num_heads, num_queries, num_points, 2 -> + # bs*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, + level].transpose(1, 2).flatten(0, 1) + # bs*num_heads, embed_dims, num_queries, num_points + sampling_value_l_ = F.grid_sample( + value_l_, + sampling_grid_l_, + mode='bilinear', + padding_mode='zeros', + align_corners=False) + sampling_value_list.append(sampling_value_l_) + # (bs, num_queries, num_heads, num_levels, num_points) -> + # (bs, num_heads, num_queries, num_levels, num_points) -> + # (bs, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + bs * num_heads, 1, num_queries, num_levels * num_points) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * + attention_weights).sum(-1).view(bs, num_heads * embed_dims, + num_queries) + return output.transpose(1, 2).contiguous() + + +@ATTENTION.register_module() +class MultiScaleDeformableAttention(BaseModule): + """An attention module used in Deformable-Detr. + + `Deformable DETR: Deformable Transformers for End-to-End Object Detection. + `_. + + Args: + embed_dims (int): The embedding dimension of Attention. + Default: 256. + num_heads (int): Parallel attention heads. Default: 64. + num_levels (int): The number of feature map used in + Attention. Default: 4. + num_points (int): The number of sampling points for + each query in each head. Default: 4. + im2col_step (int): The step used in image_to_column. + Default: 64. + dropout (float): A Dropout layer on `inp_identity`. + Default: 0.1. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default to False. + norm_cfg (dict): Config dict for normalization layer. + Default: None. + init_cfg (obj:`dbnet_cv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims=256, + num_heads=8, + num_levels=4, + num_points=4, + im2col_step=64, + dropout=0.1, + batch_first=False, + norm_cfg=None, + init_cfg=None): + super().__init__(init_cfg) + if embed_dims % num_heads != 0: + raise ValueError(f'embed_dims must be divisible by num_heads, ' + f'but got {embed_dims} and {num_heads}') + dim_per_head = embed_dims // num_heads + self.norm_cfg = norm_cfg + self.dropout = nn.Dropout(dropout) + self.batch_first = batch_first + + # you'd better set dim_per_head to a power of 2 + # which is more efficient in the CUDA implementation + def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError( + 'invalid input for _is_power_of_2: {} (type: {})'.format( + n, type(n))) + return (n & (n - 1) == 0) and n != 0 + + if not _is_power_of_2(dim_per_head): + warnings.warn( + "You'd better set embed_dims in " + 'MultiScaleDeformAttention to make ' + 'the dimension of each attention head a power of 2 ' + 'which is more efficient in our CUDA implementation.') + + self.im2col_step = im2col_step + self.embed_dims = embed_dims + self.num_levels = num_levels + self.num_heads = num_heads + self.num_points = num_points + self.sampling_offsets = nn.Linear( + embed_dims, num_heads * num_levels * num_points * 2) + self.attention_weights = nn.Linear(embed_dims, + num_heads * num_levels * num_points) + self.value_proj = nn.Linear(embed_dims, embed_dims) + self.output_proj = nn.Linear(embed_dims, embed_dims) + self.init_weights() + + def init_weights(self): + """Default initialization for Parameters of Module.""" + constant_init(self.sampling_offsets, 0.) + thetas = torch.arange( + self.num_heads, + dtype=torch.float32) * (2.0 * math.pi / self.num_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = (grid_init / + grid_init.abs().max(-1, keepdim=True)[0]).view( + self.num_heads, 1, 1, + 2).repeat(1, self.num_levels, self.num_points, 1) + for i in range(self.num_points): + grid_init[:, :, i, :] *= i + 1 + + self.sampling_offsets.bias.data = grid_init.view(-1) + constant_init(self.attention_weights, val=0., bias=0.) + xavier_init(self.value_proj, distribution='uniform', bias=0.) + xavier_init(self.output_proj, distribution='uniform', bias=0.) + self._is_init = True + + @deprecated_api_warning({'residual': 'identity'}, + cls_name='MultiScaleDeformableAttention') + def forward(self, + query, + key=None, + value=None, + identity=None, + query_pos=None, + key_padding_mask=None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + **kwargs): + """Forward Function of MultiScaleDeformAttention. + + Args: + query (torch.Tensor): Query of Transformer with shape + (num_query, bs, embed_dims). + key (torch.Tensor): The key tensor with shape + `(num_key, bs, embed_dims)`. + value (torch.Tensor): The value tensor with shape + `(num_key, bs, embed_dims)`. + identity (torch.Tensor): The tensor used for addition, with the + same shape as `query`. Default None. If None, + `query` will be used. + query_pos (torch.Tensor): The positional encoding for `query`. + Default: None. + key_pos (torch.Tensor): The positional encoding for `key`. Default + None. + reference_points (torch.Tensor): The normalized reference + points with shape (bs, num_query, num_levels, 2), + all elements is range in [0, 1], top-left (0,0), + bottom-right (1, 1), including padding area. + or (N, Length_{query}, num_levels, 4), add + additional two dimensions is (w, h) to + form reference boxes. + key_padding_mask (torch.Tensor): ByteTensor for `query`, with + shape [bs, num_key]. + spatial_shapes (torch.Tensor): Spatial shape of features in + different levels. With shape (num_levels, 2), + last dimension represents (h, w). + level_start_index (torch.Tensor): The start index of each level. + A tensor has shape ``(num_levels, )`` and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + + Returns: + torch.Tensor: forwarded results with shape + [num_query, bs, embed_dims]. + """ + + if value is None: + value = query + + if identity is None: + identity = query + if query_pos is not None: + query = query + query_pos + if not self.batch_first: + # change to (bs, num_query ,embed_dims) + query = query.permute(1, 0, 2) + value = value.permute(1, 0, 2) + + bs, num_query, _ = query.shape + bs, num_value, _ = value.shape + assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value + + value = self.value_proj(value) + if key_padding_mask is not None: + value = value.masked_fill(key_padding_mask[..., None], 0.0) + value = value.view(bs, num_value, self.num_heads, -1) + sampling_offsets = self.sampling_offsets(query).view( + bs, num_query, self.num_heads, self.num_levels, self.num_points, 2) + attention_weights = self.attention_weights(query).view( + bs, num_query, self.num_heads, self.num_levels * self.num_points) + attention_weights = attention_weights.softmax(-1) + + attention_weights = attention_weights.view(bs, num_query, + self.num_heads, + self.num_levels, + self.num_points) + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack( + [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = reference_points[:, :, None, :, None, :] \ + + sampling_offsets \ + / offset_normalizer[None, None, None, :, None, :] + elif reference_points.shape[-1] == 4: + sampling_locations = reference_points[:, :, None, :, None, :2] \ + + sampling_offsets / self.num_points \ + * reference_points[:, :, None, :, None, 2:] \ + * 0.5 + else: + raise ValueError( + f'Last dim of reference_points must be' + f' 2 or 4, but get {reference_points.shape[-1]} instead.') + if torch.cuda.is_available() and value.is_cuda: + output = MultiScaleDeformableAttnFunction.apply( + value, spatial_shapes, level_start_index, sampling_locations, + attention_weights, self.im2col_step) + else: + output = multi_scale_deformable_attn_pytorch( + value, spatial_shapes, sampling_locations, attention_weights) + + output = self.output_proj(output) + + if not self.batch_first: + # (num_query, bs ,embed_dims) + output = output.permute(1, 0, 2) + + return self.dropout(output) + identity diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/roi_align.py b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/roi_align.py new file mode 100755 index 00000000..7f094230 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/roi_align.py @@ -0,0 +1,224 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from ..utils import deprecated_api_warning, ext_loader + +ext_module = ext_loader.load_ext('_ext', + ['roi_align_forward', 'roi_align_backward']) + + +class RoIAlignFunction(Function): + + @staticmethod + def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio, + pool_mode, aligned): + from ..onnx import is_custom_op_loaded + has_custom_op = is_custom_op_loaded() + if has_custom_op: + return g.op( + 'dbnet_cv::DBNET_CVRoiAlign', + input, + rois, + output_height_i=output_size[0], + output_width_i=output_size[1], + spatial_scale_f=spatial_scale, + sampling_ratio_i=sampling_ratio, + mode_s=pool_mode, + aligned_i=aligned) + else: + from torch.onnx import TensorProtoDataType + from torch.onnx.symbolic_helper import _slice_helper + from torch.onnx.symbolic_opset9 import squeeze, sub + + # batch_indices = rois[:, 0].long() + batch_indices = _slice_helper( + g, rois, axes=[1], starts=[0], ends=[1]) + batch_indices = squeeze(g, batch_indices, 1) + batch_indices = g.op( + 'Cast', batch_indices, to_i=TensorProtoDataType.INT64) + # rois = rois[:, 1:] + rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5]) + if aligned: + # rois -= 0.5/spatial_scale + aligned_offset = g.op( + 'Constant', + value_t=torch.tensor([0.5 / spatial_scale], + dtype=torch.float32)) + rois = sub(g, rois, aligned_offset) + # roi align + return g.op( + 'RoiAlign', + input, + rois, + batch_indices, + output_height_i=output_size[0], + output_width_i=output_size[1], + spatial_scale_f=spatial_scale, + sampling_ratio_i=max(0, sampling_ratio), + mode_s=pool_mode) + + @staticmethod + def forward(ctx, + input, + rois, + output_size, + spatial_scale=1.0, + sampling_ratio=0, + pool_mode='avg', + aligned=True): + ctx.output_size = _pair(output_size) + ctx.spatial_scale = spatial_scale + ctx.sampling_ratio = sampling_ratio + assert pool_mode in ('max', 'avg') + ctx.pool_mode = 0 if pool_mode == 'max' else 1 + ctx.aligned = aligned + ctx.input_shape = input.size() + + assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!' + + output_shape = (rois.size(0), input.size(1), ctx.output_size[0], + ctx.output_size[1]) + output = input.new_zeros(output_shape) + if ctx.pool_mode == 0: + argmax_y = input.new_zeros(output_shape) + argmax_x = input.new_zeros(output_shape) + else: + argmax_y = input.new_zeros(0) + argmax_x = input.new_zeros(0) + + ext_module.roi_align_forward( + input, + rois, + output, + argmax_y, + argmax_x, + aligned_height=ctx.output_size[0], + aligned_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale, + sampling_ratio=ctx.sampling_ratio, + pool_mode=ctx.pool_mode, + aligned=ctx.aligned) + + ctx.save_for_backward(rois, argmax_y, argmax_x) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + rois, argmax_y, argmax_x = ctx.saved_tensors + grad_input = grad_output.new_zeros(ctx.input_shape) + # complex head architecture may cause grad_output uncontiguous. + grad_output = grad_output.contiguous() + ext_module.roi_align_backward( + grad_output, + rois, + argmax_y, + argmax_x, + grad_input, + aligned_height=ctx.output_size[0], + aligned_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale, + sampling_ratio=ctx.sampling_ratio, + pool_mode=ctx.pool_mode, + aligned=ctx.aligned) + return grad_input, None, None, None, None, None, None + + +roi_align = RoIAlignFunction.apply + + +class RoIAlign(nn.Module): + """RoI align pooling layer. + + Args: + output_size (tuple): h, w + spatial_scale (float): scale the input boxes by this number + sampling_ratio (int): number of inputs samples to take for each + output sample. 0 to take samples densely for current models. + pool_mode (str, 'avg' or 'max'): pooling mode in each bin. + aligned (bool): if False, use the legacy implementation in + dbnet_detection. If True, align the results more perfectly. + use_torchvision (bool): whether to use roi_align from torchvision. + + Note: + The implementation of RoIAlign when aligned=True is modified from + https://github.com/facebookresearch/detectron2/ + + The meaning of aligned=True: + + Given a continuous coordinate c, its two neighboring pixel + indices (in our pixel model) are computed by floor(c - 0.5) and + ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete + indices [0] and [1] (which are sampled from the underlying signal + at continuous coordinates 0.5 and 1.5). But the original roi_align + (aligned=False) does not subtract the 0.5 when computing + neighboring pixel indices and therefore it uses pixels with a + slightly incorrect alignment (relative to our pixel model) when + performing bilinear interpolation. + + With `aligned=True`, + we first appropriately scale the ROI and then shift it by -0.5 + prior to calling roi_align. This produces the correct neighbors; + + The difference does not make a difference to the model's + performance if ROIAlign is used together with conv layers. + """ + + @deprecated_api_warning( + { + 'out_size': 'output_size', + 'sample_num': 'sampling_ratio' + }, + cls_name='RoIAlign') + def __init__(self, + output_size, + spatial_scale=1.0, + sampling_ratio=0, + pool_mode='avg', + aligned=True, + use_torchvision=False): + super().__init__() + + self.output_size = _pair(output_size) + self.spatial_scale = float(spatial_scale) + self.sampling_ratio = int(sampling_ratio) + self.pool_mode = pool_mode + self.aligned = aligned + self.use_torchvision = use_torchvision + + def forward(self, input, rois): + """ + Args: + input: NCHW images + rois: Bx5 boxes. First column is the index into N.\ + The other 4 columns are xyxy. + """ + if self.use_torchvision: + from torchvision.ops import roi_align as tv_roi_align + if 'aligned' in tv_roi_align.__code__.co_varnames: + return tv_roi_align(input, rois, self.output_size, + self.spatial_scale, self.sampling_ratio, + self.aligned) + else: + if self.aligned: + rois -= rois.new_tensor([0.] + + [0.5 / self.spatial_scale] * 4) + return tv_roi_align(input, rois, self.output_size, + self.spatial_scale, self.sampling_ratio) + else: + return roi_align(input, rois, self.output_size, self.spatial_scale, + self.sampling_ratio, self.pool_mode, self.aligned) + + def __repr__(self): + s = self.__class__.__name__ + s += f'(output_size={self.output_size}, ' + s += f'spatial_scale={self.spatial_scale}, ' + s += f'sampling_ratio={self.sampling_ratio}, ' + s += f'pool_mode={self.pool_mode}, ' + s += f'aligned={self.aligned}, ' + s += f'use_torchvision={self.use_torchvision})' + return s diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/roi_pool.py b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/roi_pool.py new file mode 100755 index 00000000..f63faabb --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/roi_pool.py @@ -0,0 +1,86 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', + ['roi_pool_forward', 'roi_pool_backward']) + + +class RoIPoolFunction(Function): + + @staticmethod + def symbolic(g, input, rois, output_size, spatial_scale): + return g.op( + 'MaxRoiPool', + input, + rois, + pooled_shape_i=output_size, + spatial_scale_f=spatial_scale) + + @staticmethod + def forward(ctx, input, rois, output_size, spatial_scale=1.0): + ctx.output_size = _pair(output_size) + ctx.spatial_scale = spatial_scale + ctx.input_shape = input.size() + + assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!' + + output_shape = (rois.size(0), input.size(1), ctx.output_size[0], + ctx.output_size[1]) + output = input.new_zeros(output_shape) + argmax = input.new_zeros(output_shape, dtype=torch.int) + + ext_module.roi_pool_forward( + input, + rois, + output, + argmax, + pooled_height=ctx.output_size[0], + pooled_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale) + + ctx.save_for_backward(rois, argmax) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + rois, argmax = ctx.saved_tensors + grad_input = grad_output.new_zeros(ctx.input_shape) + + ext_module.roi_pool_backward( + grad_output, + rois, + argmax, + grad_input, + pooled_height=ctx.output_size[0], + pooled_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale) + + return grad_input, None, None, None + + +roi_pool = RoIPoolFunction.apply + + +class RoIPool(nn.Module): + + def __init__(self, output_size, spatial_scale=1.0): + super().__init__() + + self.output_size = _pair(output_size) + self.spatial_scale = float(spatial_scale) + + def forward(self, input, rois): + return roi_pool(input, rois, self.output_size, self.spatial_scale) + + def __repr__(self): + s = self.__class__.__name__ + s += f'(output_size={self.output_size}, ' + s += f'spatial_scale={self.spatial_scale})' + return s diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/ops/sync_bn.py b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/sync_bn.py new file mode 100755 index 00000000..4a0dc20b --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/ops/sync_bn.py @@ -0,0 +1,279 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.module import Module +from torch.nn.parameter import Parameter + +from dbnet_cv.cnn import NORM_LAYERS +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'sync_bn_forward_mean', 'sync_bn_forward_var', 'sync_bn_forward_output', + 'sync_bn_backward_param', 'sync_bn_backward_data' +]) + + +class SyncBatchNormFunction(Function): + + @staticmethod + def symbolic(g, input, running_mean, running_var, weight, bias, momentum, + eps, group, group_size, stats_mode): + return g.op( + 'dbnet_cv::DBNET_CVSyncBatchNorm', + input, + running_mean, + running_var, + weight, + bias, + momentum_f=momentum, + eps_f=eps, + group_i=group, + group_size_i=group_size, + stats_mode=stats_mode) + + @staticmethod + def forward(self, input, running_mean, running_var, weight, bias, momentum, + eps, group, group_size, stats_mode): + self.momentum = momentum + self.eps = eps + self.group = group + self.group_size = group_size + self.stats_mode = stats_mode + + assert isinstance( + input, (torch.HalfTensor, torch.FloatTensor, + torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \ + f'only support Half or Float Tensor, but {input.type()}' + output = torch.zeros_like(input) + input3d = input.flatten(start_dim=2) + output3d = output.view_as(input3d) + num_channels = input3d.size(1) + + # ensure mean/var/norm/std are initialized as zeros + # ``torch.empty()`` does not guarantee that + mean = torch.zeros( + num_channels, dtype=torch.float, device=input3d.device) + var = torch.zeros( + num_channels, dtype=torch.float, device=input3d.device) + norm = torch.zeros_like( + input3d, dtype=torch.float, device=input3d.device) + std = torch.zeros( + num_channels, dtype=torch.float, device=input3d.device) + + batch_size = input3d.size(0) + if batch_size > 0: + ext_module.sync_bn_forward_mean(input3d, mean) + batch_flag = torch.ones([1], device=mean.device, dtype=mean.dtype) + else: + # skip updating mean and leave it as zeros when the input is empty + batch_flag = torch.zeros([1], device=mean.device, dtype=mean.dtype) + + # synchronize mean and the batch flag + vec = torch.cat([mean, batch_flag]) + if self.stats_mode == 'N': + vec *= batch_size + if self.group_size > 1: + dist.all_reduce(vec, group=self.group) + total_batch = vec[-1].detach() + mean = vec[:num_channels] + + if self.stats_mode == 'default': + mean = mean / self.group_size + elif self.stats_mode == 'N': + mean = mean / total_batch.clamp(min=1) + else: + raise NotImplementedError + + # leave var as zeros when the input is empty + if batch_size > 0: + ext_module.sync_bn_forward_var(input3d, mean, var) + + if self.stats_mode == 'N': + var *= batch_size + if self.group_size > 1: + dist.all_reduce(var, group=self.group) + + if self.stats_mode == 'default': + var /= self.group_size + elif self.stats_mode == 'N': + var /= total_batch.clamp(min=1) + else: + raise NotImplementedError + + # if the total batch size over all the ranks is zero, + # we should not update the statistics in the current batch + update_flag = total_batch.clamp(max=1) + momentum = update_flag * self.momentum + ext_module.sync_bn_forward_output( + input3d, + mean, + var, + weight, + bias, + running_mean, + running_var, + norm, + std, + output3d, + eps=self.eps, + momentum=momentum, + group_size=self.group_size) + self.save_for_backward(norm, std, weight) + return output + + @staticmethod + @once_differentiable + def backward(self, grad_output): + norm, std, weight = self.saved_tensors + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(weight) + grad_input = torch.zeros_like(grad_output) + grad_output3d = grad_output.flatten(start_dim=2) + grad_input3d = grad_input.view_as(grad_output3d) + + batch_size = grad_input3d.size(0) + if batch_size > 0: + ext_module.sync_bn_backward_param(grad_output3d, norm, grad_weight, + grad_bias) + + # all reduce + if self.group_size > 1: + dist.all_reduce(grad_weight, group=self.group) + dist.all_reduce(grad_bias, group=self.group) + grad_weight /= self.group_size + grad_bias /= self.group_size + + if batch_size > 0: + ext_module.sync_bn_backward_data(grad_output3d, weight, + grad_weight, grad_bias, norm, std, + grad_input3d) + + return grad_input, None, None, grad_weight, grad_bias, \ + None, None, None, None, None + + +@NORM_LAYERS.register_module(name='MMSyncBN') +class SyncBatchNorm(Module): + """Synchronized Batch Normalization. + + Args: + num_features (int): number of features/chennels in input tensor + eps (float, optional): a value added to the denominator for numerical + stability. Defaults to 1e-5. + momentum (float, optional): the value used for the running_mean and + running_var computation. Defaults to 0.1. + affine (bool, optional): whether to use learnable affine parameters. + Defaults to True. + track_running_stats (bool, optional): whether to track the running + mean and variance during training. When set to False, this + module does not track such statistics, and initializes statistics + buffers ``running_mean`` and ``running_var`` as ``None``. When + these buffers are ``None``, this module always uses batch + statistics in both training and eval modes. Defaults to True. + group (int, optional): synchronization of stats happen within + each process group individually. By default it is synchronization + across the whole world. Defaults to None. + stats_mode (str, optional): The statistical mode. Available options + includes ``'default'`` and ``'N'``. Defaults to 'default'. + When ``stats_mode=='default'``, it computes the overall statistics + using those from each worker with equal weight, i.e., the + statistics are synchronized and simply divied by ``group``. This + mode will produce inaccurate statistics when empty tensors occur. + When ``stats_mode=='N'``, it compute the overall statistics using + the total number of batches in each worker ignoring the number of + group, i.e., the statistics are synchronized and then divied by + the total batch ``N``. This mode is beneficial when empty tensors + occur during training, as it average the total mean by the real + number of batch. + """ + + def __init__(self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + group=None, + stats_mode='default'): + super().__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats + group = dist.group.WORLD if group is None else group + self.group = group + self.group_size = dist.get_world_size(group) + assert stats_mode in ['default', 'N'], \ + f'"stats_mode" only accepts "default" and "N", got "{stats_mode}"' + self.stats_mode = stats_mode + if self.affine: + self.weight = Parameter(torch.Tensor(num_features)) + self.bias = Parameter(torch.Tensor(num_features)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + if self.track_running_stats: + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.register_buffer('num_batches_tracked', + torch.tensor(0, dtype=torch.long)) + else: + self.register_buffer('running_mean', None) + self.register_buffer('running_var', None) + self.register_buffer('num_batches_tracked', None) + self.reset_parameters() + + def reset_running_stats(self): + if self.track_running_stats: + self.running_mean.zero_() + self.running_var.fill_(1) + self.num_batches_tracked.zero_() + + def reset_parameters(self): + self.reset_running_stats() + if self.affine: + self.weight.data.uniform_() # pytorch use ones_() + self.bias.data.zero_() + + def forward(self, input): + if input.dim() < 2: + raise ValueError( + f'expected at least 2D input, got {input.dim()}D input') + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + if self.num_batches_tracked is not None: + self.num_batches_tracked += 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float( + self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + if self.training or not self.track_running_stats: + return SyncBatchNormFunction.apply( + input, self.running_mean, self.running_var, self.weight, + self.bias, exponential_average_factor, self.eps, self.group, + self.group_size, self.stats_mode) + else: + return F.batch_norm(input, self.running_mean, self.running_var, + self.weight, self.bias, False, + exponential_average_factor, self.eps) + + def __repr__(self): + s = self.__class__.__name__ + s += f'({self.num_features}, ' + s += f'eps={self.eps}, ' + s += f'momentum={self.momentum}, ' + s += f'affine={self.affine}, ' + s += f'track_running_stats={self.track_running_stats}, ' + s += f'group_size={self.group_size},' + s += f'stats_mode={self.stats_mode})' + return s diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/__init__.py new file mode 100755 index 00000000..2ed2c17a --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .collate import collate +from .data_container import DataContainer +from .data_parallel import MMDataParallel +from .distributed import MMDistributedDataParallel +from .registry import MODULE_WRAPPERS +from .scatter_gather import scatter, scatter_kwargs +from .utils import is_module_wrapper + +__all__ = [ + 'collate', 'DataContainer', 'MMDataParallel', 'MMDistributedDataParallel', + 'scatter', 'scatter_kwargs', 'is_module_wrapper', 'MODULE_WRAPPERS' +] diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/_functions.py b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/_functions.py new file mode 100755 index 00000000..95c58bf1 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/_functions.py @@ -0,0 +1,76 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn.parallel._functions import _get_stream + + +def scatter(input, devices, streams=None): + """Scatters tensor across multiple GPUs.""" + if streams is None: + streams = [None] * len(devices) + + if isinstance(input, list): + chunk_size = (len(input) - 1) // len(devices) + 1 + outputs = [ + scatter(input[i], [devices[i // chunk_size]], + [streams[i // chunk_size]]) for i in range(len(input)) + ] + return outputs + elif isinstance(input, torch.Tensor): + output = input.contiguous() + # TODO: copy to a pinned buffer first (if copying from CPU) + stream = streams[0] if output.numel() > 0 else None + if devices != [-1]: + with torch.cuda.device(devices[0]), torch.cuda.stream(stream): + output = output.cuda(devices[0], non_blocking=True) + + return output + else: + raise Exception(f'Unknown type {type(input)}.') + + +def synchronize_stream(output, devices, streams): + if isinstance(output, list): + chunk_size = len(output) // len(devices) + for i in range(len(devices)): + for j in range(chunk_size): + synchronize_stream(output[i * chunk_size + j], [devices[i]], + [streams[i]]) + elif isinstance(output, torch.Tensor): + if output.numel() != 0: + with torch.cuda.device(devices[0]): + main_stream = torch.cuda.current_stream() + main_stream.wait_stream(streams[0]) + output.record_stream(main_stream) + else: + raise Exception(f'Unknown type {type(output)}.') + + +def get_input_device(input): + if isinstance(input, list): + for item in input: + input_device = get_input_device(item) + if input_device != -1: + return input_device + return -1 + elif isinstance(input, torch.Tensor): + return input.get_device() if input.is_cuda else -1 + else: + raise Exception(f'Unknown type {type(input)}.') + + +class Scatter: + + @staticmethod + def forward(target_gpus, input): + input_device = get_input_device(input) + streams = None + if input_device == -1 and target_gpus != [-1]: + # Perform CPU to GPU copies in a background stream + streams = [_get_stream(device) for device in target_gpus] + + outputs = scatter(input, target_gpus, streams) + # Synchronize with the copy stream + if streams is not None: + synchronize_stream(outputs, target_gpus, streams) + + return tuple(outputs) if isinstance(outputs, list) else (outputs, ) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/collate.py b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/collate.py new file mode 100755 index 00000000..b34a2f59 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/collate.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections.abc import Mapping, Sequence + +import torch +import torch.nn.functional as F +from torch.utils.data.dataloader import default_collate + +from .data_container import DataContainer + + +def collate(batch, samples_per_gpu=1): + """Puts each data field into a tensor/DataContainer with outer dimension + batch size. + + Extend default_collate to add support for + :type:`~dbnet_cv.parallel.DataContainer`. There are 3 cases. + + 1. cpu_only = True, e.g., meta data + 2. cpu_only = False, stack = True, e.g., images tensors + 3. cpu_only = False, stack = False, e.g., gt bboxes + """ + + if not isinstance(batch, Sequence): + raise TypeError(f'{batch.dtype} is not supported.') + + if isinstance(batch[0], DataContainer): + stacked = [] + if batch[0].cpu_only: + for i in range(0, len(batch), samples_per_gpu): + stacked.append( + [sample.data for sample in batch[i:i + samples_per_gpu]]) + return DataContainer( + stacked, batch[0].stack, batch[0].padding_value, cpu_only=True) + elif batch[0].stack: + for i in range(0, len(batch), samples_per_gpu): + assert isinstance(batch[i].data, torch.Tensor) + + if batch[i].pad_dims is not None: + ndim = batch[i].dim() + assert ndim > batch[i].pad_dims + max_shape = [0 for _ in range(batch[i].pad_dims)] + for dim in range(1, batch[i].pad_dims + 1): + max_shape[dim - 1] = batch[i].size(-dim) + for sample in batch[i:i + samples_per_gpu]: + for dim in range(0, ndim - batch[i].pad_dims): + assert batch[i].size(dim) == sample.size(dim) + for dim in range(1, batch[i].pad_dims + 1): + max_shape[dim - 1] = max(max_shape[dim - 1], + sample.size(-dim)) + padded_samples = [] + for sample in batch[i:i + samples_per_gpu]: + pad = [0 for _ in range(batch[i].pad_dims * 2)] + for dim in range(1, batch[i].pad_dims + 1): + pad[2 * dim - + 1] = max_shape[dim - 1] - sample.size(-dim) + padded_samples.append( + F.pad( + sample.data, pad, value=sample.padding_value)) + stacked.append(default_collate(padded_samples)) + elif batch[i].pad_dims is None: + stacked.append( + default_collate([ + sample.data + for sample in batch[i:i + samples_per_gpu] + ])) + else: + raise ValueError( + 'pad_dims should be either None or integers (1-3)') + + else: + for i in range(0, len(batch), samples_per_gpu): + stacked.append( + [sample.data for sample in batch[i:i + samples_per_gpu]]) + return DataContainer(stacked, batch[0].stack, batch[0].padding_value) + elif isinstance(batch[0], Sequence): + transposed = zip(*batch) + return [collate(samples, samples_per_gpu) for samples in transposed] + elif isinstance(batch[0], Mapping): + return { + key: collate([d[key] for d in batch], samples_per_gpu) + for key in batch[0] + } + else: + return default_collate(batch) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/data_container.py b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/data_container.py new file mode 100755 index 00000000..cedb0d32 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/data_container.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + +import torch + + +def assert_tensor_type(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not isinstance(args[0].data, torch.Tensor): + raise AttributeError( + f'{args[0].__class__.__name__} has no attribute ' + f'{func.__name__} for type {args[0].datatype}') + return func(*args, **kwargs) + + return wrapper + + +class DataContainer: + """A container for any type of objects. + + Typically tensors will be stacked in the collate function and sliced along + some dimension in the scatter function. This behavior has some limitations. + 1. All tensors have to be the same size. + 2. Types are limited (numpy array or Tensor). + + We design `DataContainer` and `MMDataParallel` to overcome these + limitations. The behavior can be either of the following. + + - copy to GPU, pad all tensors to the same size and stack them + - copy to GPU without stacking + - leave the objects as is and pass it to the model + - pad_dims specifies the number of last few dimensions to do padding + """ + + def __init__(self, + data, + stack=False, + padding_value=0, + cpu_only=False, + pad_dims=2): + self._data = data + self._cpu_only = cpu_only + self._stack = stack + self._padding_value = padding_value + assert pad_dims in [None, 1, 2, 3] + self._pad_dims = pad_dims + + def __repr__(self): + return f'{self.__class__.__name__}({repr(self.data)})' + + def __len__(self): + return len(self._data) + + @property + def data(self): + return self._data + + @property + def datatype(self): + if isinstance(self.data, torch.Tensor): + return self.data.type() + else: + return type(self.data) + + @property + def cpu_only(self): + return self._cpu_only + + @property + def stack(self): + return self._stack + + @property + def padding_value(self): + return self._padding_value + + @property + def pad_dims(self): + return self._pad_dims + + @assert_tensor_type + def size(self, *args, **kwargs): + return self.data.size(*args, **kwargs) + + @assert_tensor_type + def dim(self): + return self.data.dim() diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/data_parallel.py b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/data_parallel.py new file mode 100755 index 00000000..e86b47a5 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/data_parallel.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import chain + +from torch.nn.parallel import DataParallel + +from .scatter_gather import scatter_kwargs + + +class MMDataParallel(DataParallel): + """The DataParallel module that supports DataContainer. + + MMDataParallel has two main differences with PyTorch DataParallel: + + - It supports a custom type :class:`DataContainer` which allows more + flexible control of input data during both GPU and CPU inference. + - It implement two more APIs ``train_step()`` and ``val_step()``. + + .. warning:: + MMDataParallel only supports single GPU training, if you need to + train with multiple GPUs, please use MMDistributedDataParallel + instead. If you have multiple GPUs and you just want to use + MMDataParallel, you can set the environment variable + ``CUDA_VISIBLE_DEVICES=0`` or instantiate ``MMDataParallel`` with + ``device_ids=[0]``. + + Args: + module (:class:`nn.Module`): Module to be encapsulated. + device_ids (list[int]): Device IDS of modules to be scattered to. + Defaults to None when GPU is not available. + output_device (str | int): Device ID for output. Defaults to None. + dim (int): Dimension used to scatter the data. Defaults to 0. + """ + + def __init__(self, *args, dim=0, **kwargs): + super().__init__(*args, dim=dim, **kwargs) + self.dim = dim + + def forward(self, *inputs, **kwargs): + """Override the original forward function. + + The main difference lies in the CPU inference where the data in + :class:`DataContainers` will still be gathered. + """ + if not self.device_ids: + # We add the following line thus the module could gather and + # convert data containers as those in GPU inference + inputs, kwargs = self.scatter(inputs, kwargs, [-1]) + return self.module(*inputs[0], **kwargs[0]) + else: + return super().forward(*inputs, **kwargs) + + def scatter(self, inputs, kwargs, device_ids): + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + + def train_step(self, *inputs, **kwargs): + if not self.device_ids: + # We add the following line thus the module could gather and + # convert data containers as those in GPU inference + inputs, kwargs = self.scatter(inputs, kwargs, [-1]) + return self.module.train_step(*inputs[0], **kwargs[0]) + + assert len(self.device_ids) == 1, \ + ('MMDataParallel only supports single GPU training, if you need to' + ' train with multiple GPUs, please use MMDistributedDataParallel' + ' instead.') + + for t in chain(self.module.parameters(), self.module.buffers()): + if t.device != self.src_device_obj: + raise RuntimeError( + 'module must have its parameters and buffers ' + f'on device {self.src_device_obj} (device_ids[0]) but ' + f'found one of them on device: {t.device}') + + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + return self.module.train_step(*inputs[0], **kwargs[0]) + + def val_step(self, *inputs, **kwargs): + if not self.device_ids: + # We add the following line thus the module could gather and + # convert data containers as those in GPU inference + inputs, kwargs = self.scatter(inputs, kwargs, [-1]) + return self.module.val_step(*inputs[0], **kwargs[0]) + + assert len(self.device_ids) == 1, \ + ('MMDataParallel only supports single GPU training, if you need to' + ' train with multiple GPUs, please use MMDistributedDataParallel' + ' instead.') + + for t in chain(self.module.parameters(), self.module.buffers()): + if t.device != self.src_device_obj: + raise RuntimeError( + 'module must have its parameters and buffers ' + f'on device {self.src_device_obj} (device_ids[0]) but ' + f'found one of them on device: {t.device}') + + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + return self.module.val_step(*inputs[0], **kwargs[0]) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/distributed.py b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/distributed.py new file mode 100755 index 00000000..85204178 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/distributed.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn.parallel.distributed import (DistributedDataParallel, + _find_tensors) + +from dbnet_cv import print_log +from dbnet_cv.utils import TORCH_VERSION, digit_version +from .scatter_gather import scatter_kwargs + + +class MMDistributedDataParallel(DistributedDataParallel): + """The DDP module that supports DataContainer. + + MMDDP has two main differences with PyTorch DDP: + + - It supports a custom type :class:`DataContainer` which allows more + flexible control of input data. + - It implement two APIs ``train_step()`` and ``val_step()``. + """ + + def to_kwargs(self, inputs, kwargs, device_id): + # Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8 + # to move all tensors to device_id + return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim) + + def scatter(self, inputs, kwargs, device_ids): + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + + def train_step(self, *inputs, **kwargs): + """train_step() API for module wrapped by DistributedDataParallel. + + This method is basically the same as + ``DistributedDataParallel.forward()``, while replacing + ``self.module.forward()`` with ``self.module.train_step()``. + It is compatible with PyTorch 1.1 - 1.5. + """ + + # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the + # end of backward to the beginning of forward. + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.7') + and self.reducer._rebuild_buckets()): + print_log( + 'Reducer buckets have been rebuilt in this iteration.', + logger='dbnet_cv') + + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): + if self._check_sync_bufs_pre_fwd(): + self._sync_buffers() + else: + if (getattr(self, 'require_forward_param_sync', False) + and self.require_forward_param_sync): + self._sync_params() + + if self.device_ids: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: + output = self.module.train_step(*inputs[0], **kwargs[0]) + else: + outputs = self.parallel_apply( + self._module_copies[:len(inputs)], inputs, kwargs) + output = self.gather(outputs, self.output_device) + else: + output = self.module.train_step(*inputs, **kwargs) + + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): + if self._check_sync_bufs_post_fwd(): + self._sync_buffers() + + if (torch.is_grad_enabled() + and getattr(self, 'require_backward_grad_sync', False) + and self.require_backward_grad_sync): + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) > digit_version('1.2')): + self.require_forward_param_sync = False + return output + + def val_step(self, *inputs, **kwargs): + """val_step() API for module wrapped by DistributedDataParallel. + + This method is basically the same as + ``DistributedDataParallel.forward()``, while replacing + ``self.module.forward()`` with ``self.module.val_step()``. + It is compatible with PyTorch 1.1 - 1.5. + """ + # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the + # end of backward to the beginning of forward. + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.7') + and self.reducer._rebuild_buckets()): + print_log( + 'Reducer buckets have been rebuilt in this iteration.', + logger='dbnet_cv') + + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): + if self._check_sync_bufs_pre_fwd(): + self._sync_buffers() + else: + if (getattr(self, 'require_forward_param_sync', False) + and self.require_forward_param_sync): + self._sync_params() + + if self.device_ids: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: + output = self.module.val_step(*inputs[0], **kwargs[0]) + else: + outputs = self.parallel_apply( + self._module_copies[:len(inputs)], inputs, kwargs) + output = self.gather(outputs, self.output_device) + else: + output = self.module.val_step(*inputs, **kwargs) + + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): + if self._check_sync_bufs_post_fwd(): + self._sync_buffers() + + if (torch.is_grad_enabled() + and getattr(self, 'require_backward_grad_sync', False) + and self.require_backward_grad_sync): + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) > digit_version('1.2')): + self.require_forward_param_sync = False + return output diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/distributed_deprecated.py b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/distributed_deprecated.py new file mode 100755 index 00000000..0cfbcebb --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/distributed_deprecated.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.distributed as dist +import torch.nn as nn +from torch._utils import (_flatten_dense_tensors, _take_tensors, + _unflatten_dense_tensors) + +from dbnet_cv.utils import TORCH_VERSION, digit_version +from .registry import MODULE_WRAPPERS +from .scatter_gather import scatter_kwargs + + +@MODULE_WRAPPERS.register_module() +class MMDistributedDataParallel(nn.Module): + + def __init__(self, + module, + dim=0, + broadcast_buffers=True, + bucket_cap_mb=25): + super().__init__() + self.module = module + self.dim = dim + self.broadcast_buffers = broadcast_buffers + + self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024 + self._sync_params() + + def _dist_broadcast_coalesced(self, tensors, buffer_size): + for tensors in _take_tensors(tensors, buffer_size): + flat_tensors = _flatten_dense_tensors(tensors) + dist.broadcast(flat_tensors, 0) + for tensor, synced in zip( + tensors, _unflatten_dense_tensors(flat_tensors, tensors)): + tensor.copy_(synced) + + def _sync_params(self): + module_states = list(self.module.state_dict().values()) + if len(module_states) > 0: + self._dist_broadcast_coalesced(module_states, + self.broadcast_bucket_size) + if self.broadcast_buffers: + if (TORCH_VERSION != 'parrots' + and digit_version(TORCH_VERSION) < digit_version('1.0')): + buffers = [b.data for b in self.module._all_buffers()] + else: + buffers = [b.data for b in self.module.buffers()] + if len(buffers) > 0: + self._dist_broadcast_coalesced(buffers, + self.broadcast_bucket_size) + + def scatter(self, inputs, kwargs, device_ids): + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + + def forward(self, *inputs, **kwargs): + inputs, kwargs = self.scatter(inputs, kwargs, + [torch.cuda.current_device()]) + return self.module(*inputs[0], **kwargs[0]) + + def train_step(self, *inputs, **kwargs): + inputs, kwargs = self.scatter(inputs, kwargs, + [torch.cuda.current_device()]) + output = self.module.train_step(*inputs[0], **kwargs[0]) + return output + + def val_step(self, *inputs, **kwargs): + inputs, kwargs = self.scatter(inputs, kwargs, + [torch.cuda.current_device()]) + output = self.module.val_step(*inputs[0], **kwargs[0]) + return output diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/registry.py b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/registry.py new file mode 100755 index 00000000..6f248cd6 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/registry.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from dbnet_cv.utils import Registry + +MODULE_WRAPPERS = Registry('module wrapper') +MODULE_WRAPPERS.register_module(module=DataParallel) +MODULE_WRAPPERS.register_module(module=DistributedDataParallel) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/scatter_gather.py b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/scatter_gather.py new file mode 100755 index 00000000..1af533c0 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/scatter_gather.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn.parallel._functions import Scatter as OrigScatter + +from ._functions import Scatter +from .data_container import DataContainer + + +def scatter(inputs, target_gpus, dim=0): + """Scatter inputs to target gpus. + + The only difference from original :func:`scatter` is to add support for + :type:`~dbnet_cv.parallel.DataContainer`. + """ + + def scatter_map(obj): + if isinstance(obj, torch.Tensor): + if target_gpus != [-1]: + return OrigScatter.apply(target_gpus, None, dim, obj) + else: + # for CPU inference we use self-implemented scatter + return Scatter.forward(target_gpus, obj) + if isinstance(obj, DataContainer): + if obj.cpu_only: + return obj.data + else: + return Scatter.forward(target_gpus, obj.data) + if isinstance(obj, tuple) and len(obj) > 0: + return list(zip(*map(scatter_map, obj))) + if isinstance(obj, list) and len(obj) > 0: + out = list(map(list, zip(*map(scatter_map, obj)))) + return out + if isinstance(obj, dict) and len(obj) > 0: + out = list(map(type(obj), zip(*map(scatter_map, obj.items())))) + return out + return [obj for targets in target_gpus] + + # After scatter_map is called, a scatter_map cell will exist. This cell + # has a reference to the actual function scatter_map, which has references + # to a closure that has a reference to the scatter_map cell (because the + # fn is recursive). To avoid this reference cycle, we set the function to + # None, clearing the cell + try: + return scatter_map(inputs) + finally: + scatter_map = None + + +def scatter_kwargs(inputs, kwargs, target_gpus, dim=0): + """Scatter with support for kwargs dictionary.""" + inputs = scatter(inputs, target_gpus, dim) if inputs else [] + kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] + if len(inputs) < len(kwargs): + inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) + elif len(kwargs) < len(inputs): + kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) + inputs = tuple(inputs) + kwargs = tuple(kwargs) + return inputs, kwargs diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/utils.py b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/utils.py new file mode 100755 index 00000000..940e0aa3 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/parallel/utils.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .registry import MODULE_WRAPPERS + + +def is_module_wrapper(module): + """Check if a module is a module wrapper. + + The following 3 modules in DBNET_CV (and their subclasses) are regarded as + module wrappers: DataParallel, DistributedDataParallel, + MMDistributedDataParallel (the deprecated version). You may add you own + module wrapper by registering it to dbnet_cv.parallel.MODULE_WRAPPERS or + its children registries. + + Args: + module (nn.Module): The module to be checked. + + Returns: + bool: True if the input module is a module wrapper. + """ + + def is_module_in_wrapper(module, module_wrapper): + module_wrappers = tuple(module_wrapper.module_dict.values()) + if isinstance(module, module_wrappers): + return True + for child in module_wrapper.children.values(): + if is_module_in_wrapper(module, child): + return True + return False + + return is_module_in_wrapper(module, MODULE_WRAPPERS) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/__init__.py new file mode 100755 index 00000000..8e2e05d7 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/__init__.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_module import BaseModule, ModuleDict, ModuleList, Sequential +from .base_runner import BaseRunner +from .builder import RUNNERS, build_runner +from .checkpoint import (CheckpointLoader, _load_checkpoint, + _load_checkpoint_with_prefix, load_checkpoint, + load_state_dict, save_checkpoint, weights_to_cpu) +from .default_constructor import DefaultRunnerConstructor +from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info, + init_dist, master_only) +from .epoch_based_runner import EpochBasedRunner, Runner +from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model + +from .hooks import (HOOKS, CheckpointHook, ClearMLLoggerHook, ClosureHook, + DistEvalHook, DistSamplerSeedHook, DvcliveLoggerHook, + EMAHook, EvalHook, Fp16OptimizerHook, + GradientCumulativeFp16OptimizerHook, + GradientCumulativeOptimizerHook, Hook, IterTimerHook, + LoggerHook, MlflowLoggerHook, NeptuneLoggerHook, + OptimizerHook, PaviLoggerHook, SegmindLoggerHook, + SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook, + WandbLoggerHook) +from .hooks.lr_updater import StepLrUpdaterHook # noqa +from .hooks.lr_updater import (CosineAnnealingLrUpdaterHook, + CosineRestartLrUpdaterHook, CyclicLrUpdaterHook, + ExpLrUpdaterHook, FixedLrUpdaterHook, + FlatCosineAnnealingLrUpdaterHook, + InvLrUpdaterHook, LinearAnnealingLrUpdaterHook, + LrUpdaterHook, OneCycleLrUpdaterHook, + PolyLrUpdaterHook) +from .hooks.momentum_updater import (CosineAnnealingMomentumUpdaterHook, + CyclicMomentumUpdaterHook, + LinearAnnealingMomentumUpdaterHook, + MomentumUpdaterHook, + OneCycleMomentumUpdaterHook, + StepMomentumUpdaterHook) +from .iter_based_runner import IterBasedRunner, IterLoader +from .log_buffer import LogBuffer +from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS, + DefaultOptimizerConstructor, build_optimizer, + build_optimizer_constructor) +from .priority import Priority, get_priority +from .utils import get_host_info, get_time_str, obj_from_dict, set_random_seed + +# initialize ipu to registor ipu runner to RUNNERS +# from dbnet_cv.device import ipu # isort:skip # noqa + +# __all__ = [ +# 'BaseRunner', 'Runner', 'EpochBasedRunner', 'IterBasedRunner', 'LogBuffer', +# 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', +# 'FixedLrUpdaterHook', 'StepLrUpdaterHook', 'ExpLrUpdaterHook', +# 'PolyLrUpdaterHook', 'InvLrUpdaterHook', 'CosineAnnealingLrUpdaterHook', +# 'FlatCosineAnnealingLrUpdaterHook', 'CosineRestartLrUpdaterHook', +# 'CyclicLrUpdaterHook', 'OneCycleLrUpdaterHook', 'MomentumUpdaterHook', +# 'StepMomentumUpdaterHook', 'CosineAnnealingMomentumUpdaterHook', +# 'CyclicMomentumUpdaterHook', 'OneCycleMomentumUpdaterHook', +# 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', +# 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', +# 'NeptuneLoggerHook', 'WandbLoggerHook', 'MlflowLoggerHook', +# 'DvcliveLoggerHook', '_load_checkpoint', 'load_state_dict', +# 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority', +# 'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict', +# 'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS', +# 'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer', +# 'build_optimizer_constructor', 'IterLoader', 'set_random_seed', +# 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook', +# 'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads', +# 'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule', +# '_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential', +# 'ModuleDict', 'ModuleList', 'GradientCumulativeOptimizerHook', +# 'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor', +# 'SegmindLoggerHook', 'LinearAnnealingMomentumUpdaterHook', +# 'LinearAnnealingLrUpdaterHook', 'ClearMLLoggerHook' +# ] diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/base_module.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/base_module.py new file mode 100755 index 00000000..2c800eae --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/base_module.py @@ -0,0 +1,213 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import warnings +from abc import ABCMeta +from collections import defaultdict +from logging import FileHandler +from typing import Iterable, Optional + +import torch.nn as nn + +from dbnet_cv.runner.dist_utils import master_only +from dbnet_cv.utils.logging import get_logger, logger_initialized, print_log + + +class BaseModule(nn.Module, metaclass=ABCMeta): + """Base module for all modules in openmmlab. + + ``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional + functionality of parameter initialization. Compared with + ``torch.nn.Module``, ``BaseModule`` mainly adds three attributes. + + - ``init_cfg``: the config to control the initialization. + - ``init_weights``: The function of parameter initialization and recording + initialization information. + - ``_params_init_info``: Used to track the parameter initialization + information. This attribute only exists during executing the + ``init_weights``. + + Args: + init_cfg (dict, optional): Initialization config dict. + """ + + def __init__(self, init_cfg: Optional[dict] = None): + """Initialize BaseModule, inherited from `torch.nn.Module`""" + + # NOTE init_cfg can be defined in different levels, but init_cfg + # in low levels has a higher priority. + + super().__init__() + # define default value of init_cfg instead of hard code + # in init_weights() function + self._is_init = False + + self.init_cfg = copy.deepcopy(init_cfg) + + # Backward compatibility in derived classes + # if pretrained is not None: + # warnings.warn('DeprecationWarning: pretrained is a deprecated \ + # key, please consider using init_cfg') + # self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + + @property + def is_init(self): + return self._is_init + + def init_weights(self): + """Initialize the weights.""" + + is_top_level_module = False + # check if it is top-level module + if not hasattr(self, '_params_init_info'): + # The `_params_init_info` is used to record the initialization + # information of the parameters + # the key should be the obj:`nn.Parameter` of model and the value + # should be a dict containing + # - init_info (str): The string that describes the initialization. + # - tmp_mean_value (FloatTensor): The mean of the parameter, + # which indicates whether the parameter has been modified. + # this attribute would be deleted after all parameters + # is initialized. + self._params_init_info = defaultdict(dict) + is_top_level_module = True + + # Initialize the `_params_init_info`, + # When detecting the `tmp_mean_value` of + # the corresponding parameter is changed, update related + # initialization information + for name, param in self.named_parameters(): + self._params_init_info[param][ + 'init_info'] = f'The value is the same before and ' \ + f'after calling `init_weights` ' \ + f'of {self.__class__.__name__} ' + self._params_init_info[param][ + 'tmp_mean_value'] = param.data.mean() + + # pass `params_init_info` to all submodules + # All submodules share the same `params_init_info`, + # so it will be updated when parameters are + # modified at any level of the model. + for sub_module in self.modules(): + sub_module._params_init_info = self._params_init_info + + # Get the initialized logger, if not exist, + # create a logger named `dbnet_cv` + logger_names = list(logger_initialized.keys()) + logger_name = logger_names[0] if logger_names else 'dbnet_cv' + + from ..cnn import initialize + from ..cnn.utils.weight_init import update_init_info + module_name = self.__class__.__name__ + if not self._is_init: + if self.init_cfg: + print_log( + f'initialize {module_name} with init_cfg {self.init_cfg}', + logger=logger_name) + initialize(self, self.init_cfg) + if isinstance(self.init_cfg, dict): + # prevent the parameters of + # the pre-trained model + # from being overwritten by + # the `init_weights` + if self.init_cfg['type'] == 'Pretrained': + return + + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights() + # users may overload the `init_weights` + update_init_info( + m, + init_info=f'Initialized by ' + f'user-defined `init_weights`' + f' in {m.__class__.__name__} ') + + self._is_init = True + else: + warnings.warn(f'init_weights of {self.__class__.__name__} has ' + f'been called more than once.') + + if is_top_level_module: + self._dump_init_info(logger_name) + + for sub_module in self.modules(): + del sub_module._params_init_info + + @master_only + def _dump_init_info(self, logger_name: str): + """Dump the initialization information to a file named + `initialization.log.json` in workdir. + + Args: + logger_name (str): The name of logger. + """ + + logger = get_logger(logger_name) + + with_file_handler = False + # dump the information to the logger file if there is a `FileHandler` + for handler in logger.handlers: + if isinstance(handler, FileHandler): + handler.stream.write( + 'Name of parameter - Initialization information\n') + for name, param in self.named_parameters(): + handler.stream.write( + f'\n{name} - {param.shape}: ' + f"\n{self._params_init_info[param]['init_info']} \n") + handler.stream.flush() + with_file_handler = True + if not with_file_handler: + for name, param in self.named_parameters(): + print_log( + f'\n{name} - {param.shape}: ' + f"\n{self._params_init_info[param]['init_info']} \n ", + logger=logger_name) + + def __repr__(self): + s = super().__repr__() + if self.init_cfg: + s += f'\ninit_cfg={self.init_cfg}' + return s + + +class Sequential(BaseModule, nn.Sequential): + """Sequential module in openmmlab. + + Args: + init_cfg (dict, optional): Initialization config dict. + """ + + def __init__(self, *args, init_cfg: Optional[dict] = None): + BaseModule.__init__(self, init_cfg) + nn.Sequential.__init__(self, *args) + + +class ModuleList(BaseModule, nn.ModuleList): + """ModuleList in openmmlab. + + Args: + modules (iterable, optional): an iterable of modules to add. + init_cfg (dict, optional): Initialization config dict. + """ + + def __init__(self, + modules: Optional[Iterable] = None, + init_cfg: Optional[dict] = None): + BaseModule.__init__(self, init_cfg) + nn.ModuleList.__init__(self, modules) + + +class ModuleDict(BaseModule, nn.ModuleDict): + """ModuleDict in openmmlab. + + Args: + modules (dict, optional): a mapping (dictionary) of (string: module) + or an iterable of key-value pairs of type (string, module). + init_cfg (dict, optional): Initialization config dict. + """ + + def __init__(self, + modules: Optional[dict] = None, + init_cfg: Optional[dict] = None): + BaseModule.__init__(self, init_cfg) + nn.ModuleDict.__init__(self, modules) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/base_runner.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/base_runner.py new file mode 100755 index 00000000..0ccc20dc --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/base_runner.py @@ -0,0 +1,544 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import logging +import os.path as osp +import warnings +from abc import ABCMeta, abstractmethod + +import torch +from torch.optim import Optimizer + +import dbnet_cv +from ..parallel import is_module_wrapper +from .checkpoint import load_checkpoint +from .dist_utils import get_dist_info +from .hooks import HOOKS, Hook +from .log_buffer import LogBuffer +from .priority import Priority, get_priority +from .utils import get_time_str + + +class BaseRunner(metaclass=ABCMeta): + """The base class of Runner, a training helper for PyTorch. + + All subclasses should implement the following APIs: + + - ``run()`` + - ``train()`` + - ``val()`` + - ``save_checkpoint()`` + + Args: + model (:obj:`torch.nn.Module`): The model to be run. + batch_processor (callable): A callable method that process a data + batch. The interface of this method should be + `batch_processor(model, data, train_mode) -> dict` + optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an + optimizer (in most cases) or a dict of optimizers (in models that + requires more than one optimizer, e.g., GAN). + work_dir (str, optional): The working directory to save checkpoints + and logs. Defaults to None. + logger (:obj:`logging.Logger`): Logger used during training. + Defaults to None. (The default value is just for backward + compatibility) + meta (dict | None): A dict records some import information such as + environment info and seed, which will be logged in logger hook. + Defaults to None. + max_epochs (int, optional): Total training epochs. + max_iters (int, optional): Total training iterations. + """ + + def __init__(self, + model, + batch_processor=None, + optimizer=None, + work_dir=None, + logger=None, + meta=None, + max_iters=None, + max_epochs=None): + if batch_processor is not None: + if not callable(batch_processor): + raise TypeError('batch_processor must be callable, ' + f'but got {type(batch_processor)}') + warnings.warn( + 'batch_processor is deprecated, please implement ' + 'train_step() and val_step() in the model instead.', + DeprecationWarning) + # raise an error is `batch_processor` is not None and + # `model.train_step()` exists. + if is_module_wrapper(model): + _model = model.module + else: + _model = model + if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'): + raise RuntimeError( + 'batch_processor and model.train_step()/model.val_step() ' + 'cannot be both available.') + else: + assert hasattr(model, 'train_step') + + # check the type of `optimizer` + if isinstance(optimizer, dict): + for name, optim in optimizer.items(): + if not isinstance(optim, Optimizer): + raise TypeError( + f'optimizer must be a dict of torch.optim.Optimizers, ' + f'but optimizer["{name}"] is a {type(optim)}') + elif not isinstance(optimizer, Optimizer) and optimizer is not None: + raise TypeError( + f'optimizer must be a torch.optim.Optimizer object ' + f'or dict or None, but got {type(optimizer)}') + + # check the type of `logger` + if not isinstance(logger, logging.Logger): + raise TypeError(f'logger must be a logging.Logger object, ' + f'but got {type(logger)}') + + # check the type of `meta` + if meta is not None and not isinstance(meta, dict): + raise TypeError( + f'meta must be a dict or None, but got {type(meta)}') + + self.model = model + self.batch_processor = batch_processor + self.optimizer = optimizer + self.logger = logger + self.meta = meta + # create work_dir + if dbnet_cv.is_str(work_dir): + self.work_dir = osp.abspath(work_dir) + dbnet_cv.mkdir_or_exist(self.work_dir) + elif work_dir is None: + self.work_dir = None + else: + raise TypeError('"work_dir" must be a str or None') + + # get model name from the model class + if hasattr(self.model, 'module'): + self._model_name = self.model.module.__class__.__name__ + else: + self._model_name = self.model.__class__.__name__ + + self._rank, self._world_size = get_dist_info() + self.timestamp = get_time_str() + self.mode = None + self._hooks = [] + self._epoch = 0 + self._iter = 0 + self._inner_iter = 0 + + if max_epochs is not None and max_iters is not None: + raise ValueError( + 'Only one of `max_epochs` or `max_iters` can be set.') + + self._max_epochs = max_epochs + self._max_iters = max_iters + # TODO: Redesign LogBuffer, it is not flexible and elegant enough + self.log_buffer = LogBuffer() + + @property + def model_name(self): + """str: Name of the model, usually the module class name.""" + return self._model_name + + @property + def rank(self): + """int: Rank of current process. (distributed training)""" + return self._rank + + @property + def world_size(self): + """int: Number of processes participating in the job. + (distributed training)""" + return self._world_size + + @property + def hooks(self): + """list[:obj:`Hook`]: A list of registered hooks.""" + return self._hooks + + @property + def epoch(self): + """int: Current epoch.""" + return self._epoch + + @property + def iter(self): + """int: Current iteration.""" + return self._iter + + @property + def inner_iter(self): + """int: Iteration in an epoch.""" + return self._inner_iter + + @property + def max_epochs(self): + """int: Maximum training epochs.""" + return self._max_epochs + + @property + def max_iters(self): + """int: Maximum training iterations.""" + return self._max_iters + + @abstractmethod + def train(self): + pass + + @abstractmethod + def val(self): + pass + + @abstractmethod + def run(self, data_loaders, workflow, **kwargs): + pass + + @abstractmethod + def save_checkpoint(self, + out_dir, + filename_tmpl, + save_optimizer=True, + meta=None, + create_symlink=True): + pass + + def current_lr(self): + """Get current learning rates. + + Returns: + list[float] | dict[str, list[float]]: Current learning rates of all + param groups. If the runner has a dict of optimizers, this method + will return a dict. + """ + if isinstance(self.optimizer, torch.optim.Optimizer): + lr = [group['lr'] for group in self.optimizer.param_groups] + elif isinstance(self.optimizer, dict): + lr = dict() + for name, optim in self.optimizer.items(): + lr[name] = [group['lr'] for group in optim.param_groups] + else: + raise RuntimeError( + 'lr is not applicable because optimizer does not exist.') + return lr + + def current_momentum(self): + """Get current momentums. + + Returns: + list[float] | dict[str, list[float]]: Current momentums of all + param groups. If the runner has a dict of optimizers, this method + will return a dict. + """ + + def _get_momentum(optimizer): + momentums = [] + for group in optimizer.param_groups: + if 'momentum' in group.keys(): + momentums.append(group['momentum']) + elif 'betas' in group.keys(): + momentums.append(group['betas'][0]) + else: + momentums.append(0) + return momentums + + if self.optimizer is None: + raise RuntimeError( + 'momentum is not applicable because optimizer does not exist.') + elif isinstance(self.optimizer, torch.optim.Optimizer): + momentums = _get_momentum(self.optimizer) + elif isinstance(self.optimizer, dict): + momentums = dict() + for name, optim in self.optimizer.items(): + momentums[name] = _get_momentum(optim) + return momentums + + def register_hook(self, hook, priority='NORMAL'): + """Register a hook into the hook list. + + The hook will be inserted into a priority queue, with the specified + priority (See :class:`Priority` for details of priorities). + For hooks with the same priority, they will be triggered in the same + order as they are registered. + + Args: + hook (:obj:`Hook`): The hook to be registered. + priority (int or str or :obj:`Priority`): Hook priority. + Lower value means higher priority. + """ + assert isinstance(hook, Hook) + if hasattr(hook, 'priority'): + raise ValueError('"priority" is a reserved attribute for hooks') + priority = get_priority(priority) + hook.priority = priority + # insert the hook to a sorted list + inserted = False + for i in range(len(self._hooks) - 1, -1, -1): + if priority >= self._hooks[i].priority: + self._hooks.insert(i + 1, hook) + inserted = True + break + if not inserted: + self._hooks.insert(0, hook) + + def register_hook_from_cfg(self, hook_cfg): + """Register a hook from its cfg. + + Args: + hook_cfg (dict): Hook config. It should have at least keys 'type' + and 'priority' indicating its type and priority. + + Note: + The specific hook class to register should not use 'type' and + 'priority' arguments during initialization. + """ + hook_cfg = hook_cfg.copy() + priority = hook_cfg.pop('priority', 'NORMAL') + hook = dbnet_cv.build_from_cfg(hook_cfg, HOOKS) + self.register_hook(hook, priority=priority) + + def call_hook(self, fn_name): + """Call all hooks. + + Args: + fn_name (str): The function name in each hook to be called, such as + "before_train_epoch". + """ + for hook in self._hooks: + getattr(hook, fn_name)(self) + + def get_hook_info(self): + # Get hooks info in each stage + stage_hook_map = {stage: [] for stage in Hook.stages} + for hook in self.hooks: + try: + priority = Priority(hook.priority).name + except ValueError: + priority = hook.priority + classname = hook.__class__.__name__ + hook_info = f'({priority:<12}) {classname:<35}' + for trigger_stage in hook.get_triggered_stages(): + stage_hook_map[trigger_stage].append(hook_info) + + stage_hook_infos = [] + for stage in Hook.stages: + hook_infos = stage_hook_map[stage] + if len(hook_infos) > 0: + info = f'{stage}:\n' + info += '\n'.join(hook_infos) + info += '\n -------------------- ' + stage_hook_infos.append(info) + return '\n'.join(stage_hook_infos) + + def load_checkpoint(self, + filename, + map_location='cpu', + strict=False, + revise_keys=[(r'^module.', '')]): + return load_checkpoint( + self.model, + filename, + map_location, + strict, + self.logger, + revise_keys=revise_keys) + + def resume(self, + checkpoint, + resume_optimizer=True, + map_location='default'): + if map_location == 'default': + if torch.cuda.is_available(): + device_id = torch.cuda.current_device() + checkpoint = self.load_checkpoint( + checkpoint, + map_location=lambda storage, loc: storage.cuda(device_id)) + else: + checkpoint = self.load_checkpoint(checkpoint) + else: + checkpoint = self.load_checkpoint( + checkpoint, map_location=map_location) + + self._epoch = checkpoint['meta']['epoch'] + self._iter = checkpoint['meta']['iter'] + if self.meta is None: + self.meta = {} + self.meta.setdefault('hook_msgs', {}) + # load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages + self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {})) + + # Re-calculate the number of iterations when resuming + # models with different number of GPUs + if 'config' in checkpoint['meta']: + config = dbnet_cv.Config.fromstring( + checkpoint['meta']['config'], file_format='.py') + previous_gpu_ids = config.get('gpu_ids', None) + if previous_gpu_ids and len(previous_gpu_ids) > 0 and len( + previous_gpu_ids) != self.world_size: + self._iter = int(self._iter * len(previous_gpu_ids) / + self.world_size) + self.logger.info('the iteration number is changed due to ' + 'change of GPU number') + + # resume meta information meta + self.meta = checkpoint['meta'] + + if 'optimizer' in checkpoint and resume_optimizer: + if isinstance(self.optimizer, Optimizer): + self.optimizer.load_state_dict(checkpoint['optimizer']) + elif isinstance(self.optimizer, dict): + for k in self.optimizer.keys(): + self.optimizer[k].load_state_dict( + checkpoint['optimizer'][k]) + else: + raise TypeError( + 'Optimizer should be dict or torch.optim.Optimizer ' + f'but got {type(self.optimizer)}') + + self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter) + + def register_lr_hook(self, lr_config): + if lr_config is None: + return + elif isinstance(lr_config, dict): + assert 'policy' in lr_config + policy_type = lr_config.pop('policy') + # If the type of policy is all in lower case, e.g., 'cyclic', + # then its first letter will be capitalized, e.g., to be 'Cyclic'. + # This is for the convenient usage of Lr updater. + # Since this is not applicable for ` + # CosineAnnealingLrUpdater`, + # the string will not be changed if it contains capital letters. + if policy_type == policy_type.lower(): + policy_type = policy_type.title() + hook_type = policy_type + 'LrUpdaterHook' + lr_config['type'] = hook_type + hook = dbnet_cv.build_from_cfg(lr_config, HOOKS) + else: + hook = lr_config + self.register_hook(hook, priority='VERY_HIGH') + + def register_momentum_hook(self, momentum_config): + if momentum_config is None: + return + if isinstance(momentum_config, dict): + assert 'policy' in momentum_config + policy_type = momentum_config.pop('policy') + # If the type of policy is all in lower case, e.g., 'cyclic', + # then its first letter will be capitalized, e.g., to be 'Cyclic'. + # This is for the convenient usage of momentum updater. + # Since this is not applicable for + # `CosineAnnealingMomentumUpdater`, + # the string will not be changed if it contains capital letters. + if policy_type == policy_type.lower(): + policy_type = policy_type.title() + hook_type = policy_type + 'MomentumUpdaterHook' + momentum_config['type'] = hook_type + hook = dbnet_cv.build_from_cfg(momentum_config, HOOKS) + else: + hook = momentum_config + self.register_hook(hook, priority='HIGH') + + def register_optimizer_hook(self, optimizer_config): + if optimizer_config is None: + return + if isinstance(optimizer_config, dict): + optimizer_config.setdefault('type', 'OptimizerHook') + hook = dbnet_cv.build_from_cfg(optimizer_config, HOOKS) + else: + hook = optimizer_config + self.register_hook(hook, priority='ABOVE_NORMAL') + + def register_checkpoint_hook(self, checkpoint_config): + if checkpoint_config is None: + return + if isinstance(checkpoint_config, dict): + checkpoint_config.setdefault('type', 'CheckpointHook') + hook = dbnet_cv.build_from_cfg(checkpoint_config, HOOKS) + else: + hook = checkpoint_config + self.register_hook(hook, priority='NORMAL') + + def register_logger_hooks(self, log_config): + if log_config is None: + return + log_interval = log_config['interval'] + for info in log_config['hooks']: + logger_hook = dbnet_cv.build_from_cfg( + info, HOOKS, default_args=dict(interval=log_interval)) + self.register_hook(logger_hook, priority='VERY_LOW') + + def register_timer_hook(self, timer_config): + if timer_config is None: + return + if isinstance(timer_config, dict): + timer_config_ = copy.deepcopy(timer_config) + hook = dbnet_cv.build_from_cfg(timer_config_, HOOKS) + else: + hook = timer_config + self.register_hook(hook, priority='LOW') + + def register_custom_hooks(self, custom_config): + if custom_config is None: + return + + if not isinstance(custom_config, list): + custom_config = [custom_config] + + for item in custom_config: + if isinstance(item, dict): + self.register_hook_from_cfg(item) + else: + self.register_hook(item, priority='NORMAL') + + def register_profiler_hook(self, profiler_config): + if profiler_config is None: + return + if isinstance(profiler_config, dict): + profiler_config.setdefault('type', 'ProfilerHook') + hook = dbnet_cv.build_from_cfg(profiler_config, HOOKS) + else: + hook = profiler_config + self.register_hook(hook) + + def register_training_hooks(self, + lr_config, + optimizer_config=None, + checkpoint_config=None, + log_config=None, + momentum_config=None, + timer_config=dict(type='IterTimerHook'), + custom_hooks_config=None): + """Register default and custom hooks for training. + + Default and custom hooks include: + + +----------------------+-------------------------+ + | Hooks | Priority | + +======================+=========================+ + | LrUpdaterHook | VERY_HIGH (10) | + +----------------------+-------------------------+ + | MomentumUpdaterHook | HIGH (30) | + +----------------------+-------------------------+ + | OptimizerStepperHook | ABOVE_NORMAL (40) | + +----------------------+-------------------------+ + | CheckpointSaverHook | NORMAL (50) | + +----------------------+-------------------------+ + | IterTimerHook | LOW (70) | + +----------------------+-------------------------+ + | LoggerHook(s) | VERY_LOW (90) | + +----------------------+-------------------------+ + | CustomHook(s) | defaults to NORMAL (50) | + +----------------------+-------------------------+ + + If custom hooks have same priority with default hooks, custom hooks + will be triggered after default hooks. + """ + self.register_lr_hook(lr_config) + self.register_momentum_hook(momentum_config) + self.register_optimizer_hook(optimizer_config) + self.register_checkpoint_hook(checkpoint_config) + self.register_timer_hook(timer_config) + self.register_logger_hooks(log_config) + self.register_custom_hooks(custom_hooks_config) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/builder.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/builder.py new file mode 100755 index 00000000..008da32a --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/builder.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Optional + +from ..utils import Registry + +RUNNERS = Registry('runner') +RUNNER_BUILDERS = Registry('runner builder') + + +def build_runner_constructor(cfg: dict): + return RUNNER_BUILDERS.build(cfg) + + +def build_runner(cfg: dict, default_args: Optional[dict] = None): + runner_cfg = copy.deepcopy(cfg) + constructor_type = runner_cfg.pop('constructor', + 'DefaultRunnerConstructor') + runner_constructor = build_runner_constructor( + dict( + type=constructor_type, + runner_cfg=runner_cfg, + default_args=default_args)) + runner = runner_constructor() + return runner diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/checkpoint.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/checkpoint.py new file mode 100755 index 00000000..83a4796e --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/checkpoint.py @@ -0,0 +1,800 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import io +import logging +import os +import os.path as osp +import pkgutil +import re +import time +import warnings +from collections import OrderedDict +from importlib import import_module +from tempfile import TemporaryDirectory +from typing import Callable, List, Optional, Sequence, Union + +import torch +import torchvision +from torch.optim import Optimizer + +import dbnet_cv +from ..fileio import FileClient +from ..fileio import load as load_file +from ..parallel import is_module_wrapper +from ..utils import digit_version, load_url, mkdir_or_exist +from .dist_utils import get_dist_info + +ENV_DBNET_CV_HOME = 'DBNET_CV_HOME' +ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' +DEFAULT_CACHE_DIR = '~/.cache' + + +def _get_dbnet_cv_home(): + dbnet_cv_home = os.path.expanduser( + os.getenv( + ENV_DBNET_CV_HOME, + os.path.join( + os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'dbnet_cv'))) + + mkdir_or_exist(dbnet_cv_home) + return dbnet_cv_home + + +def load_state_dict(module: torch.nn.Module, + state_dict: Union[dict, OrderedDict], + strict: bool = False, + logger: Optional[logging.Logger] = None) -> None: + """Load state_dict to a module. + + This method is modified from :meth:`torch.nn.Module.load_state_dict`. + Default value for ``strict`` is set to ``False`` and the message for + param mismatch will be shown even if strict is False. + + Args: + module (Module): Module that receives the state_dict. + state_dict (OrderedDict): Weights. + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. + logger (:obj:`logging.Logger`, optional): Logger to log the error + message. If not specified, print function will be used. + """ + unexpected_keys: List = [] + all_missing_keys: List = [] + err_msg: List = [] + + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata # type: ignore + + # use _load_from_state_dict to enable checkpoint version control + def load(module, prefix=''): + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict(state_dict, prefix, local_metadata, True, + all_missing_keys, unexpected_keys, + err_msg) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(module) + # break load->load reference cycle + load = None # type: ignore + + # ignore "num_batches_tracked" of BN layers + missing_keys = [ + key for key in all_missing_keys if 'num_batches_tracked' not in key + ] + + if unexpected_keys: + err_msg.append('unexpected key in source ' + f'state_dict: {", ".join(unexpected_keys)}\n') + if missing_keys: + err_msg.append( + f'missing keys in source state_dict: {", ".join(missing_keys)}\n') + + rank, _ = get_dist_info() + if len(err_msg) > 0 and rank == 0: + err_msg.insert( + 0, 'The model and loaded state dict do not match exactly\n') + err_msg = '\n'.join(err_msg) # type: ignore + if strict: + raise RuntimeError(err_msg) + elif logger is not None: + logger.warning(err_msg) + else: + print(err_msg) + + +def get_torchvision_models(): + if digit_version(torchvision.__version__) < digit_version('0.13.0a0'): + model_urls = dict() + # When the version of torchvision is lower than 0.13, the model url is + # not declared in `torchvision.model.__init__.py`, so we need to + # iterate through `torchvision.models.__path__` to get the url for each + # model. + for _, name, ispkg in pkgutil.walk_packages( + torchvision.models.__path__): + if ispkg: + continue + _zoo = import_module(f'torchvision.models.{name}') + if hasattr(_zoo, 'model_urls'): + _urls = getattr(_zoo, 'model_urls') + model_urls.update(_urls) + else: + # Since torchvision bumps to v0.13, the weight loading logic, + # model keys and model urls have been changed. Here the URLs of old + # version is loaded to avoid breaking back compatibility. If the + # torchvision version>=0.13.0, new URLs will be added. Users can get + # the resnet50 checkpoint by setting 'resnet50.imagent1k_v1', + # 'resnet50' or 'ResNet50_Weights.IMAGENET1K_V1' in the config. + json_path = osp.join(dbnet_cv.__path__[0], + 'model_zoo/torchvision_0.12.json') + model_urls = dbnet_cv.load(json_path) + for cls_name, cls in torchvision.models.__dict__.items(): + # The name of torchvision model weights classes ends with + # `_Weights` such as `ResNet18_Weights`. However, some model weight + # classes, such as `MNASNet0_75_Weights` does not have any urls in + # torchvision 0.13.0 and cannot be iterated. Here we simply check + # `DEFAULT` attribute to ensure the class is not empty. + if (not cls_name.endswith('_Weights') + or not hasattr(cls, 'DEFAULT')): + continue + # Since `cls.DEFAULT` can not be accessed by iterating cls, we set + # default urls explicitly. + cls_key = cls_name.replace('_Weights', '').lower() + model_urls[f'{cls_key}.default'] = cls.DEFAULT.url + for weight_enum in cls: + cls_key = cls_name.replace('_Weights', '').lower() + cls_key = f'{cls_key}.{weight_enum.name.lower()}' + model_urls[cls_key] = weight_enum.url + + return model_urls + + +def get_external_models(): + dbnet_cv_home = _get_dbnet_cv_home() + default_json_path = osp.join(dbnet_cv.__path__[0], 'model_zoo/open_mmlab.json') + default_urls = load_file(default_json_path) + assert isinstance(default_urls, dict) + external_json_path = osp.join(dbnet_cv_home, 'open_mmlab.json') + if osp.exists(external_json_path): + external_urls = load_file(external_json_path) + assert isinstance(external_urls, dict) + default_urls.update(external_urls) + + return default_urls + + +def get_mmcls_models(): + mmcls_json_path = osp.join(dbnet_cv.__path__[0], 'model_zoo/mmcls.json') + mmcls_urls = load_file(mmcls_json_path) + + return mmcls_urls + + +def get_deprecated_model_names(): + deprecate_json_path = osp.join(dbnet_cv.__path__[0], + 'model_zoo/deprecated.json') + deprecate_urls = load_file(deprecate_json_path) + assert isinstance(deprecate_urls, dict) + + return deprecate_urls + + +def _process_mmcls_checkpoint(checkpoint): + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + # Some checkpoints converted from 3rd-party repo don't + # have the "state_dict" key. + state_dict = checkpoint + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith('backbone.'): + new_state_dict[k[9:]] = v + new_checkpoint = dict(state_dict=new_state_dict) + + return new_checkpoint + + +class CheckpointLoader: + """A general checkpoint loader to manage all schemes.""" + + _schemes: dict = {} + + @classmethod + def _register_scheme(cls, prefixes, loader, force=False): + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + for prefix in prefixes: + if (prefix not in cls._schemes) or force: + cls._schemes[prefix] = loader + else: + raise KeyError( + f'{prefix} is already registered as a loader backend, ' + 'add "force=True" if you want to override it') + # sort, longer prefixes take priority + cls._schemes = OrderedDict( + sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True)) + + @classmethod + def register_scheme(cls, + prefixes: Union[str, Sequence[str]], + loader: Optional[Callable] = None, + force: bool = False): + """Register a loader to CheckpointLoader. + + This method can be used as a normal class method or a decorator. + + Args: + prefixes (str or Sequence[str]): + The prefix of the registered loader. + loader (function, optional): The loader function to be registered. + When this method is used as a decorator, loader is None. + Defaults to None. + force (bool, optional): Whether to override the loader + if the prefix has already been registered. Defaults to False. + """ + + if loader is not None: + cls._register_scheme(prefixes, loader, force=force) + return + + def _register(loader_cls): + cls._register_scheme(prefixes, loader_cls, force=force) + return loader_cls + + return _register + + @classmethod + def _get_checkpoint_loader(cls, path): + """Finds a loader that supports the given path. Falls back to the local + loader if no other loader is found. + + Args: + path (str): checkpoint path + + Returns: + callable: checkpoint loader + """ + for p in cls._schemes: + # use regular match to handle some cases that where the prefix of + # loader has a prefix. For example, both 's3://path' and + # 'open-mmlab:s3://path' should return `load_from_ceph` + if re.match(p, path) is not None: + return cls._schemes[p] + + @classmethod + def load_checkpoint( + cls, + filename: str, + map_location: Optional[str] = None, + logger: Optional[logging.Logger] = None + ) -> Union[dict, OrderedDict]: + """load checkpoint through URL scheme path. + + Args: + filename (str): checkpoint file name with given prefix + map_location (str, optional): Same as :func:`torch.load`. + Default: None + logger (:mod:`logging.Logger`, optional): The logger for message. + Default: None + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + checkpoint_loader = cls._get_checkpoint_loader(filename) + class_name = checkpoint_loader.__name__ + dbnet_cv.print_log( + f'load checkpoint from {class_name[10:]} path: {filename}', logger) + return checkpoint_loader(filename, map_location) + + +@CheckpointLoader.register_scheme(prefixes='') +def load_from_local( + filename: str, + map_location: Optional[str] = None) -> Union[dict, OrderedDict]: + """load checkpoint by local file path. + + Args: + filename (str): local checkpoint file path + map_location (str, optional): Same as :func:`torch.load`. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + filename = osp.expanduser(filename) + if not osp.isfile(filename): + raise FileNotFoundError(f'{filename} can not be found.') + checkpoint = torch.load(filename, map_location=map_location) + return checkpoint + + +@CheckpointLoader.register_scheme(prefixes=('http://', 'https://')) +def load_from_http( + filename: str, + map_location: Optional[str] = None, + model_dir: Optional[str] = None) -> Union[dict, OrderedDict]: + """load checkpoint through HTTP or HTTPS scheme path. In distributed + setting, this function only download checkpoint at local rank 0. + + Args: + filename (str): checkpoint file path with modelzoo or + torchvision prefix + map_location (str, optional): Same as :func:`torch.load`. + model_dir (str, optional): directory in which to save the object, + Default: None + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + rank, world_size = get_dist_info() + if rank == 0: + checkpoint = load_url( + filename, model_dir=model_dir, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + checkpoint = load_url( + filename, model_dir=model_dir, map_location=map_location) + return checkpoint + + +@CheckpointLoader.register_scheme(prefixes='pavi://') +def load_from_pavi( + filename: str, + map_location: Optional[str] = None) -> Union[dict, OrderedDict]: + """load checkpoint through the file path prefixed with pavi. In distributed + setting, this function download ckpt at all ranks to different temporary + directories. + + Args: + filename (str): checkpoint file path with pavi prefix + map_location (str, optional): Same as :func:`torch.load`. + Default: None + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + assert filename.startswith('pavi://'), \ + f'Expected filename startswith `pavi://`, but get {filename}' + model_path = filename[7:] + + try: + from pavi import modelcloud + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load(downloaded_file, map_location=map_location) + return checkpoint + + +@CheckpointLoader.register_scheme(prefixes=r'(\S+\:)?s3://') +def load_from_ceph(filename: str, + map_location: Optional[str] = None, + backend: str = 'petrel') -> Union[dict, OrderedDict]: + """load checkpoint through the file path prefixed with s3. In distributed + setting, this function download ckpt at all ranks to different temporary + directories. + + Note: + Since v1.4.1, the registered scheme prefixes have been enhanced to + support bucket names in the path prefix, e.g. 's3://xx.xx/xx.path', + 'bucket1:s3://xx.xx/xx.path'. + + Args: + filename (str): checkpoint file path with s3 prefix + map_location (str, optional): Same as :func:`torch.load`. + backend (str): The storage backend type. Options are 'ceph', + 'petrel'. Default: 'petrel'. + + .. warning:: + :class:`dbnet_cv.fileio.file_client.CephBackend` will be deprecated, + please use :class:`dbnet_cv.fileio.file_client.PetrelBackend` instead. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + allowed_backends = ['ceph', 'petrel'] + if backend not in allowed_backends: + raise ValueError(f'Load from Backend {backend} is not supported.') + + if backend == 'ceph': + warnings.warn( + 'CephBackend will be deprecated, please use PetrelBackend instead', + DeprecationWarning) + + # CephClient and PetrelBackend have the same prefix 's3://' and the latter + # will be chosen as default. If PetrelBackend can not be instantiated + # successfully, the CephClient will be chosen. + try: + file_client = FileClient(backend=backend) + except ImportError: + allowed_backends.remove(backend) + file_client = FileClient(backend=allowed_backends[0]) + + with io.BytesIO(file_client.get(filename)) as buffer: + checkpoint = torch.load(buffer, map_location=map_location) + return checkpoint + + +@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://')) +def load_from_torchvision( + filename: str, + map_location: Optional[str] = None) -> Union[dict, OrderedDict]: + """load checkpoint through the file path prefixed with modelzoo or + torchvision. + + Args: + filename (str): checkpoint file path with modelzoo or + torchvision prefix + map_location (str, optional): Same as :func:`torch.load`. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + model_urls = get_torchvision_models() + if filename.startswith('modelzoo://'): + warnings.warn( + 'The URL scheme of "modelzoo://" is deprecated, please ' + 'use "torchvision://" instead', DeprecationWarning) + model_name = filename[11:] + else: + model_name = filename[14:] + + # Support getting model urls in the same way as torchvision + # `ResNet50_Weights.IMAGENET1K_V1` will be mapped to + # resnet50.imagenet1k_v1. + model_name = model_name.lower().replace('_weights', '') + return load_from_http(model_urls[model_name], map_location=map_location) + + +@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://')) +def load_from_openmmlab( + filename: str, + map_location: Optional[str] = None) -> Union[dict, OrderedDict]: + """load checkpoint through the file path prefixed with open-mmlab or + openmmlab. + + Args: + filename (str): checkpoint file path with open-mmlab or + openmmlab prefix + map_location (str, optional): Same as :func:`torch.load`. + Default: None + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + model_urls = get_external_models() + prefix_str = 'open-mmlab://' + if filename.startswith(prefix_str): + model_name = filename[13:] + else: + model_name = filename[12:] + prefix_str = 'openmmlab://' + + deprecated_urls = get_deprecated_model_names() + if model_name in deprecated_urls: + warnings.warn( + f'{prefix_str}{model_name} is deprecated in favor ' + f'of {prefix_str}{deprecated_urls[model_name]}', + DeprecationWarning) + model_name = deprecated_urls[model_name] + model_url = model_urls[model_name] + # check if is url + if model_url.startswith(('http://', 'https://')): + checkpoint = load_from_http(model_url, map_location=map_location) + else: + filename = osp.join(_get_dbnet_cv_home(), model_url) + if not osp.isfile(filename): + raise FileNotFoundError(f'{filename} can not be found.') + checkpoint = torch.load(filename, map_location=map_location) + return checkpoint + + +@CheckpointLoader.register_scheme(prefixes='mmcls://') +def load_from_mmcls( + filename: str, + map_location: Optional[str] = None) -> Union[dict, OrderedDict]: + """load checkpoint through the file path prefixed with mmcls. + + Args: + filename (str): checkpoint file path with mmcls prefix + map_location (str, optional): Same as :func:`torch.load`. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + model_urls = get_mmcls_models() + model_name = filename[8:] + checkpoint = load_from_http( + model_urls[model_name], map_location=map_location) + checkpoint = _process_mmcls_checkpoint(checkpoint) + return checkpoint + + +def _load_checkpoint( + filename: str, + map_location: Optional[str] = None, + logger: Optional[logging.Logger] = None) -> Union[dict, OrderedDict]: + """Load checkpoint from somewhere (modelzoo, file, url). + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str, optional): Same as :func:`torch.load`. + Default: None. + logger (:mod:`logging.Logger`, optional): The logger for error message. + Default: None + + Returns: + dict or OrderedDict: The loaded checkpoint. It can be either an + OrderedDict storing model weights or a dict containing other + information, which depends on the checkpoint. + """ + return CheckpointLoader.load_checkpoint(filename, map_location, logger) + + +def _load_checkpoint_with_prefix( + prefix: str, + filename: str, + map_location: Optional[str] = None) -> Union[dict, OrderedDict]: + """Load partial pretrained model with specific prefix. + + Args: + prefix (str): The prefix of sub-module. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str | None): Same as :func:`torch.load`. Default: None. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + checkpoint = _load_checkpoint(filename, map_location=map_location) + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + if not prefix.endswith('.'): + prefix += '.' + prefix_len = len(prefix) + + state_dict = { + k[prefix_len:]: v + for k, v in state_dict.items() if k.startswith(prefix) + } + + assert state_dict, f'{prefix} is not in the pretrained model' + return state_dict + + +def load_checkpoint( + model: torch.nn.Module, + filename: str, + map_location: Optional[str] = None, + strict: bool = False, + logger: Optional[logging.Logger] = None, + revise_keys: list = [(r'^module\.', '')]) -> Union[dict, OrderedDict]: + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + revise_keys (list): A list of customized keywords to modify the + state_dict in checkpoint. Each item is a (pattern, replacement) + pair of the regular expression operations. Default: strip + the prefix 'module.' by [(r'^module\\.', '')]. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint = _load_checkpoint(filename, map_location, logger) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError( + f'No state_dict found in checkpoint file {filename}') + # get state_dict from checkpoint + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + # strip prefix of state_dict + metadata = getattr(state_dict, '_metadata', OrderedDict()) + for p, r in revise_keys: + state_dict = OrderedDict( + {re.sub(p, r, k): v + for k, v in state_dict.items()}) + # Keep metadata in state_dict + state_dict._metadata = metadata + + # load state_dict + load_state_dict(model, state_dict, strict, logger) + return checkpoint + + +def weights_to_cpu(state_dict: OrderedDict) -> OrderedDict: + """Copy a model state_dict to cpu. + + Args: + state_dict (OrderedDict): Model weights on GPU. + + Returns: + OrderedDict: Model weights on GPU. + """ + state_dict_cpu = OrderedDict() + for key, val in state_dict.items(): + state_dict_cpu[key] = val.cpu() + # Keep metadata in state_dict + state_dict_cpu._metadata = getattr( # type: ignore + state_dict, '_metadata', OrderedDict()) + return state_dict_cpu + + +def _save_to_state_dict(module: torch.nn.Module, destination: dict, + prefix: str, keep_vars: bool) -> None: + """Saves module state to `destination` dictionary. + + This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. + + Args: + module (nn.Module): The module to generate state_dict. + destination (dict): A dict where state will be stored. + prefix (str): The prefix for parameters and buffers used in this + module. + """ + for name, param in module._parameters.items(): + if param is not None: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in module._buffers.items(): + # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d + if buf is not None: + destination[prefix + name] = buf if keep_vars else buf.detach() + + +def get_state_dict(module: torch.nn.Module, + destination: Optional[OrderedDict] = None, + prefix: str = '', + keep_vars: bool = False) -> OrderedDict: + """Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + + This method is modified from :meth:`torch.nn.Module.state_dict` to + recursively check parallel module in case that the model has a complicated + structure, e.g., nn.Module(nn.Module(DDP)). + + Args: + module (nn.Module): The module to generate state_dict. + destination (OrderedDict): Returned dict for the state of the + module. + prefix (str): Prefix of the key. + keep_vars (bool): Whether to keep the variable property of the + parameters. Default: False. + + Returns: + dict: A dictionary containing a whole state of the module. + """ + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + + # below is the same as torch.nn.Module.state_dict() + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() # type: ignore + destination._metadata[prefix[:-1]] = local_metadata = dict( # type: ignore + version=module._version) + _save_to_state_dict(module, destination, prefix, keep_vars) + for name, child in module._modules.items(): + if child is not None: + get_state_dict( + child, destination, prefix + name + '.', keep_vars=keep_vars) + for hook in module._state_dict_hooks.values(): + hook_result = hook(module, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination # type: ignore + + +def save_checkpoint(model: torch.nn.Module, + filename: str, + optimizer: Optional[Optimizer] = None, + meta: Optional[dict] = None, + file_client_args: Optional[dict] = None) -> None: + """Save checkpoint to file. + + The checkpoint will have 3 fields: ``meta``, ``state_dict`` and + ``optimizer``. By default ``meta`` will contain version and time info. + + Args: + model (Module): Module whose params are to be saved. + filename (str): Checkpoint filename. + optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. + meta (dict, optional): Metadata to be saved in checkpoint. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`dbnet_cv.fileio.FileClient` for details. + Default: None. + `New in version 1.3.16.` + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError(f'meta must be a dict or None, but got {type(meta)}') + meta.update(dbnet_cv_version=dbnet_cv.__version__, time=time.asctime()) + + if is_module_wrapper(model): + model = model.module + + if hasattr(model, 'CLASSES') and model.CLASSES is not None: + # save class name to the meta + meta.update(CLASSES=model.CLASSES) + + checkpoint = { + 'meta': meta, + 'state_dict': weights_to_cpu(get_state_dict(model)) + } + # save optimizer state dict in the checkpoint + if isinstance(optimizer, Optimizer): + checkpoint['optimizer'] = optimizer.state_dict() + elif isinstance(optimizer, dict): + checkpoint['optimizer'] = {} + for name, optim in optimizer.items(): + checkpoint['optimizer'][name] = optim.state_dict() + + if filename.startswith('pavi://'): + if file_client_args is not None: + raise ValueError( + 'file_client_args should be "None" if filename starts with' + f'"pavi://", but got {file_client_args}') + try: + from pavi import exception, modelcloud + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + model_path = filename[7:] + root = modelcloud.Folder() + model_dir, model_name = osp.split(model_path) + try: + model = modelcloud.get(model_dir) + except exception.NodeNotFoundError: + model = root.create_training_model(model_dir) + with TemporaryDirectory() as tmp_dir: + checkpoint_file = osp.join(tmp_dir, model_name) + with open(checkpoint_file, 'wb') as f: + torch.save(checkpoint, f) + f.flush() + model.create_file(checkpoint_file, name=model_name) + else: + file_client = FileClient.infer_client(file_client_args, filename) + with io.BytesIO() as f: + torch.save(checkpoint, f) + file_client.put(f.getvalue(), filename) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/default_constructor.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/default_constructor.py new file mode 100755 index 00000000..d4ff4cdf --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/default_constructor.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +from .builder import RUNNER_BUILDERS, RUNNERS + + +@RUNNER_BUILDERS.register_module() +class DefaultRunnerConstructor: + """Default constructor for runners. + + Custom existing `Runner` like `EpocBasedRunner` though `RunnerConstructor`. + For example, We can inject some new properties and functions for `Runner`. + + Example: + >>> from dbnet_cv.runner import RUNNER_BUILDERS, build_runner + >>> # Define a new RunnerReconstructor + >>> @RUNNER_BUILDERS.register_module() + >>> class MyRunnerConstructor: + ... def __init__(self, runner_cfg, default_args=None): + ... if not isinstance(runner_cfg, dict): + ... raise TypeError('runner_cfg should be a dict', + ... f'but got {type(runner_cfg)}') + ... self.runner_cfg = runner_cfg + ... self.default_args = default_args + ... + ... def __call__(self): + ... runner = RUNNERS.build(self.runner_cfg, + ... default_args=self.default_args) + ... # Add new properties for existing runner + ... runner.my_name = 'my_runner' + ... runner.my_function = lambda self: print(self.my_name) + ... ... + >>> # build your runner + >>> runner_cfg = dict(type='EpochBasedRunner', max_epochs=40, + ... constructor='MyRunnerConstructor') + >>> runner = build_runner(runner_cfg) + """ + + def __init__(self, runner_cfg: dict, default_args: Optional[dict] = None): + if not isinstance(runner_cfg, dict): + raise TypeError('runner_cfg should be a dict', + f'but got {type(runner_cfg)}') + self.runner_cfg = runner_cfg + self.default_args = default_args + + def __call__(self): + return RUNNERS.build(self.runner_cfg, default_args=self.default_args) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/dist_utils.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/dist_utils.py new file mode 100755 index 00000000..103e1f08 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/dist_utils.py @@ -0,0 +1,209 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import functools +import os +import socket +import subprocess +from collections import OrderedDict +from typing import Callable, List, Optional, Tuple + +import torch +import torch.multiprocessing as mp +from torch import distributed as dist +from torch._utils import (_flatten_dense_tensors, _take_tensors, + _unflatten_dense_tensors) + +from dbnet_cv.utils import IS_MLU_AVAILABLE + + +def _find_free_port(): + # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501 + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Binding to port 0 will cause the OS to find an available port for us + sock.bind(('', 0)) + port = sock.getsockname()[1] + sock.close() + # NOTE: there is still a chance the port could be taken by other processes. + return port + + +def _is_free_port(port): + ips = socket.gethostbyname_ex(socket.gethostname())[-1] + ips.append('localhost') + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return all(s.connect_ex((ip, port)) != 0 for ip in ips) + + +def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None: + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'mpi': + _init_dist_mpi(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend: str, **kwargs): + # TODO: use local_rank instead of rank % num_gpus + rank = int(os.environ['RANK']) + if IS_MLU_AVAILABLE: + import torch_mlu # noqa: F401 + torch.mlu.set_device(rank) + dist.init_process_group( + backend='cncl', + rank=rank, + world_size=int(os.environ['WORLD_SIZE']), + **kwargs) + else: + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_mpi(backend: str, **kwargs): + local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + torch.cuda.set_device(local_rank) + if 'MASTER_PORT' not in os.environ: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + if 'MASTER_ADDR' not in os.environ: + raise KeyError('The environment variable MASTER_ADDR is not set') + os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE'] + os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK'] + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend: str, port: Optional[int] = None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput( + f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # if torch.distributed default port(29500) is available + # then use it, else find a free port + if _is_free_port(29500): + os.environ['MASTER_PORT'] = '29500' + else: + os.environ['MASTER_PORT'] = str(_find_free_port()) + # use MASTER_ADDR in the environment variable if it already exists + if 'MASTER_ADDR' not in os.environ: + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info() -> Tuple[int, int]: + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func: Callable) -> Callable: + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper + + +def allreduce_params(params: List[torch.nn.Parameter], + coalesce: bool = True, + bucket_size_mb: int = -1) -> None: + """Allreduce parameters. + + Args: + params (list[torch.nn.Parameter]): List of parameters or buffers + of a model. + coalesce (bool, optional): Whether allreduce parameters as a whole. + Defaults to True. + bucket_size_mb (int, optional): Size of bucket, the unit is MB. + Defaults to -1. + """ + _, world_size = get_dist_info() + if world_size == 1: + return + params = [param.data for param in params] + if coalesce: + _allreduce_coalesced(params, world_size, bucket_size_mb) + else: + for tensor in params: + dist.all_reduce(tensor.div_(world_size)) + + +def allreduce_grads(params: List[torch.nn.Parameter], + coalesce: bool = True, + bucket_size_mb: int = -1) -> None: + """Allreduce gradients. + + Args: + params (list[torch.nn.Parameter]): List of parameters of a model. + coalesce (bool, optional): Whether allreduce parameters as a whole. + Defaults to True. + bucket_size_mb (int, optional): Size of bucket, the unit is MB. + Defaults to -1. + """ + grads = [ + param.grad.data for param in params + if param.requires_grad and param.grad is not None + ] + _, world_size = get_dist_info() + if world_size == 1: + return + if coalesce: + _allreduce_coalesced(grads, world_size, bucket_size_mb) + else: + for tensor in grads: + dist.all_reduce(tensor.div_(world_size)) + + +def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): + if bucket_size_mb > 0: + bucket_size_bytes = bucket_size_mb * 1024 * 1024 + buckets = _take_tensors(tensors, bucket_size_bytes) + else: + buckets = OrderedDict() + for tensor in tensors: + tp = tensor.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(tensor) + buckets = buckets.values() + + for bucket in buckets: + flat_tensors = _flatten_dense_tensors(bucket) + dist.all_reduce(flat_tensors) + flat_tensors.div_(world_size) + for tensor, synced in zip( + bucket, _unflatten_dense_tensors(flat_tensors, bucket)): + tensor.copy_(synced) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/epoch_based_runner.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/epoch_based_runner.py new file mode 100755 index 00000000..0abc601d --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/epoch_based_runner.py @@ -0,0 +1,191 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import platform +import shutil +import time +import warnings + +import torch + +import dbnet_cv +from .base_runner import BaseRunner +from .builder import RUNNERS +from .checkpoint import save_checkpoint +from .utils import get_host_info + + +@RUNNERS.register_module() +class EpochBasedRunner(BaseRunner): + """Epoch-based Runner. + + This runner train models epoch by epoch. + """ + + def run_iter(self, data_batch, train_mode, **kwargs): + if self.batch_processor is not None: + outputs = self.batch_processor( + self.model, data_batch, train_mode=train_mode, **kwargs) + elif train_mode: + outputs = self.model.train_step(data_batch, self.optimizer, + **kwargs) + else: + outputs = self.model.val_step(data_batch, self.optimizer, **kwargs) + if not isinstance(outputs, dict): + raise TypeError('"batch_processor()" or "model.train_step()"' + 'and "model.val_step()" must return a dict') + if 'log_vars' in outputs: + self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) + self.outputs = outputs + + def train(self, data_loader, **kwargs): + self.model.train() + self.mode = 'train' + self.data_loader = data_loader + self._max_iters = self._max_epochs * len(self.data_loader) + self.call_hook('before_train_epoch') + time.sleep(2) # Prevent possible deadlock during epoch transition + for i, data_batch in enumerate(self.data_loader): + self.data_batch = data_batch + self._inner_iter = i + self.call_hook('before_train_iter') + self.run_iter(data_batch, train_mode=True, **kwargs) + self.call_hook('after_train_iter') + del self.data_batch + self._iter += 1 + + self.call_hook('after_train_epoch') + self._epoch += 1 + + @torch.no_grad() + def val(self, data_loader, **kwargs): + self.model.eval() + self.mode = 'val' + self.data_loader = data_loader + self.call_hook('before_val_epoch') + time.sleep(2) # Prevent possible deadlock during epoch transition + for i, data_batch in enumerate(self.data_loader): + self.data_batch = data_batch + self._inner_iter = i + self.call_hook('before_val_iter') + self.run_iter(data_batch, train_mode=False) + self.call_hook('after_val_iter') + del self.data_batch + self.call_hook('after_val_epoch') + + def run(self, data_loaders, workflow, max_epochs=None, **kwargs): + """Start running. + + Args: + data_loaders (list[:obj:`DataLoader`]): Dataloaders for training + and validation. + workflow (list[tuple]): A list of (phase, epochs) to specify the + running order and epochs. E.g, [('train', 2), ('val', 1)] means + running 2 epochs for training and 1 epoch for validation, + iteratively. + """ + assert isinstance(data_loaders, list) + assert dbnet_cv.is_list_of(workflow, tuple) + assert len(data_loaders) == len(workflow) + if max_epochs is not None: + warnings.warn( + 'setting max_epochs in run is deprecated, ' + 'please set max_epochs in runner_config', DeprecationWarning) + self._max_epochs = max_epochs + + assert self._max_epochs is not None, ( + 'max_epochs must be specified during instantiation') + + for i, flow in enumerate(workflow): + mode, epochs = flow + if mode == 'train': + self._max_iters = self._max_epochs * len(data_loaders[i]) + break + + work_dir = self.work_dir if self.work_dir is not None else 'NONE' + self.logger.info('Start running, host: %s, work_dir: %s', + get_host_info(), work_dir) + self.logger.info('Hooks will be executed in the following order:\n%s', + self.get_hook_info()) + self.logger.info('workflow: %s, max: %d epochs', workflow, + self._max_epochs) + self.call_hook('before_run') + + while self.epoch < self._max_epochs: + for i, flow in enumerate(workflow): + mode, epochs = flow + if isinstance(mode, str): # self.train() + if not hasattr(self, mode): + raise ValueError( + f'runner has no method named "{mode}" to run an ' + 'epoch') + epoch_runner = getattr(self, mode) + else: + raise TypeError( + 'mode in workflow must be a str, but got {}'.format( + type(mode))) + + for _ in range(epochs): + if mode == 'train' and self.epoch >= self._max_epochs: + break + epoch_runner(data_loaders[i], **kwargs) + + time.sleep(1) # wait for some hooks like loggers to finish + self.call_hook('after_run') + + def save_checkpoint(self, + out_dir, + filename_tmpl='epoch_{}.pth', + save_optimizer=True, + meta=None, + create_symlink=True): + """Save the checkpoint. + + Args: + out_dir (str): The directory that checkpoints are saved. + filename_tmpl (str, optional): The checkpoint filename template, + which contains a placeholder for the epoch number. + Defaults to 'epoch_{}.pth'. + save_optimizer (bool, optional): Whether to save the optimizer to + the checkpoint. Defaults to True. + meta (dict, optional): The meta information to be saved in the + checkpoint. Defaults to None. + create_symlink (bool, optional): Whether to create a symlink + "latest.pth" to point to the latest checkpoint. + Defaults to True. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError( + f'meta should be a dict or None, but got {type(meta)}') + if self.meta is not None: + meta.update(self.meta) + # Note: meta.update(self.meta) should be done before + # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise + # there will be problems with resumed checkpoints. + # More details in https://github.com/open-mmlab/dbnet_cv/pull/1108 + meta.update(epoch=self.epoch + 1, iter=self.iter) + + filename = filename_tmpl.format(self.epoch + 1) + filepath = osp.join(out_dir, filename) + optimizer = self.optimizer if save_optimizer else None + save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) + # in some environments, `os.symlink` is not supported, you may need to + # set `create_symlink` to False + if create_symlink: + dst_file = osp.join(out_dir, 'latest.pth') + if platform.system() != 'Windows': + dbnet_cv.symlink(filename, dst_file) + else: + shutil.copy(filepath, dst_file) + + +@RUNNERS.register_module() +class Runner(EpochBasedRunner): + """Deprecated name of EpochBasedRunner.""" + + def __init__(self, *args, **kwargs): + warnings.warn( + 'Runner was deprecated, please use EpochBasedRunner instead', + DeprecationWarning) + super().__init__(*args, **kwargs) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/fp16_utils.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/fp16_utils.py new file mode 100755 index 00000000..34f42349 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/fp16_utils.py @@ -0,0 +1,435 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +import warnings +from collections import abc +from inspect import getfullargspec +from typing import Callable, Iterable, List, Optional + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter + +from dbnet_cv.utils import TORCH_VERSION, digit_version +from .dist_utils import allreduce_grads as _allreduce_grads + +try: + # If PyTorch version >= 1.6.0, torch.cuda.amp.autocast would be imported + # and used; otherwise, auto fp16 will adopt dbnet_cv's implementation. + # Note that when PyTorch >= 1.6.0, we still cast tensor types to fp16 + # manually, so the behavior may not be consistent with real amp. + from torch.cuda.amp import autocast +except ImportError: + pass + + +def cast_tensor_type(inputs, src_type: torch.dtype, dst_type: torch.dtype): + """Recursively convert Tensor in inputs from src_type to dst_type. + + Note: + In v1.4.4 and later, ``cast_tersor_type`` will only convert the + torch.Tensor which is consistent with ``src_type`` to the ``dst_type``. + Before v1.4.4, it ignores the ``src_type`` argument, leading to some + potential problems. For example, + ``cast_tensor_type(inputs, torch.float, torch.half)`` will convert all + tensors in inputs to ``torch.half`` including those originally in + ``torch.Int`` or other types, which is not expected. + + Args: + inputs: Inputs that to be casted. + src_type (torch.dtype): Source type.. + dst_type (torch.dtype): Destination type. + + Returns: + The same type with inputs, but all contained Tensors have been cast. + """ + if isinstance(inputs, nn.Module): + return inputs + elif isinstance(inputs, torch.Tensor): + # we need to ensure that the type of inputs to be casted are the same + # as the argument `src_type`. + return inputs.to(dst_type) if inputs.dtype == src_type else inputs + elif isinstance(inputs, str): + return inputs + elif isinstance(inputs, np.ndarray): + return inputs + elif isinstance(inputs, abc.Mapping): + return type(inputs)({ # type: ignore + k: cast_tensor_type(v, src_type, dst_type) + for k, v in inputs.items() + }) + elif isinstance(inputs, abc.Iterable): + return type(inputs)( # type: ignore + cast_tensor_type(item, src_type, dst_type) for item in inputs) + else: + return inputs + + +def auto_fp16( + apply_to: Optional[Iterable] = None, + out_fp32: bool = False, + supported_types: tuple = (nn.Module, ), +) -> Callable: + """Decorator to enable fp16 training automatically. + + This decorator is useful when you write custom modules and want to support + mixed precision training. If inputs arguments are fp32 tensors, they will + be converted to fp16 automatically. Arguments other than fp32 tensors are + ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the + backend, otherwise, original dbnet_cv implementation will be adopted. + + Args: + apply_to (Iterable, optional): The argument names to be converted. + `None` indicates all arguments. + out_fp32 (bool): Whether to convert the output back to fp32. + supported_types (tuple): Classes can be decorated by ``auto_fp16``. + `New in version 1.5.0.` + Example: + + >>> import torch.nn as nn + >>> class MyModule1(nn.Module): + >>> + >>> # Convert x and y to fp16 + >>> @auto_fp16() + >>> def forward(self, x, y): + >>> pass + + >>> import torch.nn as nn + >>> class MyModule2(nn.Module): + >>> + >>> # convert pred to fp16 + >>> @auto_fp16(apply_to=('pred', )) + >>> def do_something(self, pred, others): + >>> pass + """ + + def auto_fp16_wrapper(old_func): + + @functools.wraps(old_func) + def new_func(*args, **kwargs): + # check if the module has set the attribute `fp16_enabled`, if not, + # just fallback to the original method. + if not isinstance(args[0], supported_types): + raise TypeError('@auto_fp16 can only be used to decorate the ' + f'method of those classes {supported_types}') + if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): + return old_func(*args, **kwargs) + + # get the arg spec of the decorated method + args_info = getfullargspec(old_func) + # get the argument names to be casted + args_to_cast = args_info.args if apply_to is None else apply_to + # convert the args that need to be processed + new_args = [] + # NOTE: default args are not taken into consideration + if args: + arg_names = args_info.args[:len(args)] + for i, arg_name in enumerate(arg_names): + if arg_name in args_to_cast: + new_args.append( + cast_tensor_type(args[i], torch.float, torch.half)) + else: + new_args.append(args[i]) + # convert the kwargs that need to be processed + new_kwargs = {} + if kwargs: + for arg_name, arg_value in kwargs.items(): + if arg_name in args_to_cast: + new_kwargs[arg_name] = cast_tensor_type( + arg_value, torch.float, torch.half) + else: + new_kwargs[arg_name] = arg_value + # apply converted arguments to the decorated method + if (TORCH_VERSION != 'parrots' and + digit_version(TORCH_VERSION) >= digit_version('1.6.0')): + with autocast(enabled=True): + output = old_func(*new_args, **new_kwargs) + else: + output = old_func(*new_args, **new_kwargs) + # cast the results back to fp32 if necessary + if out_fp32: + output = cast_tensor_type(output, torch.half, torch.float) + return output + + return new_func + + return auto_fp16_wrapper + + +def force_fp32(apply_to: Optional[Iterable] = None, + out_fp16: bool = False) -> Callable: + """Decorator to convert input arguments to fp32 in force. + + This decorator is useful when you write custom modules and want to support + mixed precision training. If there are some inputs that must be processed + in fp32 mode, then this decorator can handle it. If inputs arguments are + fp16 tensors, they will be converted to fp32 automatically. Arguments other + than fp16 tensors are ignored. If you are using PyTorch >= 1.6, + torch.cuda.amp is used as the backend, otherwise, original dbnet_cv + implementation will be adopted. + + Args: + apply_to (Iterable, optional): The argument names to be converted. + `None` indicates all arguments. + out_fp16 (bool): Whether to convert the output back to fp16. + + Example: + + >>> import torch.nn as nn + >>> class MyModule1(nn.Module): + >>> + >>> # Convert x and y to fp32 + >>> @force_fp32() + >>> def loss(self, x, y): + >>> pass + + >>> import torch.nn as nn + >>> class MyModule2(nn.Module): + >>> + >>> # convert pred to fp32 + >>> @force_fp32(apply_to=('pred', )) + >>> def post_process(self, pred, others): + >>> pass + """ + + def force_fp32_wrapper(old_func): + + @functools.wraps(old_func) + def new_func(*args, **kwargs): + # check if the module has set the attribute `fp16_enabled`, if not, + # just fallback to the original method. + if not isinstance(args[0], torch.nn.Module): + raise TypeError('@force_fp32 can only be used to decorate the ' + 'method of nn.Module') + if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): + return old_func(*args, **kwargs) + # get the arg spec of the decorated method + args_info = getfullargspec(old_func) + # get the argument names to be casted + args_to_cast = args_info.args if apply_to is None else apply_to + # convert the args that need to be processed + new_args = [] + if args: + arg_names = args_info.args[:len(args)] + for i, arg_name in enumerate(arg_names): + if arg_name in args_to_cast: + new_args.append( + cast_tensor_type(args[i], torch.half, torch.float)) + else: + new_args.append(args[i]) + # convert the kwargs that need to be processed + new_kwargs = dict() + if kwargs: + for arg_name, arg_value in kwargs.items(): + if arg_name in args_to_cast: + new_kwargs[arg_name] = cast_tensor_type( + arg_value, torch.half, torch.float) + else: + new_kwargs[arg_name] = arg_value + # apply converted arguments to the decorated method + if (TORCH_VERSION != 'parrots' and + digit_version(TORCH_VERSION) >= digit_version('1.6.0')): + with autocast(enabled=False): + output = old_func(*new_args, **new_kwargs) + else: + output = old_func(*new_args, **new_kwargs) + # cast the results back to fp32 if necessary + if out_fp16: + output = cast_tensor_type(output, torch.float, torch.half) + return output + + return new_func + + return force_fp32_wrapper + + +def allreduce_grads(params: List[Parameter], + coalesce: bool = True, + bucket_size_mb: int = -1) -> None: + warnings.warn( + '"dbnet_cv.runner.fp16_utils.allreduce_grads" is deprecated, and will be ' + 'removed in v2.8. Please switch to "dbnet_cv.runner.allreduce_grads', + DeprecationWarning) + _allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb) + + +def wrap_fp16_model(model: nn.Module) -> None: + """Wrap the FP32 model to FP16. + + If you are using PyTorch >= 1.6, torch.cuda.amp is used as the + backend, otherwise, original dbnet_cv implementation will be adopted. + + For PyTorch >= 1.6, this function will + 1. Set fp16 flag inside the model to True. + + Otherwise: + 1. Convert FP32 model to FP16. + 2. Remain some necessary layers to be FP32, e.g., normalization layers. + 3. Set `fp16_enabled` flag inside the model to True. + + Args: + model (nn.Module): Model in FP32. + """ + if (TORCH_VERSION == 'parrots' + or digit_version(TORCH_VERSION) < digit_version('1.6.0')): + # convert model to fp16 + model.half() + # patch the normalization layers to make it work in fp32 mode + patch_norm_fp32(model) + # set `fp16_enabled` flag + for m in model.modules(): + if hasattr(m, 'fp16_enabled'): + m.fp16_enabled = True + + +def patch_norm_fp32(module: nn.Module) -> nn.Module: + """Recursively convert normalization layers from FP16 to FP32. + + Args: + module (nn.Module): The modules to be converted in FP16. + + Returns: + nn.Module: The converted module, the normalization layers have been + converted to FP32. + """ + if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)): + module.float() + if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3': + module.forward = patch_forward_method(module.forward, torch.half, + torch.float) + for child in module.children(): + patch_norm_fp32(child) + return module + + +def patch_forward_method(func: Callable, + src_type: torch.dtype, + dst_type: torch.dtype, + convert_output: bool = True) -> Callable: + """Patch the forward method of a module. + + Args: + func (callable): The original forward method. + src_type (torch.dtype): Type of input arguments to be converted from. + dst_type (torch.dtype): Type of input arguments to be converted to. + convert_output (bool): Whether to convert the output back to src_type. + + Returns: + callable: The patched forward method. + """ + + def new_forward(*args, **kwargs): + output = func(*cast_tensor_type(args, src_type, dst_type), + **cast_tensor_type(kwargs, src_type, dst_type)) + if convert_output: + output = cast_tensor_type(output, dst_type, src_type) + return output + + return new_forward + + +class LossScaler: + """Class that manages loss scaling in mixed precision training which + supports both dynamic or static mode. + + The implementation refers to + https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/loss_scaler.py. + Indirectly, by supplying ``mode='dynamic'`` for dynamic loss scaling. + It's important to understand how :class:`LossScaler` operates. + Loss scaling is designed to combat the problem of underflowing + gradients encountered at long times when training fp16 networks. + Dynamic loss scaling begins by attempting a very high loss + scale. Ironically, this may result in OVERflowing gradients. + If overflowing gradients are encountered, :class:`FP16_Optimizer` then + skips the update step for this particular iteration/minibatch, + and :class:`LossScaler` adjusts the loss scale to a lower value. + If a certain number of iterations occur without overflowing gradients + detected,:class:`LossScaler` increases the loss scale once more. + In this way :class:`LossScaler` attempts to "ride the edge" of always + using the highest loss scale possible without incurring overflow. + + Args: + init_scale (float): Initial loss scale value, default: 2**32. + scale_factor (float): Factor used when adjusting the loss scale. + Default: 2. + mode (str): Loss scaling mode. 'dynamic' or 'static' + scale_window (int): Number of consecutive iterations without an + overflow to wait before increasing the loss scale. Default: 1000. + """ + + def __init__(self, + init_scale: float = 2**32, + mode: str = 'dynamic', + scale_factor: float = 2., + scale_window: int = 1000): + self.cur_scale = init_scale + self.cur_iter = 0 + assert mode in ('dynamic', + 'static'), 'mode can only be dynamic or static' + self.mode = mode + self.last_overflow_iter = -1 + self.scale_factor = scale_factor + self.scale_window = scale_window + + def has_overflow(self, params: List[Parameter]) -> bool: + """Check if params contain overflow.""" + if self.mode != 'dynamic': + return False + for p in params: + if p.grad is not None and LossScaler._has_inf_or_nan(p.grad.data): + return True + return False + + def _has_inf_or_nan(x): + """Check if params contain NaN.""" + try: + cpu_sum = float(x.float().sum()) + except RuntimeError as instance: + if 'value cannot be converted' not in instance.args[0]: + raise + return True + else: + if cpu_sum == float('inf') or cpu_sum == -float('inf') \ + or cpu_sum != cpu_sum: + return True + return False + + def update_scale(self, overflow: bool) -> None: + """update the current loss scale value when overflow happens.""" + if self.mode != 'dynamic': + return + if overflow: + self.cur_scale = max(self.cur_scale / self.scale_factor, 1) + self.last_overflow_iter = self.cur_iter + else: + if (self.cur_iter - self.last_overflow_iter) % \ + self.scale_window == 0: + self.cur_scale *= self.scale_factor + self.cur_iter += 1 + + def state_dict(self): + """Returns the state of the scaler as a :class:`dict`.""" + return dict( + cur_scale=self.cur_scale, + cur_iter=self.cur_iter, + mode=self.mode, + last_overflow_iter=self.last_overflow_iter, + scale_factor=self.scale_factor, + scale_window=self.scale_window) + + def load_state_dict(self, state_dict: dict) -> None: + """Loads the loss_scaler state dict. + + Args: + state_dict (dict): scaler state. + """ + self.cur_scale = state_dict['cur_scale'] + self.cur_iter = state_dict['cur_iter'] + self.mode = state_dict['mode'] + self.last_overflow_iter = state_dict['last_overflow_iter'] + self.scale_factor = state_dict['scale_factor'] + self.scale_window = state_dict['scale_window'] + + @property + def loss_scale(self): + return self.cur_scale diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/__init__.py new file mode 100755 index 00000000..03e2a619 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/__init__.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .checkpoint import CheckpointHook +from .closure import ClosureHook +from .ema import EMAHook +from .evaluation import DistEvalHook, EvalHook +from .hook import HOOKS, Hook +from .iter_timer import IterTimerHook +from .logger import (ClearMLLoggerHook, DvcliveLoggerHook, LoggerHook, + MlflowLoggerHook, NeptuneLoggerHook, PaviLoggerHook, + SegmindLoggerHook, TensorboardLoggerHook, TextLoggerHook, + WandbLoggerHook) +from .lr_updater import (CosineAnnealingLrUpdaterHook, + CosineRestartLrUpdaterHook, CyclicLrUpdaterHook, + ExpLrUpdaterHook, FixedLrUpdaterHook, + FlatCosineAnnealingLrUpdaterHook, InvLrUpdaterHook, + LinearAnnealingLrUpdaterHook, LrUpdaterHook, + OneCycleLrUpdaterHook, PolyLrUpdaterHook, + StepLrUpdaterHook) +from .memory import EmptyCacheHook +from .momentum_updater import (CosineAnnealingMomentumUpdaterHook, + CyclicMomentumUpdaterHook, + LinearAnnealingMomentumUpdaterHook, + MomentumUpdaterHook, + OneCycleMomentumUpdaterHook, + StepMomentumUpdaterHook) +from .optimizer import (Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook, + GradientCumulativeOptimizerHook, OptimizerHook) +from .profiler import ProfilerHook +from .sampler_seed import DistSamplerSeedHook +from .sync_buffer import SyncBuffersHook + +__all__ = [ + 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', + 'FixedLrUpdaterHook', 'StepLrUpdaterHook', 'ExpLrUpdaterHook', + 'PolyLrUpdaterHook', 'InvLrUpdaterHook', 'CosineAnnealingLrUpdaterHook', + 'FlatCosineAnnealingLrUpdaterHook', 'CosineRestartLrUpdaterHook', + 'CyclicLrUpdaterHook', 'OneCycleLrUpdaterHook', 'OptimizerHook', + 'Fp16OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', + 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook', + 'TextLoggerHook', 'TensorboardLoggerHook', 'NeptuneLoggerHook', + 'WandbLoggerHook', 'DvcliveLoggerHook', 'MomentumUpdaterHook', + 'StepMomentumUpdaterHook', 'CosineAnnealingMomentumUpdaterHook', + 'CyclicMomentumUpdaterHook', 'OneCycleMomentumUpdaterHook', + 'SyncBuffersHook', 'EMAHook', 'EvalHook', 'DistEvalHook', 'ProfilerHook', + 'GradientCumulativeOptimizerHook', 'GradientCumulativeFp16OptimizerHook', + 'SegmindLoggerHook', 'LinearAnnealingLrUpdaterHook', + 'LinearAnnealingMomentumUpdaterHook', 'ClearMLLoggerHook' +] diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/checkpoint.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/checkpoint.py new file mode 100755 index 00000000..83289b1f --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/checkpoint.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings + +from dbnet_cv.fileio import FileClient +from ..dist_utils import allreduce_params, master_only +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class CheckpointHook(Hook): + """Save checkpoints periodically. + + Args: + interval (int): The saving period. If ``by_epoch=True``, interval + indicates epochs, otherwise it indicates iterations. + Default: -1, which means "never". + by_epoch (bool): Saving checkpoints by epoch or by iteration. + Default: True. + save_optimizer (bool): Whether to save optimizer state_dict in the + checkpoint. It is usually used for resuming experiments. + Default: True. + out_dir (str, optional): The root directory to save checkpoints. If not + specified, ``runner.work_dir`` will be used by default. If + specified, the ``out_dir`` will be the concatenation of ``out_dir`` + and the last level directory of ``runner.work_dir``. + `Changed in version 1.3.16.` + max_keep_ckpts (int, optional): The maximum checkpoints to keep. + In some cases we want only the latest few checkpoints and would + like to delete old ones to save the disk space. + Default: -1, which means unlimited. + save_last (bool, optional): Whether to force the last checkpoint to be + saved regardless of interval. Default: True. + sync_buffer (bool, optional): Whether to synchronize buffers in + different gpus. Default: False. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`dbnet_cv.fileio.FileClient` for details. + Default: None. + `New in version 1.3.16.` + + .. warning:: + Before v1.3.16, the ``out_dir`` argument indicates the path where the + checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the + root directory and the final path to save checkpoint is the + concatenation of ``out_dir`` and the last level directory of + ``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A" + and the value of ``runner.work_dir`` is "/path/of/B", then the final + path will be "/path/of/A/B". + """ + + def __init__(self, + interval=-1, + by_epoch=True, + save_optimizer=True, + out_dir=None, + max_keep_ckpts=-1, + save_last=True, + sync_buffer=False, + file_client_args=None, + **kwargs): + self.interval = interval + self.by_epoch = by_epoch + self.save_optimizer = save_optimizer + self.out_dir = out_dir + self.max_keep_ckpts = max_keep_ckpts + self.save_last = save_last + self.args = kwargs + self.sync_buffer = sync_buffer + self.file_client_args = file_client_args + + def before_run(self, runner): + if not self.out_dir: + self.out_dir = runner.work_dir + + self.file_client = FileClient.infer_client(self.file_client_args, + self.out_dir) + + # if `self.out_dir` is not equal to `runner.work_dir`, it means that + # `self.out_dir` is set so the final `self.out_dir` is the + # concatenation of `self.out_dir` and the last level directory of + # `runner.work_dir` + if self.out_dir != runner.work_dir: + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + self.out_dir = self.file_client.join_path(self.out_dir, basename) + + runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by ' + f'{self.file_client.name}.') + + # disable the create_symlink option because some file backends do not + # allow to create a symlink + if 'create_symlink' in self.args: + if self.args[ + 'create_symlink'] and not self.file_client.allow_symlink: + self.args['create_symlink'] = False + warnings.warn( + 'create_symlink is set as True by the user but is changed' + 'to be False because creating symbolic link is not ' + f'allowed in {self.file_client.name}') + else: + self.args['create_symlink'] = self.file_client.allow_symlink + + def after_train_epoch(self, runner): + if not self.by_epoch: + return + + # save checkpoint for following cases: + # 1. every ``self.interval`` epochs + # 2. reach the last epoch of training + if self.every_n_epochs( + runner, self.interval) or (self.save_last + and self.is_last_epoch(runner)): + runner.logger.info( + f'Saving checkpoint at {runner.epoch + 1} epochs') + if self.sync_buffer: + allreduce_params(runner.model.buffers()) + self._save_checkpoint(runner) + + @master_only + def _save_checkpoint(self, runner): + """Save the current checkpoint and delete unwanted checkpoint.""" + runner.save_checkpoint( + self.out_dir, save_optimizer=self.save_optimizer, **self.args) + if runner.meta is not None: + if self.by_epoch: + cur_ckpt_filename = self.args.get( + 'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1) + else: + cur_ckpt_filename = self.args.get( + 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1) + runner.meta.setdefault('hook_msgs', dict()) + runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path( + self.out_dir, cur_ckpt_filename) + # remove other checkpoints + if self.max_keep_ckpts > 0: + if self.by_epoch: + name = 'epoch_{}.pth' + current_ckpt = runner.epoch + 1 + else: + name = 'iter_{}.pth' + current_ckpt = runner.iter + 1 + redundant_ckpts = range( + current_ckpt - self.max_keep_ckpts * self.interval, 0, + -self.interval) + filename_tmpl = self.args.get('filename_tmpl', name) + for _step in redundant_ckpts: + ckpt_path = self.file_client.join_path( + self.out_dir, filename_tmpl.format(_step)) + if self.file_client.isfile(ckpt_path): + self.file_client.remove(ckpt_path) + else: + break + + def after_train_iter(self, runner): + if self.by_epoch: + return + + # save checkpoint for following cases: + # 1. every ``self.interval`` iterations + # 2. reach the last iteration of training + if self.every_n_iters( + runner, self.interval) or (self.save_last + and self.is_last_iter(runner)): + runner.logger.info( + f'Saving checkpoint at {runner.iter + 1} iterations') + if self.sync_buffer: + allreduce_params(runner.model.buffers()) + self._save_checkpoint(runner) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/closure.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/closure.py new file mode 100755 index 00000000..b955f81f --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/closure.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class ClosureHook(Hook): + + def __init__(self, fn_name, fn): + assert hasattr(self, fn_name) + assert callable(fn) + setattr(self, fn_name, fn) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/ema.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/ema.py new file mode 100755 index 00000000..6ed77b84 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/ema.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ...parallel import is_module_wrapper +from ..hooks.hook import HOOKS, Hook + + +@HOOKS.register_module() +class EMAHook(Hook): + r"""Exponential Moving Average Hook. + + Use Exponential Moving Average on all parameters of model in training + process. All parameters have a ema backup, which update by the formula + as below. EMAHook takes priority over EvalHook and CheckpointSaverHook. + + .. math:: + + Xema\_{t+1} = (1 - \text{momentum}) \times + Xema\_{t} + \text{momentum} \times X_t + + Args: + momentum (float): The momentum used for updating ema parameter. + Defaults to 0.0002. + interval (int): Update ema parameter every interval iteration. + Defaults to 1. + warm_up (int): During first warm_up steps, we may use smaller momentum + to update ema parameters more slowly. Defaults to 100. + resume_from (str): The checkpoint path. Defaults to None. + """ + + def __init__(self, + momentum=0.0002, + interval=1, + warm_up=100, + resume_from=None): + assert isinstance(interval, int) and interval > 0 + self.warm_up = warm_up + self.interval = interval + assert momentum > 0 and momentum < 1 + self.momentum = momentum**interval + self.checkpoint = resume_from + + def before_run(self, runner): + """To resume model with it's ema parameters more friendly. + + Register ema parameter as ``named_buffer`` to model + """ + model = runner.model + if is_module_wrapper(model): + model = model.module + self.param_ema_buffer = {} + self.model_parameters = dict(model.named_parameters(recurse=True)) + for name, value in self.model_parameters.items(): + # "." is not allowed in module's buffer name + buffer_name = f"ema_{name.replace('.', '_')}" + self.param_ema_buffer[name] = buffer_name + model.register_buffer(buffer_name, value.data.clone()) + self.model_buffers = dict(model.named_buffers(recurse=True)) + if self.checkpoint is not None: + runner.resume(self.checkpoint) + + def after_train_iter(self, runner): + """Update ema parameter every self.interval iterations.""" + curr_step = runner.iter + # We warm up the momentum considering the instability at beginning + momentum = min(self.momentum, + (1 + curr_step) / (self.warm_up + curr_step)) + if curr_step % self.interval != 0: + return + for name, parameter in self.model_parameters.items(): + buffer_name = self.param_ema_buffer[name] + buffer_parameter = self.model_buffers[buffer_name] + buffer_parameter.mul_(1 - momentum).add_(momentum, parameter.data) + + def after_train_epoch(self, runner): + """We load parameter values from ema backup to model before the + EvalHook.""" + self._swap_ema_parameters() + + def before_train_epoch(self, runner): + """We recover model's parameter from ema backup after last epoch's + EvalHook.""" + self._swap_ema_parameters() + + def _swap_ema_parameters(self): + """Swap the parameter of model with parameter in ema_buffer.""" + for name, value in self.model_parameters.items(): + temp = value.data.clone() + ema_buffer = self.model_buffers[self.param_ema_buffer[name]] + value.data.copy_(ema_buffer.data) + ema_buffer.data.copy_(temp) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/evaluation.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/evaluation.py new file mode 100755 index 00000000..3e73c32e --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/evaluation.py @@ -0,0 +1,511 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings +from math import inf + +import torch.distributed as dist +from torch.nn.modules.batchnorm import _BatchNorm +from torch.utils.data import DataLoader + +from dbnet_cv.fileio import FileClient +from dbnet_cv.utils import is_seq_of +from .hook import Hook +from .logger import LoggerHook + + +class EvalHook(Hook): + """Non-Distributed evaluation hook. + + This hook will regularly perform evaluation in a given interval when + performing in non-distributed environment. + + Args: + dataloader (DataLoader): A PyTorch dataloader, whose dataset has + implemented ``evaluate`` function. + start (int | None, optional): Evaluation starting epoch. It enables + evaluation before the training starts if ``start`` <= the resuming + epoch. If None, whether to evaluate is merely decided by + ``interval``. Default: None. + interval (int): Evaluation interval. Default: 1. + by_epoch (bool): Determine perform evaluation by epoch or by iteration. + If set to True, it will perform by epoch. Otherwise, by iteration. + Default: True. + save_best (str, optional): If a metric is specified, it would measure + the best checkpoint during evaluation. The information about best + checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep + best score value and best checkpoint path, which will be also + loaded when resume checkpoint. Options are the evaluation metrics + on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox + detection and instance segmentation. ``AR@100`` for proposal + recall. If ``save_best`` is ``auto``, the first key of the returned + ``OrderedDict`` result will be used. Default: None. + rule (str | None, optional): Comparison rule for best score. If set to + None, it will infer a reasonable rule. Keys such as 'acc', 'top' + .etc will be inferred by 'greater' rule. Keys contain 'loss' will + be inferred by 'less' rule. Options are 'greater', 'less', None. + Default: None. + test_fn (callable, optional): test a model with samples from a + dataloader, and return the test results. If ``None``, the default + test function ``dbnet_cv.engine.single_gpu_test`` will be used. + (default: ``None``) + greater_keys (List[str] | None, optional): Metric keys that will be + inferred by 'greater' comparison rule. If ``None``, + _default_greater_keys will be used. (default: ``None``) + less_keys (List[str] | None, optional): Metric keys that will be + inferred by 'less' comparison rule. If ``None``, _default_less_keys + will be used. (default: ``None``) + out_dir (str, optional): The root directory to save checkpoints. If not + specified, `runner.work_dir` will be used by default. If specified, + the `out_dir` will be the concatenation of `out_dir` and the last + level directory of `runner.work_dir`. + `New in version 1.3.16.` + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`dbnet_cv.fileio.FileClient` for details. Default: None. + `New in version 1.3.16.` + **eval_kwargs: Evaluation arguments fed into the evaluate function of + the dataset. + + Note: + If new arguments are added for EvalHook, tools/test.py, + tools/eval_metric.py may be affected. + """ + + # Since the key for determine greater or less is related to the downstream + # tasks, downstream repos may need to overwrite the following inner + # variable accordingly. + + rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y} + init_value_map = {'greater': -inf, 'less': inf} + _default_greater_keys = [ + 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU', + 'mAcc', 'aAcc' + ] + _default_less_keys = ['loss'] + + def __init__(self, + dataloader, + start=None, + interval=1, + by_epoch=True, + save_best=None, + rule=None, + test_fn=None, + greater_keys=None, + less_keys=None, + out_dir=None, + file_client_args=None, + **eval_kwargs): + if not isinstance(dataloader, DataLoader): + raise TypeError(f'dataloader must be a pytorch DataLoader, ' + f'but got {type(dataloader)}') + + if interval <= 0: + raise ValueError(f'interval must be a positive number, ' + f'but got {interval}') + + assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean' + + if start is not None and start < 0: + raise ValueError(f'The evaluation start epoch {start} is smaller ' + f'than 0') + + self.dataloader = dataloader + self.interval = interval + self.start = start + self.by_epoch = by_epoch + + assert isinstance(save_best, str) or save_best is None, \ + '""save_best"" should be a str or None ' \ + f'rather than {type(save_best)}' + self.save_best = save_best + self.eval_kwargs = eval_kwargs + self.initial_flag = True + + if test_fn is None: + from dbnet_cv.engine import single_gpu_test + self.test_fn = single_gpu_test + else: + self.test_fn = test_fn + + if greater_keys is None: + self.greater_keys = self._default_greater_keys + else: + if not isinstance(greater_keys, (list, tuple)): + greater_keys = (greater_keys, ) + assert is_seq_of(greater_keys, str) + self.greater_keys = greater_keys + + if less_keys is None: + self.less_keys = self._default_less_keys + else: + if not isinstance(less_keys, (list, tuple)): + less_keys = (less_keys, ) + assert is_seq_of(less_keys, str) + self.less_keys = less_keys + + if self.save_best is not None: + self.best_ckpt_path = None + self._init_rule(rule, self.save_best) + + self.out_dir = out_dir + self.file_client_args = file_client_args + + def _init_rule(self, rule, key_indicator): + """Initialize rule, key_indicator, comparison_func, and best score. + + Here is the rule to determine which rule is used for key indicator + when the rule is not specific (note that the key indicator matching + is case-insensitive): + 1. If the key indicator is in ``self.greater_keys``, the rule will be + specified as 'greater'. + 2. Or if the key indicator is in ``self.less_keys``, the rule will be + specified as 'less'. + 3. Or if any one item in ``self.greater_keys`` is a substring of + key_indicator , the rule will be specified as 'greater'. + 4. Or if any one item in ``self.less_keys`` is a substring of + key_indicator , the rule will be specified as 'less'. + + Args: + rule (str | None): Comparison rule for best score. + key_indicator (str | None): Key indicator to determine the + comparison rule. + """ + if rule not in self.rule_map and rule is not None: + raise KeyError(f'rule must be greater, less or None, ' + f'but got {rule}.') + + if rule is None: + if key_indicator != 'auto': + # `_lc` here means we use the lower case of keys for + # case-insensitive matching + key_indicator_lc = key_indicator.lower() + greater_keys = [key.lower() for key in self.greater_keys] + less_keys = [key.lower() for key in self.less_keys] + + if key_indicator_lc in greater_keys: + rule = 'greater' + elif key_indicator_lc in less_keys: + rule = 'less' + elif any(key in key_indicator_lc for key in greater_keys): + rule = 'greater' + elif any(key in key_indicator_lc for key in less_keys): + rule = 'less' + else: + raise ValueError(f'Cannot infer the rule for key ' + f'{key_indicator}, thus a specific rule ' + f'must be specified.') + self.rule = rule + self.key_indicator = key_indicator + if self.rule is not None: + self.compare_func = self.rule_map[self.rule] + + def before_run(self, runner): + if not self.out_dir: + self.out_dir = runner.work_dir + + self.file_client = FileClient.infer_client(self.file_client_args, + self.out_dir) + + # if `self.out_dir` is not equal to `runner.work_dir`, it means that + # `self.out_dir` is set so the final `self.out_dir` is the + # concatenation of `self.out_dir` and the last level directory of + # `runner.work_dir` + if self.out_dir != runner.work_dir: + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + self.out_dir = self.file_client.join_path(self.out_dir, basename) + runner.logger.info( + f'The best checkpoint will be saved to {self.out_dir} by ' + f'{self.file_client.name}') + + if self.save_best is not None: + if runner.meta is None: + warnings.warn('runner.meta is None. Creating an empty one.') + runner.meta = dict() + runner.meta.setdefault('hook_msgs', dict()) + self.best_ckpt_path = runner.meta['hook_msgs'].get( + 'best_ckpt', None) + + def before_train_iter(self, runner): + """Evaluate the model only at the start of training by iteration.""" + if self.by_epoch or not self.initial_flag: + return + if self.start is not None and runner.iter >= self.start: + self.after_train_iter(runner) + self.initial_flag = False + + def before_train_epoch(self, runner): + """Evaluate the model only at the start of training by epoch.""" + if not (self.by_epoch and self.initial_flag): + return + if self.start is not None and runner.epoch >= self.start: + self.after_train_epoch(runner) + self.initial_flag = False + + def after_train_iter(self, runner): + """Called after every training iter to evaluate the results.""" + if not self.by_epoch and self._should_evaluate(runner): + # Because the priority of EvalHook is higher than LoggerHook, the + # training log and the evaluating log are mixed. Therefore, + # we need to dump the training log and clear it before evaluating + # log is generated. In addition, this problem will only appear in + # `IterBasedRunner` whose `self.by_epoch` is False, because + # `EpochBasedRunner` whose `self.by_epoch` is True calls + # `_do_evaluate` in `after_train_epoch` stage, and at this stage + # the training log has been printed, so it will not cause any + # problem. more details at + # https://github.com/open-mmlab/mmsegmentation/issues/694 + for hook in runner._hooks: + if isinstance(hook, LoggerHook): + hook.after_train_iter(runner) + runner.log_buffer.clear() + + self._do_evaluate(runner) + + def after_train_epoch(self, runner): + """Called after every training epoch to evaluate the results.""" + if self.by_epoch and self._should_evaluate(runner): + self._do_evaluate(runner) + + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + results = self.test_fn(runner.model, self.dataloader) + runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) + key_score = self.evaluate(runner, results) + # the key_score may be `None` so it needs to skip the action to save + # the best checkpoint + if self.save_best and key_score: + self._save_ckpt(runner, key_score) + + def _should_evaluate(self, runner): + """Judge whether to perform evaluation. + + Here is the rule to judge whether to perform evaluation: + 1. It will not perform evaluation during the epoch/iteration interval, + which is determined by ``self.interval``. + 2. It will not perform evaluation if the start time is larger than + current time. + 3. It will not perform evaluation when current time is larger than + the start time but during epoch/iteration interval. + + Returns: + bool: The flag indicating whether to perform evaluation. + """ + if self.by_epoch: + current = runner.epoch + check_time = self.every_n_epochs + else: + current = runner.iter + check_time = self.every_n_iters + + if self.start is None: + if not check_time(runner, self.interval): + # No evaluation during the interval. + return False + elif (current + 1) < self.start: + # No evaluation if start is larger than the current time. + return False + else: + # Evaluation only at epochs/iters 3, 5, 7... + # if start==3 and interval==2 + if (current + 1 - self.start) % self.interval: + return False + return True + + def _save_ckpt(self, runner, key_score): + """Save the best checkpoint. + + It will compare the score according to the compare function, write + related information (best score, best checkpoint path) and save the + best checkpoint into ``work_dir``. + """ + if self.by_epoch: + current = f'epoch_{runner.epoch + 1}' + cur_type, cur_time = 'epoch', runner.epoch + 1 + else: + current = f'iter_{runner.iter + 1}' + cur_type, cur_time = 'iter', runner.iter + 1 + + best_score = runner.meta['hook_msgs'].get( + 'best_score', self.init_value_map[self.rule]) + if self.compare_func(key_score, best_score): + best_score = key_score + runner.meta['hook_msgs']['best_score'] = best_score + + if self.best_ckpt_path and self.file_client.isfile( + self.best_ckpt_path): + self.file_client.remove(self.best_ckpt_path) + runner.logger.info( + f'The previous best checkpoint {self.best_ckpt_path} was ' + 'removed') + + best_ckpt_name = f'best_{self.key_indicator}_{current}.pth' + self.best_ckpt_path = self.file_client.join_path( + self.out_dir, best_ckpt_name) + runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path + + runner.save_checkpoint( + self.out_dir, + filename_tmpl=best_ckpt_name, + create_symlink=False) + runner.logger.info( + f'Now best checkpoint is saved as {best_ckpt_name}.') + runner.logger.info( + f'Best {self.key_indicator} is {best_score:0.4f} ' + f'at {cur_time} {cur_type}.') + + def evaluate(self, runner, results): + """Evaluate the results. + + Args: + runner (:obj:`dbnet_cv.Runner`): The underlined training runner. + results (list): Output results. + """ + eval_res = self.dataloader.dataset.evaluate( + results, logger=runner.logger, **self.eval_kwargs) + + for name, val in eval_res.items(): + runner.log_buffer.output[name] = val + runner.log_buffer.ready = True + + if self.save_best is not None: + # If the performance of model is pool, the `eval_res` may be an + # empty dict and it will raise exception when `self.save_best` is + # not None. More details at + # https://github.com/open-mmlab/dbnet_detection/issues/6265. + if not eval_res: + warnings.warn( + 'Since `eval_res` is an empty dict, the behavior to save ' + 'the best checkpoint will be skipped in this evaluation.') + return None + + if self.key_indicator == 'auto': + # infer from eval_results + self._init_rule(self.rule, list(eval_res.keys())[0]) + return eval_res[self.key_indicator] + + return None + + +class DistEvalHook(EvalHook): + """Distributed evaluation hook. + + This hook will regularly perform evaluation in a given interval when + performing in distributed environment. + + Args: + dataloader (DataLoader): A PyTorch dataloader, whose dataset has + implemented ``evaluate`` function. + start (int | None, optional): Evaluation starting epoch. It enables + evaluation before the training starts if ``start`` <= the resuming + epoch. If None, whether to evaluate is merely decided by + ``interval``. Default: None. + interval (int): Evaluation interval. Default: 1. + by_epoch (bool): Determine perform evaluation by epoch or by iteration. + If set to True, it will perform by epoch. Otherwise, by iteration. + default: True. + save_best (str, optional): If a metric is specified, it would measure + the best checkpoint during evaluation. The information about best + checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep + best score value and best checkpoint path, which will be also + loaded when resume checkpoint. Options are the evaluation metrics + on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox + detection and instance segmentation. ``AR@100`` for proposal + recall. If ``save_best`` is ``auto``, the first key of the returned + ``OrderedDict`` result will be used. Default: None. + rule (str | None, optional): Comparison rule for best score. If set to + None, it will infer a reasonable rule. Keys such as 'acc', 'top' + .etc will be inferred by 'greater' rule. Keys contain 'loss' will + be inferred by 'less' rule. Options are 'greater', 'less', None. + Default: None. + test_fn (callable, optional): test a model with samples from a + dataloader in a multi-gpu manner, and return the test results. If + ``None``, the default test function ``dbnet_cv.engine.multi_gpu_test`` + will be used. (default: ``None``) + tmpdir (str | None): Temporary directory to save the results of all + processes. Default: None. + gpu_collect (bool): Whether to use gpu or cpu to collect results. + Default: False. + broadcast_bn_buffer (bool): Whether to broadcast the + buffer(running_mean and running_var) of rank 0 to other rank + before evaluation. Default: True. + out_dir (str, optional): The root directory to save checkpoints. If not + specified, `runner.work_dir` will be used by default. If specified, + the `out_dir` will be the concatenation of `out_dir` and the last + level directory of `runner.work_dir`. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`dbnet_cv.fileio.FileClient` for details. Default: None. + **eval_kwargs: Evaluation arguments fed into the evaluate function of + the dataset. + """ + + def __init__(self, + dataloader, + start=None, + interval=1, + by_epoch=True, + save_best=None, + rule=None, + test_fn=None, + greater_keys=None, + less_keys=None, + broadcast_bn_buffer=True, + tmpdir=None, + gpu_collect=False, + out_dir=None, + file_client_args=None, + **eval_kwargs): + + if test_fn is None: + from dbnet_cv.engine import multi_gpu_test + test_fn = multi_gpu_test + + super().__init__( + dataloader, + start=start, + interval=interval, + by_epoch=by_epoch, + save_best=save_best, + rule=rule, + test_fn=test_fn, + greater_keys=greater_keys, + less_keys=less_keys, + out_dir=out_dir, + file_client_args=file_client_args, + **eval_kwargs) + + self.broadcast_bn_buffer = broadcast_bn_buffer + self.tmpdir = tmpdir + self.gpu_collect = gpu_collect + + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + # Synchronization of BatchNorm's buffer (running_mean + # and running_var) is not supported in the DDP of pytorch, + # which may cause the inconsistent performance of models in + # different ranks, so we broadcast BatchNorm's buffers + # of rank 0 to other ranks to avoid this. + if self.broadcast_bn_buffer: + model = runner.model + for name, module in model.named_modules(): + if isinstance(module, + _BatchNorm) and module.track_running_stats: + dist.broadcast(module.running_var, 0) + dist.broadcast(module.running_mean, 0) + + tmpdir = self.tmpdir + if tmpdir is None: + tmpdir = osp.join(runner.work_dir, '.eval_hook') + + results = self.test_fn( + runner.model, + self.dataloader, + tmpdir=tmpdir, + gpu_collect=self.gpu_collect) + if runner.rank == 0: + print('\n') + runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) + key_score = self.evaluate(runner, results) + # the key_score may be `None` so it needs to skip the action to + # save the best checkpoint + if self.save_best and key_score: + self._save_ckpt(runner, key_score) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/hook.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/hook.py new file mode 100755 index 00000000..3989f6ef --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/hook.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dbnet_cv.utils import Registry, is_method_overridden + +HOOKS = Registry('hook') + + +class Hook: + stages = ('before_run', 'before_train_epoch', 'before_train_iter', + 'after_train_iter', 'after_train_epoch', 'before_val_epoch', + 'before_val_iter', 'after_val_iter', 'after_val_epoch', + 'after_run') + + def before_run(self, runner): + pass + + def after_run(self, runner): + pass + + def before_epoch(self, runner): + pass + + def after_epoch(self, runner): + pass + + def before_iter(self, runner): + pass + + def after_iter(self, runner): + pass + + def before_train_epoch(self, runner): + self.before_epoch(runner) + + def before_val_epoch(self, runner): + self.before_epoch(runner) + + def after_train_epoch(self, runner): + self.after_epoch(runner) + + def after_val_epoch(self, runner): + self.after_epoch(runner) + + def before_train_iter(self, runner): + self.before_iter(runner) + + def before_val_iter(self, runner): + self.before_iter(runner) + + def after_train_iter(self, runner): + self.after_iter(runner) + + def after_val_iter(self, runner): + self.after_iter(runner) + + def every_n_epochs(self, runner, n): + return (runner.epoch + 1) % n == 0 if n > 0 else False + + def every_n_inner_iters(self, runner, n): + return (runner.inner_iter + 1) % n == 0 if n > 0 else False + + def every_n_iters(self, runner, n): + return (runner.iter + 1) % n == 0 if n > 0 else False + + def end_of_epoch(self, runner): + return runner.inner_iter + 1 == len(runner.data_loader) + + def is_last_epoch(self, runner): + return runner.epoch + 1 == runner._max_epochs + + def is_last_iter(self, runner): + return runner.iter + 1 == runner._max_iters + + def get_triggered_stages(self): + trigger_stages = set() + for stage in Hook.stages: + if is_method_overridden(stage, Hook, self): + trigger_stages.add(stage) + + # some methods will be triggered in multi stages + # use this dict to map method to stages. + method_stages_map = { + 'before_epoch': ['before_train_epoch', 'before_val_epoch'], + 'after_epoch': ['after_train_epoch', 'after_val_epoch'], + 'before_iter': ['before_train_iter', 'before_val_iter'], + 'after_iter': ['after_train_iter', 'after_val_iter'], + } + + for method, map_stages in method_stages_map.items(): + if is_method_overridden(method, Hook, self): + trigger_stages.update(map_stages) + + return [stage for stage in Hook.stages if stage in trigger_stages] diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/iter_timer.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/iter_timer.py new file mode 100755 index 00000000..cfd5002f --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/iter_timer.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import time + +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class IterTimerHook(Hook): + + def before_epoch(self, runner): + self.t = time.time() + + def before_iter(self, runner): + runner.log_buffer.update({'data_time': time.time() - self.t}) + + def after_iter(self, runner): + runner.log_buffer.update({'time': time.time() - self.t}) + self.t = time.time() diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/__init__.py new file mode 100755 index 00000000..062709e7 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import LoggerHook +from .clearml import ClearMLLoggerHook +from .dvclive import DvcliveLoggerHook +from .mlflow import MlflowLoggerHook +from .neptune import NeptuneLoggerHook +from .pavi import PaviLoggerHook +from .segmind import SegmindLoggerHook +from .tensorboard import TensorboardLoggerHook +from .text import TextLoggerHook +from .wandb import WandbLoggerHook + +__all__ = [ + 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook', + 'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook', + 'NeptuneLoggerHook', 'DvcliveLoggerHook', 'SegmindLoggerHook', + 'ClearMLLoggerHook' +] diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/base.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/base.py new file mode 100755 index 00000000..416a1b75 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/base.py @@ -0,0 +1,172 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numbers +from abc import ABCMeta, abstractmethod +from typing import Dict + +import numpy as np +import torch + +from ..hook import Hook + + +class LoggerHook(Hook): + """Base class for logger hooks. + + Args: + interval (int): Logging interval (every k iterations). Default 10. + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. Default True. + reset_flag (bool): Whether to clear the output buffer after logging. + Default False. + by_epoch (bool): Whether EpochBasedRunner is used. Default True. + """ + + __metaclass__ = ABCMeta + + def __init__(self, + interval: int = 10, + ignore_last: bool = True, + reset_flag: bool = False, + by_epoch: bool = True): + self.interval = interval + self.ignore_last = ignore_last + self.reset_flag = reset_flag + self.by_epoch = by_epoch + + @abstractmethod + def log(self, runner): + pass + + @staticmethod + def is_scalar(val, + include_np: bool = True, + include_torch: bool = True) -> bool: + """Tell the input variable is a scalar or not. + + Args: + val: Input variable. + include_np (bool): Whether include 0-d np.ndarray as a scalar. + include_torch (bool): Whether include 0-d torch.Tensor as a scalar. + + Returns: + bool: True or False. + """ + if isinstance(val, numbers.Number): + return True + elif include_np and isinstance(val, np.ndarray) and val.ndim == 0: + return True + elif include_torch and isinstance(val, torch.Tensor) and len(val) == 1: + return True + else: + return False + + def get_mode(self, runner) -> str: + if runner.mode == 'train': + if 'time' in runner.log_buffer.output: + mode = 'train' + else: + mode = 'val' + elif runner.mode == 'val': + mode = 'val' + else: + raise ValueError(f"runner mode should be 'train' or 'val', " + f'but got {runner.mode}') + return mode + + def get_epoch(self, runner) -> int: + if runner.mode == 'train': + epoch = runner.epoch + 1 + elif runner.mode == 'val': + # normal val mode + # runner.epoch += 1 has been done before val workflow + epoch = runner.epoch + else: + raise ValueError(f"runner mode should be 'train' or 'val', " + f'but got {runner.mode}') + return epoch + + def get_iter(self, runner, inner_iter: bool = False) -> int: + """Get the current training iteration step.""" + if self.by_epoch and inner_iter: + current_iter = runner.inner_iter + 1 + else: + current_iter = runner.iter + 1 + return current_iter + + def get_lr_tags(self, runner) -> Dict[str, float]: + tags = {} + lrs = runner.current_lr() + if isinstance(lrs, dict): + for name, value in lrs.items(): + tags[f'learning_rate/{name}'] = value[0] + else: + tags['learning_rate'] = lrs[0] + return tags + + def get_momentum_tags(self, runner) -> Dict[str, float]: + tags = {} + momentums = runner.current_momentum() + if isinstance(momentums, dict): + for name, value in momentums.items(): + tags[f'momentum/{name}'] = value[0] + else: + tags['momentum'] = momentums[0] + return tags + + def get_loggable_tags( + self, + runner, + allow_scalar: bool = True, + allow_text: bool = False, + add_mode: bool = True, + tags_to_skip: tuple = ('time', 'data_time') + ) -> Dict: + tags = {} + for var, val in runner.log_buffer.output.items(): + if var in tags_to_skip: + continue + if self.is_scalar(val) and not allow_scalar: + continue + if isinstance(val, str) and not allow_text: + continue + if add_mode: + var = f'{self.get_mode(runner)}/{var}' + tags[var] = val + tags.update(self.get_lr_tags(runner)) + tags.update(self.get_momentum_tags(runner)) + return tags + + def before_run(self, runner) -> None: + for hook in runner.hooks[::-1]: + if isinstance(hook, LoggerHook): + hook.reset_flag = True + break + + def before_epoch(self, runner) -> None: + runner.log_buffer.clear() # clear logs of last epoch + + def after_train_iter(self, runner) -> None: + if self.by_epoch and self.every_n_inner_iters(runner, self.interval): + runner.log_buffer.average(self.interval) + elif not self.by_epoch and self.every_n_iters(runner, self.interval): + runner.log_buffer.average(self.interval) + elif self.end_of_epoch(runner) and not self.ignore_last: + # not precise but more stable + runner.log_buffer.average(self.interval) + + if runner.log_buffer.ready: + self.log(runner) + if self.reset_flag: + runner.log_buffer.clear_output() + + def after_train_epoch(self, runner) -> None: + if runner.log_buffer.ready: + self.log(runner) + if self.reset_flag: + runner.log_buffer.clear_output() + + def after_val_epoch(self, runner) -> None: + runner.log_buffer.average() + self.log(runner) + if self.reset_flag: + runner.log_buffer.clear_output() diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/clearml.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/clearml.py new file mode 100755 index 00000000..7db651f0 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/clearml.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Dict, Optional + +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class ClearMLLoggerHook(LoggerHook): + """Class to log metrics with clearml. + + It requires `clearml`_ to be installed. + + + Args: + init_kwargs (dict): A dict contains the `clearml.Task.init` + initialization keys. See `taskinit`_ for more details. + interval (int): Logging interval (every k iterations). Default 10. + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. Default: True. + reset_flag (bool): Whether to clear the output buffer after logging. + Default: False. + by_epoch (bool): Whether EpochBasedRunner is used. Default: True. + + .. _clearml: + https://clear.ml/docs/latest/docs/ + .. _taskinit: + https://clear.ml/docs/latest/docs/references/sdk/task/#taskinit + """ + + def __init__(self, + init_kwargs: Optional[Dict] = None, + interval: int = 10, + ignore_last: bool = True, + reset_flag: bool = False, + by_epoch: bool = True): + super().__init__(interval, ignore_last, reset_flag, by_epoch) + self.import_clearml() + self.init_kwargs = init_kwargs + + def import_clearml(self): + try: + import clearml + except ImportError: + raise ImportError( + 'Please run "pip install clearml" to install clearml') + self.clearml = clearml + + @master_only + def before_run(self, runner) -> None: + super().before_run(runner) + task_kwargs = self.init_kwargs if self.init_kwargs else {} + self.task = self.clearml.Task.init(**task_kwargs) + self.task_logger = self.task.get_logger() + + @master_only + def log(self, runner) -> None: + tags = self.get_loggable_tags(runner) + for tag, val in tags.items(): + self.task_logger.report_scalar(tag, tag, val, + self.get_iter(runner)) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/dvclive.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/dvclive.py new file mode 100755 index 00000000..fc0a58c4 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/dvclive.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Optional + +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class DvcliveLoggerHook(LoggerHook): + """Class to log metrics with dvclive. + + It requires `dvclive`_ to be installed. + + Args: + model_file (str): Default None. If not None, after each epoch the + model will be saved to {model_file}. + interval (int): Logging interval (every k iterations). Default 10. + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. Default: True. + reset_flag (bool): Whether to clear the output buffer after logging. + Default: False. + by_epoch (bool): Whether EpochBasedRunner is used. Default: True. + kwargs: Arguments for instantiating `Live`_. + + .. _dvclive: + https://dvc.org/doc/dvclive + + .. _Live: + https://dvc.org/doc/dvclive/api-reference/live#parameters + """ + + def __init__(self, + model_file: Optional[str] = None, + interval: int = 10, + ignore_last: bool = True, + reset_flag: bool = False, + by_epoch: bool = True, + **kwargs): + super().__init__(interval, ignore_last, reset_flag, by_epoch) + self.model_file = model_file + self.import_dvclive(**kwargs) + + def import_dvclive(self, **kwargs) -> None: + try: + from dvclive import Live + except ImportError: + raise ImportError( + 'Please run "pip install dvclive" to install dvclive') + self.dvclive = Live(**kwargs) + + @master_only + def log(self, runner) -> None: + tags = self.get_loggable_tags(runner) + if tags: + self.dvclive.set_step(self.get_iter(runner)) + for k, v in tags.items(): + self.dvclive.log(k, v) + + @master_only + def after_train_epoch(self, runner) -> None: + super().after_train_epoch(runner) + if self.model_file is not None: + runner.save_checkpoint( + Path(self.model_file).parent, + filename_tmpl=Path(self.model_file).name, + create_symlink=False, + ) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/mlflow.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/mlflow.py new file mode 100755 index 00000000..4ae63fe0 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/mlflow.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional + +from dbnet_cv.utils import TORCH_VERSION +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class MlflowLoggerHook(LoggerHook): + """Class to log metrics and (optionally) a trained model to MLflow. + + It requires `MLflow`_ to be installed. + + Args: + exp_name (str, optional): Name of the experiment to be used. + Default None. If not None, set the active experiment. + If experiment does not exist, an experiment with provided name + will be created. + tags (Dict[str], optional): Tags for the current run. + Default None. If not None, set tags for the current run. + log_model (bool, optional): Whether to log an MLflow artifact. + Default True. If True, log runner.model as an MLflow artifact + for the current run. + interval (int): Logging interval (every k iterations). Default: 10. + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. Default: True. + reset_flag (bool): Whether to clear the output buffer after logging. + Default: False. + by_epoch (bool): Whether EpochBasedRunner is used. Default: True. + + .. _MLflow: + https://www.mlflow.org/docs/latest/index.html + """ + + def __init__(self, + exp_name: Optional[str] = None, + tags: Optional[Dict] = None, + log_model: bool = True, + interval: int = 10, + ignore_last: bool = True, + reset_flag: bool = False, + by_epoch: bool = True): + super().__init__(interval, ignore_last, reset_flag, by_epoch) + self.import_mlflow() + self.exp_name = exp_name + self.tags = tags + self.log_model = log_model + + def import_mlflow(self) -> None: + try: + import mlflow + import mlflow.pytorch as mlflow_pytorch + except ImportError: + raise ImportError( + 'Please run "pip install mlflow" to install mlflow') + self.mlflow = mlflow + self.mlflow_pytorch = mlflow_pytorch + + @master_only + def before_run(self, runner) -> None: + super().before_run(runner) + if self.exp_name is not None: + self.mlflow.set_experiment(self.exp_name) + if self.tags is not None: + self.mlflow.set_tags(self.tags) + + @master_only + def log(self, runner) -> None: + tags = self.get_loggable_tags(runner) + if tags: + self.mlflow.log_metrics(tags, step=self.get_iter(runner)) + + @master_only + def after_run(self, runner) -> None: + if self.log_model: + self.mlflow_pytorch.log_model( + runner.model, + 'models', + pip_requirements=[f'torch=={TORCH_VERSION}']) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/neptune.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/neptune.py new file mode 100755 index 00000000..e398fe1e --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/neptune.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional + +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class NeptuneLoggerHook(LoggerHook): + """Class to log metrics to NeptuneAI. + + It requires `Neptune`_ to be installed. + + Args: + init_kwargs (dict): a dict contains the initialization keys as below: + + - project (str): Name of a project in a form of + namespace/project_name. If None, the value of NEPTUNE_PROJECT + environment variable will be taken. + - api_token (str): User’s API token. If None, the value of + NEPTUNE_API_TOKEN environment variable will be taken. Note: It is + strongly recommended to use NEPTUNE_API_TOKEN environment + variable rather than placing your API token in plain text in your + source code. + - name (str, optional, default is 'Untitled'): Editable name of the + run. Name is displayed in the run's Details and in Runs table as + a column. + + Check https://docs.neptune.ai/api-reference/neptune#init for more + init arguments. + interval (int): Logging interval (every k iterations). Default: 10. + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than ``interval``. Default: True. + reset_flag (bool): Whether to clear the output buffer after logging. + Default: True. + with_step (bool): If True, the step will be logged from + ``self.get_iters``. Otherwise, step will not be logged. + Default: True. + by_epoch (bool): Whether EpochBasedRunner is used. Default: True. + + .. _Neptune: + https://docs.neptune.ai + """ + + def __init__(self, + init_kwargs: Optional[Dict] = None, + interval: int = 10, + ignore_last: bool = True, + reset_flag: bool = True, + with_step: bool = True, + by_epoch: bool = True): + + super().__init__(interval, ignore_last, reset_flag, by_epoch) + self.import_neptune() + self.init_kwargs = init_kwargs + self.with_step = with_step + + def import_neptune(self) -> None: + try: + import neptune.new as neptune + except ImportError: + raise ImportError( + 'Please run "pip install neptune-client" to install neptune') + self.neptune = neptune + self.run = None + + @master_only + def before_run(self, runner) -> None: + if self.init_kwargs: + self.run = self.neptune.init(**self.init_kwargs) + else: + self.run = self.neptune.init() + + @master_only + def log(self, runner) -> None: + tags = self.get_loggable_tags(runner) + if tags: + for tag_name, tag_value in tags.items(): + if self.with_step: + self.run[tag_name].log( # type: ignore + tag_value, step=self.get_iter(runner)) + else: + tags['global_step'] = self.get_iter(runner) + self.run[tag_name].log(tags) # type: ignore + + @master_only + def after_run(self, runner) -> None: + self.run.stop() # type: ignore diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/pavi.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/pavi.py new file mode 100755 index 00000000..0fdf2911 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/pavi.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os +import os.path as osp +from typing import Dict, Optional + +import torch +import yaml + +import dbnet_cv +from ....parallel.utils import is_module_wrapper +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class PaviLoggerHook(LoggerHook): + """Class to visual model, log metrics (for internal use). + + Args: + init_kwargs (dict): A dict contains the initialization keys. + add_graph (bool): Whether to visual model. Default: False. + add_last_ckpt (bool): Whether to save checkpoint after run. + Default: False. + interval (int): Logging interval (every k iterations). Default: True. + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. Default: True. + reset_flag (bool): Whether to clear the output buffer after logging. + Default: False. + by_epoch (bool): Whether EpochBasedRunner is used. Default: True. + img_key (string): Get image data from Dataset. Default: 'img_info'. + """ + + def __init__(self, + init_kwargs: Optional[Dict] = None, + add_graph: bool = False, + add_last_ckpt: bool = False, + interval: int = 10, + ignore_last: bool = True, + reset_flag: bool = False, + by_epoch: bool = True, + img_key: str = 'img_info'): + super().__init__(interval, ignore_last, reset_flag, by_epoch) + self.init_kwargs = init_kwargs + self.add_graph = add_graph + self.add_last_ckpt = add_last_ckpt + self.img_key = img_key + + @master_only + def before_run(self, runner) -> None: + super().before_run(runner) + try: + from pavi import SummaryWriter + except ImportError: + raise ImportError('Please run "pip install pavi" to install pavi.') + + self.run_name = runner.work_dir.split('/')[-1] + + if not self.init_kwargs: + self.init_kwargs = dict() + self.init_kwargs['name'] = self.run_name + self.init_kwargs['model'] = runner._model_name + if runner.meta is not None: + if 'config_dict' in runner.meta: + config_dict = runner.meta['config_dict'] + assert isinstance( + config_dict, + dict), ('meta["config_dict"] has to be of a dict, ' + f'but got {type(config_dict)}') + elif 'config_file' in runner.meta: + config_file = runner.meta['config_file'] + config_dict = dict(dbnet_cv.Config.fromfile(config_file)) + else: + config_dict = None + if config_dict is not None: + # 'max_.*iter' is parsed in pavi sdk as the maximum iterations + # to properly set up the progress bar. + config_dict = config_dict.copy() + config_dict.setdefault('max_iter', runner.max_iters) + # non-serializable values are first converted in + # dbnet_cv.dump to json + config_dict = json.loads( + dbnet_cv.dump(config_dict, file_format='json')) + session_text = yaml.dump(config_dict) + self.init_kwargs['session_text'] = session_text + self.writer = SummaryWriter(**self.init_kwargs) + + def get_step(self, runner) -> int: + """Get the total training step/epoch.""" + if self.get_mode(runner) == 'val' and self.by_epoch: + return self.get_epoch(runner) + else: + return self.get_iter(runner) + + @master_only + def log(self, runner) -> None: + tags = self.get_loggable_tags(runner, add_mode=False) + if tags: + self.writer.add_scalars( + self.get_mode(runner), tags, self.get_step(runner)) + + @master_only + def after_run(self, runner) -> None: + if self.add_last_ckpt: + ckpt_path = osp.join(runner.work_dir, 'latest.pth') + if osp.islink(ckpt_path): + ckpt_path = osp.join(runner.work_dir, os.readlink(ckpt_path)) + + if osp.isfile(ckpt_path): + # runner.epoch += 1 has been done before `after_run`. + iteration = runner.epoch if self.by_epoch else runner.iter + return self.writer.add_snapshot_file( + tag=self.run_name, + snapshot_file_path=ckpt_path, + iteration=iteration) + + # flush the buffer and send a task ending signal to Pavi + self.writer.close() + + @master_only + def before_epoch(self, runner) -> None: + if runner.epoch == 0 and self.add_graph: + if is_module_wrapper(runner.model): + _model = runner.model.module + else: + _model = runner.model + device = next(_model.parameters()).device + data = next(iter(runner.data_loader)) + image = data[self.img_key][0:1].to(device) + with torch.no_grad(): + self.writer.add_graph(_model, image) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/segmind.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/segmind.py new file mode 100755 index 00000000..ecb3751e --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/segmind.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class SegmindLoggerHook(LoggerHook): + """Class to log metrics to Segmind. + + It requires `Segmind`_ to be installed. + + Args: + interval (int): Logging interval (every k iterations). Default: 10. + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. Default True. + reset_flag (bool): Whether to clear the output buffer after logging. + Default False. + by_epoch (bool): Whether EpochBasedRunner is used. Default True. + + .. _Segmind: + https://docs.segmind.com/python-library + """ + + def __init__(self, + interval: int = 10, + ignore_last: bool = True, + reset_flag: bool = False, + by_epoch=True): + super().__init__(interval, ignore_last, reset_flag, by_epoch) + self.import_segmind() + + def import_segmind(self) -> None: + try: + import segmind + except ImportError: + raise ImportError( + "Please run 'pip install segmind' to install segmind") + self.log_metrics = segmind.tracking.fluent.log_metrics + self.mlflow_log = segmind.utils.logging_utils.try_mlflow_log + + @master_only + def log(self, runner) -> None: + tags = self.get_loggable_tags(runner) + if tags: + # logging metrics to segmind + self.mlflow_log( + self.log_metrics, tags, step=runner.epoch, epoch=runner.epoch) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/tensorboard.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/tensorboard.py new file mode 100755 index 00000000..a4786b8e --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/tensorboard.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Optional + +from dbnet_cv.utils import TORCH_VERSION, digit_version +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class TensorboardLoggerHook(LoggerHook): + """Class to log metrics to Tensorboard. + + Args: + log_dir (string): Save directory location. Default: None. If default + values are used, directory location is ``runner.work_dir``/tf_logs. + interval (int): Logging interval (every k iterations). Default: True. + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. Default: True. + reset_flag (bool): Whether to clear the output buffer after logging. + Default: False. + by_epoch (bool): Whether EpochBasedRunner is used. Default: True. + """ + + def __init__(self, + log_dir: Optional[str] = None, + interval: int = 10, + ignore_last: bool = True, + reset_flag: bool = False, + by_epoch: bool = True): + super().__init__(interval, ignore_last, reset_flag, by_epoch) + self.log_dir = log_dir + + @master_only + def before_run(self, runner) -> None: + super().before_run(runner) + if (TORCH_VERSION == 'parrots' + or digit_version(TORCH_VERSION) < digit_version('1.1')): + try: + from tensorboardX import SummaryWriter + except ImportError: + raise ImportError('Please install tensorboardX to use ' + 'TensorboardLoggerHook.') + else: + try: + from torch.utils.tensorboard import SummaryWriter + except ImportError: + raise ImportError( + 'Please run "pip install future tensorboard" to install ' + 'the dependencies to use torch.utils.tensorboard ' + '(applicable to PyTorch 1.1 or higher)') + + if self.log_dir is None: + self.log_dir = osp.join(runner.work_dir, 'tf_logs') + self.writer = SummaryWriter(self.log_dir) + + @master_only + def log(self, runner) -> None: + tags = self.get_loggable_tags(runner, allow_text=True) + for tag, val in tags.items(): + if isinstance(val, str): + self.writer.add_text(tag, val, self.get_iter(runner)) + else: + self.writer.add_scalar(tag, val, self.get_iter(runner)) + + @master_only + def after_run(self, runner) -> None: + self.writer.close() diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/text.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/text.py new file mode 100755 index 00000000..be30c7b9 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/text.py @@ -0,0 +1,256 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import os +import os.path as osp +from collections import OrderedDict +from typing import Dict, Optional, Union + +import torch +import torch.distributed as dist + +import dbnet_cv +from dbnet_cv.fileio.file_client import FileClient +from dbnet_cv.utils import is_tuple_of, scandir +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class TextLoggerHook(LoggerHook): + """Logger hook in text. + + In this logger hook, the information will be printed on terminal and + saved in json file. + + Args: + by_epoch (bool, optional): Whether EpochBasedRunner is used. + Default: True. + interval (int, optional): Logging interval (every k iterations). + Default: 10. + ignore_last (bool, optional): Ignore the log of last iterations in each + epoch if less than :attr:`interval`. Default: True. + reset_flag (bool, optional): Whether to clear the output buffer after + logging. Default: False. + interval_exp_name (int, optional): Logging interval for experiment + name. This feature is to help users conveniently get the experiment + information from screen or log file. Default: 1000. + out_dir (str, optional): Logs are saved in ``runner.work_dir`` default. + If ``out_dir`` is specified, logs will be copied to a new directory + which is the concatenation of ``out_dir`` and the last level + directory of ``runner.work_dir``. Default: None. + `New in version 1.3.16.` + out_suffix (str or tuple[str], optional): Those filenames ending with + ``out_suffix`` will be copied to ``out_dir``. + Default: ('.log.json', '.log', '.py'). + `New in version 1.3.16.` + keep_local (bool, optional): Whether to keep local log when + :attr:`out_dir` is specified. If False, the local log will be + removed. Default: True. + `New in version 1.3.16.` + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`dbnet_cv.fileio.FileClient` for details. + Default: None. + `New in version 1.3.16.` + """ + + def __init__(self, + by_epoch: bool = True, + interval: int = 10, + ignore_last: bool = True, + reset_flag: bool = False, + interval_exp_name: int = 1000, + out_dir: Optional[str] = None, + out_suffix: Union[str, tuple] = ('.log.json', '.log', '.py'), + keep_local: bool = True, + file_client_args: Optional[Dict] = None): + super().__init__(interval, ignore_last, reset_flag, by_epoch) + self.by_epoch = by_epoch + self.time_sec_tot = 0 + self.interval_exp_name = interval_exp_name + + if out_dir is None and file_client_args is not None: + raise ValueError( + 'file_client_args should be "None" when `out_dir` is not' + 'specified.') + self.out_dir = out_dir + + if not (out_dir is None or isinstance(out_dir, str) + or is_tuple_of(out_dir, str)): + raise TypeError('out_dir should be "None" or string or tuple of ' + 'string, but got {out_dir}') + self.out_suffix = out_suffix + + self.keep_local = keep_local + self.file_client_args = file_client_args + if self.out_dir is not None: + self.file_client = FileClient.infer_client(file_client_args, + self.out_dir) + + def before_run(self, runner) -> None: + super().before_run(runner) + + if self.out_dir is not None: + self.file_client = FileClient.infer_client(self.file_client_args, + self.out_dir) + # The final `self.out_dir` is the concatenation of `self.out_dir` + # and the last level directory of `runner.work_dir` + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + self.out_dir = self.file_client.join_path(self.out_dir, basename) + runner.logger.info( + f'Text logs will be saved to {self.out_dir} by ' + f'{self.file_client.name} after the training process.') + + self.start_iter = runner.iter + self.json_log_path = osp.join(runner.work_dir, + f'{runner.timestamp}.log.json') + if runner.meta is not None: + self._dump_log(runner.meta, runner) + + def _get_max_memory(self, runner) -> int: + device = getattr(runner.model, 'output_device', None) + mem = torch.cuda.max_memory_allocated(device=device) + mem_mb = torch.tensor([int(mem) // (1024 * 1024)], + dtype=torch.int, + device=device) + if runner.world_size > 1: + dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX) + return mem_mb.item() + + def _log_info(self, log_dict: Dict, runner) -> None: + # print exp name for users to distinguish experiments + # at every ``interval_exp_name`` iterations and the end of each epoch + if runner.meta is not None and 'exp_name' in runner.meta: + if (self.every_n_iters(runner, self.interval_exp_name)) or ( + self.by_epoch and self.end_of_epoch(runner)): + exp_info = f'Exp name: {runner.meta["exp_name"]}' + runner.logger.info(exp_info) + + if log_dict['mode'] == 'train': + if isinstance(log_dict['lr'], dict): + lr_str = [] + for k, val in log_dict['lr'].items(): + lr_str.append(f'lr_{k}: {val:.3e}') + lr_str = ' '.join(lr_str) # type: ignore + else: + lr_str = f'lr: {log_dict["lr"]:.3e}' # type: ignore + + # by epoch: Epoch [4][100/1000] + # by iter: Iter [100/100000] + if self.by_epoch: + log_str = f'Epoch [{log_dict["epoch"]}]' \ + f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t' + else: + log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t' + log_str += f'{lr_str}, ' + + if 'time' in log_dict.keys(): + self.time_sec_tot += (log_dict['time'] * self.interval) + time_sec_avg = self.time_sec_tot / ( + runner.iter - self.start_iter + 1) + eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + log_str += f'eta: {eta_str}, ' + log_str += f'time: {log_dict["time"]:.3f}, ' \ + f'data_time: {log_dict["data_time"]:.3f}, ' + # statistic memory + if torch.cuda.is_available(): + log_str += f'memory: {log_dict["memory"]}, ' + else: + # val/test time + # here 1000 is the length of the val dataloader + # by epoch: Epoch[val] [4][1000] + # by iter: Iter[val] [1000] + if self.by_epoch: + log_str = f'Epoch({log_dict["mode"]}) ' \ + f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t' + else: + log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t' + + log_items = [] + for name, val in log_dict.items(): + # TODO: resolve this hack + # these items have been in log_str + if name in [ + 'mode', 'Epoch', 'iter', 'lr', 'time', 'data_time', + 'memory', 'epoch' + ]: + continue + if isinstance(val, float): + val = f'{val:.4f}' + log_items.append(f'{name}: {val}') + log_str += ', '.join(log_items) + + runner.logger.info(log_str) + + def _dump_log(self, log_dict: Dict, runner) -> None: + # dump log in json format + json_log = OrderedDict() + for k, v in log_dict.items(): + json_log[k] = self._round_float(v) + # only append log at last line + if runner.rank == 0: + with open(self.json_log_path, 'a+') as f: + dbnet_cv.dump(json_log, f, file_format='json') + f.write('\n') + + def _round_float(self, items): + if isinstance(items, list): + return [self._round_float(item) for item in items] + elif isinstance(items, float): + return round(items, 5) + else: + return items + + def log(self, runner) -> OrderedDict: + if 'eval_iter_num' in runner.log_buffer.output: + # this doesn't modify runner.iter and is regardless of by_epoch + cur_iter = runner.log_buffer.output.pop('eval_iter_num') + else: + cur_iter = self.get_iter(runner, inner_iter=True) + + log_dict = OrderedDict( + mode=self.get_mode(runner), + epoch=self.get_epoch(runner), + iter=cur_iter) + + # only record lr of the first param group + cur_lr = runner.current_lr() + if isinstance(cur_lr, list): + log_dict['lr'] = cur_lr[0] + else: + assert isinstance(cur_lr, dict) + log_dict['lr'] = {} + for k, lr_ in cur_lr.items(): + assert isinstance(lr_, list) + log_dict['lr'].update({k: lr_[0]}) + + if 'time' in runner.log_buffer.output: + # statistic memory + if torch.cuda.is_available(): + log_dict['memory'] = self._get_max_memory(runner) + + log_dict = dict(log_dict, **runner.log_buffer.output) # type: ignore + + self._log_info(log_dict, runner) + self._dump_log(log_dict, runner) + return log_dict + + def after_run(self, runner) -> None: + # copy or upload logs to self.out_dir + if self.out_dir is not None: + for filename in scandir(runner.work_dir, self.out_suffix, True): + local_filepath = osp.join(runner.work_dir, filename) + out_filepath = self.file_client.join_path( + self.out_dir, filename) + with open(local_filepath) as f: + self.file_client.put_text(f.read(), out_filepath) + + runner.logger.info( + f'The file {local_filepath} has been uploaded to ' + f'{out_filepath}.') + + if not self.keep_local: + os.remove(local_filepath) + runner.logger.info( + f'{local_filepath} was removed due to the ' + '`self.keep_local=False`') diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/wandb.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/wandb.py new file mode 100755 index 00000000..4d4f40ea --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/logger/wandb.py @@ -0,0 +1,107 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Dict, Optional, Union + +from dbnet_cv.utils import scandir +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class WandbLoggerHook(LoggerHook): + """Class to log metrics with wandb. + + It requires `wandb`_ to be installed. + + + Args: + init_kwargs (dict): A dict contains the initialization keys. Check + https://docs.wandb.ai/ref/python/init for more init arguments. + interval (int): Logging interval (every k iterations). + Default 10. + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. + Default: True. + reset_flag (bool): Whether to clear the output buffer after logging. + Default: False. + commit (bool): Save the metrics dict to the wandb server and increment + the step. If false ``wandb.log`` just updates the current metrics + dict with the row argument and metrics won't be saved until + ``wandb.log`` is called with ``commit=True``. + Default: True. + by_epoch (bool): Whether EpochBasedRunner is used. + Default: True. + with_step (bool): If True, the step will be logged from + ``self.get_iters``. Otherwise, step will not be logged. + Default: True. + log_artifact (bool): If True, artifacts in {work_dir} will be uploaded + to wandb after training ends. + Default: True + `New in version 1.4.3.` + out_suffix (str or tuple[str], optional): Those filenames ending with + ``out_suffix`` will be uploaded to wandb. + Default: ('.log.json', '.log', '.py'). + `New in version 1.4.3.` + + .. _wandb: + https://docs.wandb.ai + """ + + def __init__(self, + init_kwargs: Optional[Dict] = None, + interval: int = 10, + ignore_last: bool = True, + reset_flag: bool = False, + commit: bool = True, + by_epoch: bool = True, + with_step: bool = True, + log_artifact: bool = True, + out_suffix: Union[str, tuple] = ('.log.json', '.log', '.py')): + super().__init__(interval, ignore_last, reset_flag, by_epoch) + self.import_wandb() + self.init_kwargs = init_kwargs + self.commit = commit + self.with_step = with_step + self.log_artifact = log_artifact + self.out_suffix = out_suffix + + def import_wandb(self) -> None: + try: + import wandb + except ImportError: + raise ImportError( + 'Please run "pip install wandb" to install wandb') + self.wandb = wandb + + @master_only + def before_run(self, runner) -> None: + super().before_run(runner) + if self.wandb is None: + self.import_wandb() + if self.init_kwargs: + self.wandb.init(**self.init_kwargs) # type: ignore + else: + self.wandb.init() # type: ignore + + @master_only + def log(self, runner) -> None: + tags = self.get_loggable_tags(runner) + if tags: + if self.with_step: + self.wandb.log( + tags, step=self.get_iter(runner), commit=self.commit) + else: + tags['global_step'] = self.get_iter(runner) + self.wandb.log(tags, commit=self.commit) + + @master_only + def after_run(self, runner) -> None: + if self.log_artifact: + wandb_artifact = self.wandb.Artifact( + name='artifacts', type='model') + for filename in scandir(runner.work_dir, self.out_suffix, True): + local_filepath = osp.join(runner.work_dir, filename) + wandb_artifact.add_file(local_filepath) + self.wandb.log_artifact(wandb_artifact) + self.wandb.join() diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/lr_updater.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/lr_updater.py new file mode 100755 index 00000000..26ae4d89 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/lr_updater.py @@ -0,0 +1,751 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numbers +from math import cos, pi +from typing import Callable, List, Optional, Union + +import dbnet_cv +from dbnet_cv import runner +from .hook import HOOKS, Hook + + +class LrUpdaterHook(Hook): + """LR Scheduler in DBNET_CV. + + Args: + by_epoch (bool): LR changes epoch by epoch + warmup (string): Type of warmup used. It can be None(use no warmup), + 'constant', 'linear' or 'exp' + warmup_iters (int): The number of iterations or epochs that warmup + lasts + warmup_ratio (float): LR used at the beginning of warmup equals to + warmup_ratio * initial_lr + warmup_by_epoch (bool): When warmup_by_epoch == True, warmup_iters + means the number of epochs that warmup lasts, otherwise means the + number of iteration that warmup lasts + """ + + def __init__(self, + by_epoch: bool = True, + warmup: Optional[str] = None, + warmup_iters: int = 0, + warmup_ratio: float = 0.1, + warmup_by_epoch: bool = False) -> None: + # validate the "warmup" argument + if warmup is not None: + if warmup not in ['constant', 'linear', 'exp']: + raise ValueError( + f'"{warmup}" is not a supported type for warming up, valid' + ' types are "constant", "linear" and "exp"') + if warmup is not None: + assert warmup_iters > 0, \ + '"warmup_iters" must be a positive integer' + assert 0 < warmup_ratio <= 1.0, \ + '"warmup_ratio" must be in range (0,1]' + + self.by_epoch = by_epoch + self.warmup = warmup + self.warmup_iters: Optional[int] = warmup_iters + self.warmup_ratio = warmup_ratio + self.warmup_by_epoch = warmup_by_epoch + + if self.warmup_by_epoch: + self.warmup_epochs: Optional[int] = self.warmup_iters + self.warmup_iters = None + else: + self.warmup_epochs = None + + self.base_lr: Union[list, dict] = [] # initial lr for all param groups + self.regular_lr: list = [] # expected lr if no warming up is performed + + def _set_lr(self, runner, lr_groups): + if isinstance(runner.optimizer, dict): + for k, optim in runner.optimizer.items(): + for param_group, lr in zip(optim.param_groups, lr_groups[k]): + param_group['lr'] = lr + else: + for param_group, lr in zip(runner.optimizer.param_groups, + lr_groups): + param_group['lr'] = lr + + def get_lr(self, runner: 'runner.BaseRunner', base_lr: float): + raise NotImplementedError + + def get_regular_lr(self, runner: 'runner.BaseRunner'): + if isinstance(runner.optimizer, dict): + lr_groups = {} + for k in runner.optimizer.keys(): + _lr_group = [ + self.get_lr(runner, _base_lr) + for _base_lr in self.base_lr[k] + ] + lr_groups.update({k: _lr_group}) + + return lr_groups + else: + return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr] + + def get_warmup_lr(self, cur_iters: int): + + def _get_warmup_lr(cur_iters, regular_lr): + if self.warmup == 'constant': + warmup_lr = [_lr * self.warmup_ratio for _lr in regular_lr] + elif self.warmup == 'linear': + k = (1 - cur_iters / self.warmup_iters) * (1 - + self.warmup_ratio) + warmup_lr = [_lr * (1 - k) for _lr in regular_lr] + elif self.warmup == 'exp': + k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters) + warmup_lr = [_lr * k for _lr in regular_lr] + return warmup_lr + + if isinstance(self.regular_lr, dict): + lr_groups = {} + for key, regular_lr in self.regular_lr.items(): + lr_groups[key] = _get_warmup_lr(cur_iters, regular_lr) + return lr_groups + else: + return _get_warmup_lr(cur_iters, self.regular_lr) + + def before_run(self, runner: 'runner.BaseRunner'): + # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved, + # it will be set according to the optimizer params + if isinstance(runner.optimizer, dict): + self.base_lr = {} + for k, optim in runner.optimizer.items(): + for group in optim.param_groups: + group.setdefault('initial_lr', group['lr']) + _base_lr = [ + group['initial_lr'] for group in optim.param_groups + ] + self.base_lr.update({k: _base_lr}) + else: + for group in runner.optimizer.param_groups: + group.setdefault('initial_lr', group['lr']) + self.base_lr = [ + group['initial_lr'] for group in runner.optimizer.param_groups + ] + + def before_train_epoch(self, runner: 'runner.BaseRunner'): + if self.warmup_iters is None: + epoch_len = len(runner.data_loader) # type: ignore + self.warmup_iters = self.warmup_epochs * epoch_len # type: ignore + + if not self.by_epoch: + return + + self.regular_lr = self.get_regular_lr(runner) + self._set_lr(runner, self.regular_lr) + + def before_train_iter(self, runner: 'runner.BaseRunner'): + cur_iter = runner.iter + if not self.by_epoch: + self.regular_lr = self.get_regular_lr(runner) + if self.warmup is None or cur_iter >= self.warmup_iters: + self._set_lr(runner, self.regular_lr) + else: + warmup_lr = self.get_warmup_lr(cur_iter) + self._set_lr(runner, warmup_lr) + elif self.by_epoch: + if self.warmup is None or cur_iter > self.warmup_iters: + return + elif cur_iter == self.warmup_iters: + self._set_lr(runner, self.regular_lr) + else: + warmup_lr = self.get_warmup_lr(cur_iter) + self._set_lr(runner, warmup_lr) + + +@HOOKS.register_module() +class FixedLrUpdaterHook(LrUpdaterHook): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_lr(self, runner, base_lr): + return base_lr + + +@HOOKS.register_module() +class StepLrUpdaterHook(LrUpdaterHook): + """Step LR scheduler with min_lr clipping. + + Args: + step (int | list[int]): Step to decay the LR. If an int value is given, + regard it as the decay interval. If a list is given, decay LR at + these steps. + gamma (float): Decay LR ratio. Defaults to 0.1. + min_lr (float, optional): Minimum LR value to keep. If LR after decay + is lower than `min_lr`, it will be clipped to this value. If None + is given, we don't perform lr clipping. Default: None. + """ + + def __init__(self, + step: Union[int, List[int]], + gamma: float = 0.1, + min_lr: Optional[float] = None, + **kwargs) -> None: + if isinstance(step, list): + assert dbnet_cv.is_list_of(step, int) + assert all([s > 0 for s in step]) + elif isinstance(step, int): + assert step > 0 + else: + raise TypeError('"step" must be a list or integer') + self.step = step + self.gamma = gamma + self.min_lr = min_lr + super().__init__(**kwargs) + + def get_lr(self, runner: 'runner.BaseRunner', base_lr: float): + progress = runner.epoch if self.by_epoch else runner.iter + + # calculate exponential term + if isinstance(self.step, int): + exp = progress // self.step + else: + exp = len(self.step) + for i, s in enumerate(self.step): + if progress < s: + exp = i + break + + lr = base_lr * (self.gamma**exp) + if self.min_lr is not None: + # clip to a minimum value + lr = max(lr, self.min_lr) + return lr + + +@HOOKS.register_module() +class ExpLrUpdaterHook(LrUpdaterHook): + + def __init__(self, gamma: float, **kwargs) -> None: + self.gamma = gamma + super().__init__(**kwargs) + + def get_lr(self, runner: 'runner.BaseRunner', base_lr: float): + progress = runner.epoch if self.by_epoch else runner.iter + return base_lr * self.gamma**progress + + +@HOOKS.register_module() +class PolyLrUpdaterHook(LrUpdaterHook): + + def __init__(self, + power: float = 1., + min_lr: float = 0., + **kwargs) -> None: + self.power = power + self.min_lr = min_lr + super().__init__(**kwargs) + + def get_lr(self, runner: 'runner.BaseRunner', base_lr: float): + if self.by_epoch: + progress = runner.epoch + max_progress = runner.max_epochs + else: + progress = runner.iter + max_progress = runner.max_iters + coeff = (1 - progress / max_progress)**self.power + return (base_lr - self.min_lr) * coeff + self.min_lr + + +@HOOKS.register_module() +class InvLrUpdaterHook(LrUpdaterHook): + + def __init__(self, gamma: float, power: float = 1., **kwargs) -> None: + self.gamma = gamma + self.power = power + super().__init__(**kwargs) + + def get_lr(self, runner: 'runner.BaseRunner', base_lr: float): + progress = runner.epoch if self.by_epoch else runner.iter + return base_lr * (1 + self.gamma * progress)**(-self.power) + + +@HOOKS.register_module() +class CosineAnnealingLrUpdaterHook(LrUpdaterHook): + """CosineAnnealing LR scheduler. + + Args: + min_lr (float, optional): The minimum lr. Default: None. + min_lr_ratio (float, optional): The ratio of minimum lr to the base lr. + Either `min_lr` or `min_lr_ratio` should be specified. + Default: None. + """ + + def __init__(self, + min_lr: Optional[float] = None, + min_lr_ratio: Optional[float] = None, + **kwargs) -> None: + assert (min_lr is None) ^ (min_lr_ratio is None) + self.min_lr = min_lr + self.min_lr_ratio = min_lr_ratio + super().__init__(**kwargs) + + def get_lr(self, runner: 'runner.BaseRunner', base_lr: float): + if self.by_epoch: + progress = runner.epoch + max_progress = runner.max_epochs + else: + progress = runner.iter + max_progress = runner.max_iters + + if self.min_lr_ratio is not None: + target_lr = base_lr * self.min_lr_ratio + else: + target_lr = self.min_lr # type:ignore + return annealing_cos(base_lr, target_lr, progress / max_progress) + + +@HOOKS.register_module() +class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook): + """Flat + Cosine lr schedule. + + Modified from https://github.com/fastai/fastai/blob/master/fastai/callback/schedule.py#L128 # noqa: E501 + + Args: + start_percent (float): When to start annealing the learning rate + after the percentage of the total training steps. + The value should be in range [0, 1). + Default: 0.75 + min_lr (float, optional): The minimum lr. Default: None. + min_lr_ratio (float, optional): The ratio of minimum lr to the base lr. + Either `min_lr` or `min_lr_ratio` should be specified. + Default: None. + """ + + def __init__(self, + start_percent: float = 0.75, + min_lr: Optional[float] = None, + min_lr_ratio: Optional[float] = None, + **kwargs) -> None: + assert (min_lr is None) ^ (min_lr_ratio is None) + if start_percent < 0 or start_percent > 1 or not isinstance( + start_percent, float): + raise ValueError( + 'expected float between 0 and 1 start_percent, but ' + f'got {start_percent}') + self.start_percent = start_percent + self.min_lr = min_lr + self.min_lr_ratio = min_lr_ratio + super().__init__(**kwargs) + + def get_lr(self, runner: 'runner.BaseRunner', base_lr: float): + if self.by_epoch: + start = round(runner.max_epochs * self.start_percent) + progress = runner.epoch - start + max_progress = runner.max_epochs - start + else: + start = round(runner.max_iters * self.start_percent) + progress = runner.iter - start + max_progress = runner.max_iters - start + + if self.min_lr_ratio is not None: + target_lr = base_lr * self.min_lr_ratio + else: + target_lr = self.min_lr # type:ignore + + if progress < 0: + return base_lr + else: + return annealing_cos(base_lr, target_lr, progress / max_progress) + + +@HOOKS.register_module() +class CosineRestartLrUpdaterHook(LrUpdaterHook): + """Cosine annealing with restarts learning rate scheme. + + Args: + periods (list[int]): Periods for each cosine anneling cycle. + restart_weights (list[float]): Restart weights at each + restart iteration. Defaults to [1]. + min_lr (float, optional): The minimum lr. Default: None. + min_lr_ratio (float, optional): The ratio of minimum lr to the base lr. + Either `min_lr` or `min_lr_ratio` should be specified. + Default: None. + """ + + def __init__(self, + periods: List[int], + restart_weights: List[float] = [1], + min_lr: Optional[float] = None, + min_lr_ratio: Optional[float] = None, + **kwargs) -> None: + assert (min_lr is None) ^ (min_lr_ratio is None) + self.periods = periods + self.min_lr = min_lr + self.min_lr_ratio = min_lr_ratio + self.restart_weights = restart_weights + assert (len(self.periods) == len(self.restart_weights) + ), 'periods and restart_weights should have the same length.' + super().__init__(**kwargs) + + self.cumulative_periods = [ + sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) + ] + + def get_lr(self, runner: 'runner.BaseRunner', base_lr: float): + if self.by_epoch: + progress = runner.epoch + else: + progress = runner.iter + + if self.min_lr_ratio is not None: + target_lr = base_lr * self.min_lr_ratio + else: + target_lr = self.min_lr # type:ignore + + idx = get_position_from_periods(progress, self.cumulative_periods) + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1] + current_periods = self.periods[idx] + + alpha = min((progress - nearest_restart) / current_periods, 1) + return annealing_cos(base_lr, target_lr, alpha, current_weight) + + +def get_position_from_periods(iteration: int, cumulative_periods: List[int]): + """Get the position from a period list. + + It will return the index of the right-closest number in the period list. + For example, the cumulative_periods = [100, 200, 300, 400], + if iteration == 50, return 0; + if iteration == 210, return 2; + if iteration == 300, return 3. + + Args: + iteration (int): Current iteration. + cumulative_periods (list[int]): Cumulative period list. + + Returns: + int: The position of the right-closest number in the period list. + """ + for i, period in enumerate(cumulative_periods): + if iteration < period: + return i + raise ValueError(f'Current iteration {iteration} exceeds ' + f'cumulative_periods {cumulative_periods}') + + +@HOOKS.register_module() +class CyclicLrUpdaterHook(LrUpdaterHook): + """Cyclic LR Scheduler. + + Implement the cyclical learning rate policy (CLR) described in + https://arxiv.org/pdf/1506.01186.pdf + + Different from the original paper, we use cosine annealing rather than + triangular policy inside a cycle. This improves the performance in the + 3D detection area. + + Args: + by_epoch (bool, optional): Whether to update LR by epoch. + target_ratio (tuple[float], optional): Relative ratio of the highest LR + and the lowest LR to the initial LR. + cyclic_times (int, optional): Number of cycles during training + step_ratio_up (float, optional): The ratio of the increasing process of + LR in the total cycle. + anneal_strategy (str, optional): {'cos', 'linear'} + Specifies the annealing strategy: 'cos' for cosine annealing, + 'linear' for linear annealing. Default: 'cos'. + gamma (float, optional): Cycle decay ratio. Default: 1. + It takes values in the range (0, 1]. The difference between the + maximum learning rate and the minimum learning rate decreases + periodically when it is less than 1. `New in version 1.4.4.` + """ + + def __init__(self, + by_epoch: bool = False, + target_ratio: Union[float, tuple] = (10, 1e-4), + cyclic_times: int = 1, + step_ratio_up: float = 0.4, + anneal_strategy: str = 'cos', + gamma: float = 1, + **kwargs) -> None: + if isinstance(target_ratio, float): + target_ratio = (target_ratio, target_ratio / 1e5) + elif isinstance(target_ratio, tuple): + target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \ + if len(target_ratio) == 1 else target_ratio + else: + raise ValueError('target_ratio should be either float ' + f'or tuple, got {type(target_ratio)}') + + assert len(target_ratio) == 2, \ + '"target_ratio" must be list or tuple of two floats' + assert 0 <= step_ratio_up < 1.0, \ + '"step_ratio_up" must be in range [0,1)' + assert 0 < gamma <= 1, \ + '"gamma" must be in range (0, 1]' + + self.target_ratio = target_ratio + self.cyclic_times = cyclic_times + self.step_ratio_up = step_ratio_up + self.gamma = gamma + self.max_iter_per_phase = None + self.lr_phases: list = [] # init lr_phases + # validate anneal_strategy + if anneal_strategy not in ['cos', 'linear']: + raise ValueError('anneal_strategy must be one of "cos" or ' + f'"linear", instead got {anneal_strategy}') + elif anneal_strategy == 'cos': + self.anneal_func: Callable[[float, float, float], + float] = annealing_cos + elif anneal_strategy == 'linear': + self.anneal_func = annealing_linear + + assert not by_epoch, \ + 'currently only support "by_epoch" = False' + super().__init__(by_epoch, **kwargs) + + def before_run(self, runner: 'runner.BaseRunner'): + super().before_run(runner) + # initiate lr_phases + # total lr_phases are separated as up and down + self.max_iter_per_phase = runner.max_iters // self.cyclic_times + iter_up_phase = int(self.step_ratio_up * + self.max_iter_per_phase) # type:ignore + self.lr_phases.append([0, iter_up_phase, 1, self.target_ratio[0]]) + self.lr_phases.append([ + iter_up_phase, self.max_iter_per_phase, self.target_ratio[0], + self.target_ratio[1] + ]) + + def get_lr(self, runner: 'runner.BaseRunner', base_lr: float): + curr_iter = runner.iter % self.max_iter_per_phase + curr_cycle = runner.iter // self.max_iter_per_phase + # Update weight decay + scale = self.gamma**curr_cycle + + for (start_iter, end_iter, start_ratio, end_ratio) in self.lr_phases: + if start_iter <= curr_iter < end_iter: + # Apply cycle scaling to gradually reduce the difference + # between max_lr and base lr. The target end_ratio can be + # expressed as: + # end_ratio = (base_lr + scale * (max_lr - base_lr)) / base_lr + # iteration: 0-iter_up_phase: + if start_iter == 0: + end_ratio = 1 - scale + end_ratio * scale + # iteration: iter_up_phase-self.max_iter_per_phase + else: + start_ratio = 1 - scale + start_ratio * scale + progress = curr_iter - start_iter + return self.anneal_func(base_lr * start_ratio, + base_lr * end_ratio, + progress / (end_iter - start_iter)) + + +@HOOKS.register_module() +class OneCycleLrUpdaterHook(LrUpdaterHook): + """One Cycle LR Scheduler. + + The 1cycle learning rate policy changes the learning rate after every + batch. The one cycle learning rate policy is described in + https://arxiv.org/pdf/1708.07120.pdf + + Args: + max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. + total_steps (int, optional): The total number of steps in the cycle. + Note that if a value is not provided here, it will be the max_iter + of runner. Default: None. + pct_start (float): The percentage of the cycle (in number of steps) + spent increasing the learning rate. + Default: 0.3 + anneal_strategy (str): {'cos', 'linear'} + Specifies the annealing strategy: 'cos' for cosine annealing, + 'linear' for linear annealing. + Default: 'cos' + div_factor (float): Determines the initial learning rate via + initial_lr = max_lr/div_factor + Default: 25 + final_div_factor (float): Determines the minimum learning rate via + min_lr = initial_lr/final_div_factor + Default: 1e4 + three_phase (bool): If three_phase is True, use a third phase of the + schedule to annihilate the learning rate according to + final_div_factor instead of modifying the second phase (the first + two phases will be symmetrical about the step indicated by + pct_start). + Default: False + """ + + def __init__(self, + max_lr: Union[float, List], + total_steps: Optional[int] = None, + pct_start: float = 0.3, + anneal_strategy: str = 'cos', + div_factor: float = 25, + final_div_factor: float = 1e4, + three_phase: bool = False, + **kwargs) -> None: + # validate by_epoch, currently only support by_epoch = False + if 'by_epoch' not in kwargs: + kwargs['by_epoch'] = False + else: + assert not kwargs['by_epoch'], \ + 'currently only support "by_epoch" = False' + if not isinstance(max_lr, (numbers.Number, list, dict)): + raise ValueError('the type of max_lr must be the one of list or ' + f'dict, but got {type(max_lr)}') + self._max_lr = max_lr + if total_steps is not None: + if not isinstance(total_steps, int): + raise ValueError('the type of total_steps must be int, but' + f'got {type(total_steps)}') + self.total_steps = total_steps + # validate pct_start + if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): + raise ValueError('expected float between 0 and 1 pct_start, but ' + f'got {pct_start}') + self.pct_start = pct_start + # validate anneal_strategy + if anneal_strategy not in ['cos', 'linear']: + raise ValueError('anneal_strategy must be one of "cos" or ' + f'"linear", instead got {anneal_strategy}') + elif anneal_strategy == 'cos': + self.anneal_func: Callable[[float, float, float], + float] = annealing_cos + elif anneal_strategy == 'linear': + self.anneal_func = annealing_linear + self.div_factor = div_factor + self.final_div_factor = final_div_factor + self.three_phase = three_phase + self.lr_phases: list = [] # init lr_phases + super().__init__(**kwargs) + + def before_run(self, runner: 'runner.BaseRunner'): + if hasattr(self, 'total_steps'): + total_steps = self.total_steps + else: + total_steps = runner.max_iters + if total_steps < runner.max_iters: + raise ValueError( + 'The total steps must be greater than or equal to max ' + f'iterations {runner.max_iters} of runner, but total steps ' + f'is {total_steps}.') + + if isinstance(runner.optimizer, dict): + self.base_lr = {} + for k, optim in runner.optimizer.items(): + _max_lr = format_param(k, optim, self._max_lr) + self.base_lr[k] = [lr / self.div_factor for lr in _max_lr] + for group, lr in zip(optim.param_groups, self.base_lr[k]): + group.setdefault('initial_lr', lr) + else: + k = type(runner.optimizer).__name__ + _max_lr = format_param(k, runner.optimizer, self._max_lr) + self.base_lr = [lr / self.div_factor for lr in _max_lr] + for group, lr in zip(runner.optimizer.param_groups, self.base_lr): + group.setdefault('initial_lr', lr) + + if self.three_phase: + self.lr_phases.append( + [float(self.pct_start * total_steps) - 1, 1, self.div_factor]) + self.lr_phases.append([ + float(2 * self.pct_start * total_steps) - 2, self.div_factor, 1 + ]) + self.lr_phases.append( + [total_steps - 1, 1, 1 / self.final_div_factor]) + else: + self.lr_phases.append( + [float(self.pct_start * total_steps) - 1, 1, self.div_factor]) + self.lr_phases.append( + [total_steps - 1, self.div_factor, 1 / self.final_div_factor]) + + def get_lr(self, runner: 'runner.BaseRunner', base_lr: float): + curr_iter = runner.iter + start_iter = 0 + for i, (end_iter, start_lr, end_lr) in enumerate(self.lr_phases): + if curr_iter <= end_iter: + pct = (curr_iter - start_iter) / (end_iter - start_iter) + lr = self.anneal_func(base_lr * start_lr, base_lr * end_lr, + pct) + break + start_iter = end_iter + return lr + + +@HOOKS.register_module() +class LinearAnnealingLrUpdaterHook(LrUpdaterHook): + """Linear annealing LR Scheduler decays the learning rate of each parameter + group linearly. + + Args: + min_lr (float, optional): The minimum lr. Default: None. + min_lr_ratio (float, optional): The ratio of minimum lr to the base lr. + Either `min_lr` or `min_lr_ratio` should be specified. + Default: None. + """ + + def __init__(self, + min_lr: Optional[float] = None, + min_lr_ratio: Optional[float] = None, + **kwargs): + assert (min_lr is None) ^ (min_lr_ratio is None) + self.min_lr = min_lr + self.min_lr_ratio = min_lr_ratio + super().__init__(**kwargs) + + def get_lr(self, runner: 'runner.BaseRunner', base_lr: float): + if self.by_epoch: + progress = runner.epoch + max_progress = runner.max_epochs + else: + progress = runner.iter + max_progress = runner.max_iters + if self.min_lr_ratio is not None: + target_lr = base_lr * self.min_lr_ratio + else: + target_lr = self.min_lr # type:ignore + return annealing_linear(base_lr, target_lr, progress / max_progress) + + +def annealing_cos(start: float, + end: float, + factor: float, + weight: float = 1) -> float: + """Calculate annealing cos learning rate. + + Cosine anneal from `weight * start + (1 - weight) * end` to `end` as + percentage goes from 0.0 to 1.0. + + Args: + start (float): The starting learning rate of the cosine annealing. + end (float): The ending learing rate of the cosine annealing. + factor (float): The coefficient of `pi` when calculating the current + percentage. Range from 0.0 to 1.0. + weight (float, optional): The combination factor of `start` and `end` + when calculating the actual starting learning rate. Default to 1. + """ + cos_out = cos(pi * factor) + 1 + return end + 0.5 * weight * (start - end) * cos_out + + +def annealing_linear(start: float, end: float, factor: float) -> float: + """Calculate annealing linear learning rate. + + Linear anneal from `start` to `end` as percentage goes from 0.0 to 1.0. + + Args: + start (float): The starting learning rate of the linear annealing. + end (float): The ending learing rate of the linear annealing. + factor (float): The coefficient of `pi` when calculating the current + percentage. Range from 0.0 to 1.0. + """ + return start + (end - start) * factor + + +def format_param(name, optim, param): + if isinstance(param, numbers.Number): + return [param] * len(optim.param_groups) + elif isinstance(param, (list, tuple)): # multi param groups + if len(param) != len(optim.param_groups): + raise ValueError(f'expected {len(optim.param_groups)} ' + f'values for {name}, got {len(param)}') + return param + else: # multi optimizers + if name not in param: + raise KeyError(f'{name} is not found in {param.keys()}') + return param[name] diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/memory.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/memory.py new file mode 100755 index 00000000..70cf9a83 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/memory.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class EmptyCacheHook(Hook): + + def __init__(self, before_epoch=False, after_epoch=True, after_iter=False): + self._before_epoch = before_epoch + self._after_epoch = after_epoch + self._after_iter = after_iter + + def after_iter(self, runner): + if self._after_iter: + torch.cuda.empty_cache() + + def before_epoch(self, runner): + if self._before_epoch: + torch.cuda.empty_cache() + + def after_epoch(self, runner): + if self._after_epoch: + torch.cuda.empty_cache() diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/momentum_updater.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/momentum_updater.py new file mode 100755 index 00000000..aa3d5332 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/momentum_updater.py @@ -0,0 +1,566 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dbnet_cv +from .hook import HOOKS, Hook +from .lr_updater import annealing_cos, annealing_linear, format_param + + +class MomentumUpdaterHook(Hook): + + def __init__(self, + by_epoch=True, + warmup=None, + warmup_iters=0, + warmup_ratio=0.9): + # validate the "warmup" argument + if warmup is not None: + if warmup not in ['constant', 'linear', 'exp']: + raise ValueError( + f'"{warmup}" is not a supported type for warming up, valid' + ' types are "constant" and "linear"') + if warmup is not None: + assert warmup_iters > 0, \ + '"warmup_iters" must be a positive integer' + assert 0 < warmup_ratio <= 1.0, \ + '"warmup_momentum" must be in range (0,1]' + + self.by_epoch = by_epoch + self.warmup = warmup + self.warmup_iters = warmup_iters + self.warmup_ratio = warmup_ratio + + self.base_momentum = [] # initial momentum for all param groups + self.regular_momentum = [ + ] # expected momentum if no warming up is performed + + def _set_momentum(self, runner, momentum_groups): + if isinstance(runner.optimizer, dict): + for k, optim in runner.optimizer.items(): + for param_group, mom in zip(optim.param_groups, + momentum_groups[k]): + if 'momentum' in param_group.keys(): + param_group['momentum'] = mom + elif 'betas' in param_group.keys(): + param_group['betas'] = (mom, param_group['betas'][1]) + else: + for param_group, mom in zip(runner.optimizer.param_groups, + momentum_groups): + if 'momentum' in param_group.keys(): + param_group['momentum'] = mom + elif 'betas' in param_group.keys(): + param_group['betas'] = (mom, param_group['betas'][1]) + + def get_momentum(self, runner, base_momentum): + raise NotImplementedError + + def get_regular_momentum(self, runner): + if isinstance(runner.optimizer, dict): + momentum_groups = {} + for k in runner.optimizer.keys(): + _momentum_group = [ + self.get_momentum(runner, _base_momentum) + for _base_momentum in self.base_momentum[k] + ] + momentum_groups.update({k: _momentum_group}) + return momentum_groups + else: + return [ + self.get_momentum(runner, _base_momentum) + for _base_momentum in self.base_momentum + ] + + def get_warmup_momentum(self, cur_iters): + + def _get_warmup_momentum(cur_iters, regular_momentum): + if self.warmup == 'constant': + warmup_momentum = [ + _momentum / self.warmup_ratio + for _momentum in regular_momentum + ] + elif self.warmup == 'linear': + k = (1 - cur_iters / self.warmup_iters) * (1 - + self.warmup_ratio) + warmup_momentum = [ + _momentum / (1 - k) for _momentum in regular_momentum + ] + elif self.warmup == 'exp': + k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters) + warmup_momentum = [ + _momentum / k for _momentum in regular_momentum + ] + return warmup_momentum + + if isinstance(self.regular_momentum, dict): + momentum_groups = {} + for key, regular_momentum in self.regular_momentum.items(): + momentum_groups[key] = _get_warmup_momentum( + cur_iters, regular_momentum) + return momentum_groups + else: + return _get_warmup_momentum(cur_iters, self.regular_momentum) + + def before_run(self, runner): + # NOTE: when resuming from a checkpoint, + # if 'initial_momentum' is not saved, + # it will be set according to the optimizer params + if isinstance(runner.optimizer, dict): + self.base_momentum = {} + for k, optim in runner.optimizer.items(): + for group in optim.param_groups: + if 'momentum' in group.keys(): + group.setdefault('initial_momentum', group['momentum']) + else: + group.setdefault('initial_momentum', group['betas'][0]) + _base_momentum = [ + group['initial_momentum'] for group in optim.param_groups + ] + self.base_momentum.update({k: _base_momentum}) + else: + for group in runner.optimizer.param_groups: + if 'momentum' in group.keys(): + group.setdefault('initial_momentum', group['momentum']) + else: + group.setdefault('initial_momentum', group['betas'][0]) + self.base_momentum = [ + group['initial_momentum'] + for group in runner.optimizer.param_groups + ] + + def before_train_epoch(self, runner): + if not self.by_epoch: + return + self.regular_momentum = self.get_regular_momentum(runner) + self._set_momentum(runner, self.regular_momentum) + + def before_train_iter(self, runner): + cur_iter = runner.iter + if not self.by_epoch: + self.regular_momentum = self.get_regular_momentum(runner) + if self.warmup is None or cur_iter >= self.warmup_iters: + self._set_momentum(runner, self.regular_momentum) + else: + warmup_momentum = self.get_warmup_momentum(cur_iter) + self._set_momentum(runner, warmup_momentum) + elif self.by_epoch: + if self.warmup is None or cur_iter > self.warmup_iters: + return + elif cur_iter == self.warmup_iters: + self._set_momentum(runner, self.regular_momentum) + else: + warmup_momentum = self.get_warmup_momentum(cur_iter) + self._set_momentum(runner, warmup_momentum) + + +@HOOKS.register_module() +class StepMomentumUpdaterHook(MomentumUpdaterHook): + """Step momentum scheduler with min value clipping. + + Args: + step (int | list[int]): Step to decay the momentum. If an int value is + given, regard it as the decay interval. If a list is given, decay + momentum at these steps. + gamma (float, optional): Decay momentum ratio. Default: 0.5. + min_momentum (float, optional): Minimum momentum value to keep. If + momentum after decay is lower than this value, it will be clipped + accordingly. If None is given, we don't perform lr clipping. + Default: None. + """ + + def __init__(self, step, gamma=0.5, min_momentum=None, **kwargs): + if isinstance(step, list): + assert dbnet_cv.is_list_of(step, int) + assert all([s > 0 for s in step]) + elif isinstance(step, int): + assert step > 0 + else: + raise TypeError('"step" must be a list or integer') + self.step = step + self.gamma = gamma + self.min_momentum = min_momentum + super().__init__(**kwargs) + + def get_momentum(self, runner, base_momentum): + progress = runner.epoch if self.by_epoch else runner.iter + + # calculate exponential term + if isinstance(self.step, int): + exp = progress // self.step + else: + exp = len(self.step) + for i, s in enumerate(self.step): + if progress < s: + exp = i + break + + momentum = base_momentum * (self.gamma**exp) + if self.min_momentum is not None: + # clip to a minimum value + momentum = max(momentum, self.min_momentum) + return momentum + + +@HOOKS.register_module() +class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook): + """Cosine annealing LR Momentum decays the Momentum of each parameter group + linearly. + + Args: + min_momentum (float, optional): The minimum momentum. Default: None. + min_momentum_ratio (float, optional): The ratio of minimum momentum to + the base momentum. Either `min_momentum` or `min_momentum_ratio` + should be specified. Default: None. + """ + + def __init__(self, min_momentum=None, min_momentum_ratio=None, **kwargs): + assert (min_momentum is None) ^ (min_momentum_ratio is None) + self.min_momentum = min_momentum + self.min_momentum_ratio = min_momentum_ratio + super().__init__(**kwargs) + + def get_momentum(self, runner, base_momentum): + if self.by_epoch: + progress = runner.epoch + max_progress = runner.max_epochs + else: + progress = runner.iter + max_progress = runner.max_iters + if self.min_momentum_ratio is not None: + target_momentum = base_momentum * self.min_momentum_ratio + else: + target_momentum = self.min_momentum + return annealing_cos(base_momentum, target_momentum, + progress / max_progress) + + +@HOOKS.register_module() +class LinearAnnealingMomentumUpdaterHook(MomentumUpdaterHook): + """Linear annealing LR Momentum decays the Momentum of each parameter group + linearly. + + Args: + min_momentum (float, optional): The minimum momentum. Default: None. + min_momentum_ratio (float, optional): The ratio of minimum momentum to + the base momentum. Either `min_momentum` or `min_momentum_ratio` + should be specified. Default: None. + """ + + def __init__(self, min_momentum=None, min_momentum_ratio=None, **kwargs): + assert (min_momentum is None) ^ (min_momentum_ratio is None) + self.min_momentum = min_momentum + self.min_momentum_ratio = min_momentum_ratio + super().__init__(**kwargs) + + def get_momentum(self, runner, base_momentum): + if self.by_epoch: + progress = runner.epoch + max_progress = runner.max_epochs + else: + progress = runner.iter + max_progress = runner.max_iters + if self.min_momentum_ratio is not None: + target_momentum = base_momentum * self.min_momentum_ratio + else: + target_momentum = self.min_momentum + return annealing_linear(base_momentum, target_momentum, + progress / max_progress) + + +@HOOKS.register_module() +class CyclicMomentumUpdaterHook(MomentumUpdaterHook): + """Cyclic momentum Scheduler. + + Implement the cyclical momentum scheduler policy described in + https://arxiv.org/pdf/1708.07120.pdf + + This momentum scheduler usually used together with the CyclicLRUpdater + to improve the performance in the 3D detection area. + + Args: + target_ratio (tuple[float]): Relative ratio of the lowest momentum and + the highest momentum to the initial momentum. + cyclic_times (int): Number of cycles during training + step_ratio_up (float): The ratio of the increasing process of momentum + in the total cycle. + by_epoch (bool): Whether to update momentum by epoch. + anneal_strategy (str, optional): {'cos', 'linear'} + Specifies the annealing strategy: 'cos' for cosine annealing, + 'linear' for linear annealing. Default: 'cos'. + gamma (float, optional): Cycle decay ratio. Default: 1. + It takes values in the range (0, 1]. The difference between the + maximum learning rate and the minimum learning rate decreases + periodically when it is less than 1. `New in version 1.4.4.` + """ + + def __init__(self, + by_epoch=False, + target_ratio=(0.85 / 0.95, 1), + cyclic_times=1, + step_ratio_up=0.4, + anneal_strategy='cos', + gamma=1, + **kwargs): + if isinstance(target_ratio, float): + target_ratio = (target_ratio, target_ratio / 1e5) + elif isinstance(target_ratio, tuple): + target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \ + if len(target_ratio) == 1 else target_ratio + else: + raise ValueError('target_ratio should be either float ' + f'or tuple, got {type(target_ratio)}') + + assert len(target_ratio) == 2, \ + '"target_ratio" must be list or tuple of two floats' + assert 0 <= step_ratio_up < 1.0, \ + '"step_ratio_up" must be in range [0,1)' + + self.target_ratio = target_ratio + self.cyclic_times = cyclic_times + self.step_ratio_up = step_ratio_up + self.gamma = gamma + self.momentum_phases = [] # init momentum_phases + + if anneal_strategy not in ['cos', 'linear']: + raise ValueError('anneal_strategy must be one of "cos" or ' + f'"linear", instead got {anneal_strategy}') + elif anneal_strategy == 'cos': + self.anneal_func = annealing_cos + elif anneal_strategy == 'linear': + self.anneal_func = annealing_linear + # currently only support by_epoch=False + assert not by_epoch, \ + 'currently only support "by_epoch" = False' + super().__init__(by_epoch, **kwargs) + + def before_run(self, runner): + super().before_run(runner) + # initiate momentum_phases + # total momentum_phases are separated as up and down + max_iter_per_phase = runner.max_iters // self.cyclic_times + iter_up_phase = int(self.step_ratio_up * max_iter_per_phase) + self.max_iter_per_phase = max_iter_per_phase + self.momentum_phases.append( + [0, iter_up_phase, 1, self.target_ratio[0]]) + self.momentum_phases.append([ + iter_up_phase, max_iter_per_phase, self.target_ratio[0], + self.target_ratio[1] + ]) + + def get_momentum(self, runner, base_momentum): + curr_iter = runner.iter % self.max_iter_per_phase + curr_cycle = runner.iter // self.max_iter_per_phase + scale = self.gamma**curr_cycle + for (start_iter, end_iter, start_ratio, end_ratio) \ + in self.momentum_phases: + if start_iter <= curr_iter < end_iter: + # Apply cycle scaling to gradually reduce the difference + # between max_momentum and base momentum. The target end_ratio + # can be expressed as: + # end_ratio = (base_momentum + scale * \ + # (max_momentum - base_momentum)) / base_momentum + # iteration: 0-iter_up_phase: + if start_iter == 0: + end_ratio = 1 - scale + end_ratio * scale + # iteration: iter_up_phase-self.max_iter_per_phase + else: + start_ratio = 1 - scale + start_ratio * scale + progress = curr_iter - start_iter + return self.anneal_func(base_momentum * start_ratio, + base_momentum * end_ratio, + progress / (end_iter - start_iter)) + + +@HOOKS.register_module() +class OneCycleMomentumUpdaterHook(MomentumUpdaterHook): + """OneCycle momentum Scheduler. + + This momentum scheduler usually used together with the OneCycleLrUpdater + to improve the performance. + + Args: + base_momentum (float or list): Lower momentum boundaries in the cycle + for each parameter group. Note that momentum is cycled inversely + to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr'. + Default: 0.85 + max_momentum (float or list): Upper momentum boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_momentum - base_momentum). + Note that momentum is cycled inversely + to learning rate; at the start of a cycle, momentum is + 'max_momentum' and learning rate is 'base_lr' + Default: 0.95 + pct_start (float): The percentage of the cycle (in number of steps) + spent increasing the learning rate. + Default: 0.3 + anneal_strategy (str): {'cos', 'linear'} + Specifies the annealing strategy: 'cos' for cosine annealing, + 'linear' for linear annealing. + Default: 'cos' + three_phase (bool): If three_phase is True, use a third phase of the + schedule to annihilate the learning rate according to + final_div_factor instead of modifying the second phase (the first + two phases will be symmetrical about the step indicated by + pct_start). + Default: False + """ + + def __init__(self, + base_momentum=0.85, + max_momentum=0.95, + pct_start=0.3, + anneal_strategy='cos', + three_phase=False, + **kwargs): + # validate by_epoch, currently only support by_epoch=False + if 'by_epoch' not in kwargs: + kwargs['by_epoch'] = False + else: + assert not kwargs['by_epoch'], \ + 'currently only support "by_epoch" = False' + if not isinstance(base_momentum, (float, list, dict)): + raise ValueError('base_momentum must be the type among of float,' + 'list or dict.') + self._base_momentum = base_momentum + if not isinstance(max_momentum, (float, list, dict)): + raise ValueError('max_momentum must be the type among of float,' + 'list or dict.') + self._max_momentum = max_momentum + # validate pct_start + if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): + raise ValueError('Expected float between 0 and 1 pct_start, but ' + f'got {pct_start}') + self.pct_start = pct_start + # validate anneal_strategy + if anneal_strategy not in ['cos', 'linear']: + raise ValueError('anneal_strategy must by one of "cos" or ' + f'"linear", instead got {anneal_strategy}') + elif anneal_strategy == 'cos': + self.anneal_func = annealing_cos + elif anneal_strategy == 'linear': + self.anneal_func = annealing_linear + self.three_phase = three_phase + self.momentum_phases = [] # init momentum_phases + super().__init__(**kwargs) + + def before_run(self, runner): + if isinstance(runner.optimizer, dict): + for k, optim in runner.optimizer.items(): + if ('momentum' not in optim.defaults + and 'betas' not in optim.defaults): + raise ValueError('optimizer must support momentum with' + 'option enabled') + self.use_beta1 = 'betas' in optim.defaults + _base_momentum = format_param(k, optim, self._base_momentum) + _max_momentum = format_param(k, optim, self._max_momentum) + for group, b_momentum, m_momentum in zip( + optim.param_groups, _base_momentum, _max_momentum): + if self.use_beta1: + _, beta2 = group['betas'] + group['betas'] = (m_momentum, beta2) + else: + group['momentum'] = m_momentum + group['base_momentum'] = b_momentum + group['max_momentum'] = m_momentum + else: + optim = runner.optimizer + if ('momentum' not in optim.defaults + and 'betas' not in optim.defaults): + raise ValueError('optimizer must support momentum with' + 'option enabled') + self.use_beta1 = 'betas' in optim.defaults + k = type(optim).__name__ + _base_momentum = format_param(k, optim, self._base_momentum) + _max_momentum = format_param(k, optim, self._max_momentum) + for group, b_momentum, m_momentum in zip(optim.param_groups, + _base_momentum, + _max_momentum): + if self.use_beta1: + _, beta2 = group['betas'] + group['betas'] = (m_momentum, beta2) + else: + group['momentum'] = m_momentum + group['base_momentum'] = b_momentum + group['max_momentum'] = m_momentum + + if self.three_phase: + self.momentum_phases.append({ + 'end_iter': + float(self.pct_start * runner.max_iters) - 1, + 'start_momentum': + 'max_momentum', + 'end_momentum': + 'base_momentum' + }) + self.momentum_phases.append({ + 'end_iter': + float(2 * self.pct_start * runner.max_iters) - 2, + 'start_momentum': + 'base_momentum', + 'end_momentum': + 'max_momentum' + }) + self.momentum_phases.append({ + 'end_iter': runner.max_iters - 1, + 'start_momentum': 'max_momentum', + 'end_momentum': 'max_momentum' + }) + else: + self.momentum_phases.append({ + 'end_iter': + float(self.pct_start * runner.max_iters) - 1, + 'start_momentum': + 'max_momentum', + 'end_momentum': + 'base_momentum' + }) + self.momentum_phases.append({ + 'end_iter': runner.max_iters - 1, + 'start_momentum': 'base_momentum', + 'end_momentum': 'max_momentum' + }) + + def _set_momentum(self, runner, momentum_groups): + if isinstance(runner.optimizer, dict): + for k, optim in runner.optimizer.items(): + for param_group, mom in zip(optim.param_groups, + momentum_groups[k]): + if 'momentum' in param_group.keys(): + param_group['momentum'] = mom + elif 'betas' in param_group.keys(): + param_group['betas'] = (mom, param_group['betas'][1]) + else: + for param_group, mom in zip(runner.optimizer.param_groups, + momentum_groups): + if 'momentum' in param_group.keys(): + param_group['momentum'] = mom + elif 'betas' in param_group.keys(): + param_group['betas'] = (mom, param_group['betas'][1]) + + def get_momentum(self, runner, param_group): + curr_iter = runner.iter + start_iter = 0 + for i, phase in enumerate(self.momentum_phases): + end_iter = phase['end_iter'] + if curr_iter <= end_iter or i == len(self.momentum_phases) - 1: + pct = (curr_iter - start_iter) / (end_iter - start_iter) + momentum = self.anneal_func( + param_group[phase['start_momentum']], + param_group[phase['end_momentum']], pct) + break + start_iter = end_iter + return momentum + + def get_regular_momentum(self, runner): + if isinstance(runner.optimizer, dict): + momentum_groups = {} + for k, optim in runner.optimizer.items(): + _momentum_group = [ + self.get_momentum(runner, param_group) + for param_group in optim.param_groups + ] + momentum_groups.update({k: _momentum_group}) + return momentum_groups + else: + momentum_groups = [] + for param_group in runner.optimizer.param_groups: + momentum_groups.append(self.get_momentum(runner, param_group)) + return momentum_groups diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/optimizer.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/optimizer.py new file mode 100755 index 00000000..b0e10118 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/optimizer.py @@ -0,0 +1,554 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import logging +from collections import defaultdict +from itertools import chain + +from torch.nn.utils import clip_grad + +from dbnet_cv.utils import TORCH_VERSION, _BatchNorm, digit_version +from ..dist_utils import allreduce_grads +from ..fp16_utils import LossScaler, wrap_fp16_model +from .hook import HOOKS, Hook + +try: + # If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported + # and used; otherwise, auto fp16 will adopt dbnet_cv's implementation. + from torch.cuda.amp import GradScaler +except ImportError: + pass + + +@HOOKS.register_module() +class OptimizerHook(Hook): + """A hook contains custom operations for the optimizer. + + Args: + grad_clip (dict, optional): A config dict to control the clip_grad. + Default: None. + detect_anomalous_params (bool): This option is only used for + debugging which will slow down the training speed. + Detect anomalous parameters that are not included in + the computational graph with `loss` as the root. + There are two cases + + - Parameters were not used during + forward pass. + - Parameters were not used to produce + loss. + Default: False. + """ + + def __init__(self, grad_clip=None, detect_anomalous_params=False): + self.grad_clip = grad_clip + self.detect_anomalous_params = detect_anomalous_params + + def clip_grads(self, params): + params = list( + filter(lambda p: p.requires_grad and p.grad is not None, params)) + if len(params) > 0: + return clip_grad.clip_grad_norm_(params, **self.grad_clip) + + def after_train_iter(self, runner): + runner.optimizer.zero_grad() + if self.detect_anomalous_params: + self.detect_anomalous_parameters(runner.outputs['loss'], runner) + runner.outputs['loss'].backward() + + if self.grad_clip is not None: + grad_norm = self.clip_grads(runner.model.parameters()) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update({'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + runner.optimizer.step() + + def detect_anomalous_parameters(self, loss, runner): + logger = runner.logger + parameters_in_graph = set() + visited = set() + + def traverse(grad_fn): + if grad_fn is None: + return + if grad_fn not in visited: + visited.add(grad_fn) + if hasattr(grad_fn, 'variable'): + parameters_in_graph.add(grad_fn.variable) + parents = grad_fn.next_functions + if parents is not None: + for parent in parents: + grad_fn = parent[0] + traverse(grad_fn) + + traverse(loss.grad_fn) + for n, p in runner.model.named_parameters(): + if p not in parameters_in_graph and p.requires_grad: + logger.log( + level=logging.ERROR, + msg=f'{n} with shape {p.size()} is not ' + f'in the computational graph \n') + + +@HOOKS.register_module() +class GradientCumulativeOptimizerHook(OptimizerHook): + """Optimizer Hook implements multi-iters gradient cumulating. + + Args: + cumulative_iters (int, optional): Num of gradient cumulative iters. + The optimizer will step every `cumulative_iters` iters. + Defaults to 1. + + Examples: + >>> # Use cumulative_iters to simulate a large batch size + >>> # It is helpful when the hardware cannot handle a large batch size. + >>> loader = DataLoader(data, batch_size=64) + >>> optim_hook = GradientCumulativeOptimizerHook(cumulative_iters=4) + >>> # almost equals to + >>> loader = DataLoader(data, batch_size=256) + >>> optim_hook = OptimizerHook() + """ + + def __init__(self, cumulative_iters=1, **kwargs): + super().__init__(**kwargs) + + assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \ + f'cumulative_iters only accepts positive int, but got ' \ + f'{type(cumulative_iters)} instead.' + + self.cumulative_iters = cumulative_iters + self.divisible_iters = 0 + self.remainder_iters = 0 + self.initialized = False + + def has_batch_norm(self, module): + if isinstance(module, _BatchNorm): + return True + for m in module.children(): + if self.has_batch_norm(m): + return True + return False + + def _init(self, runner): + if runner.iter % self.cumulative_iters != 0: + runner.logger.warning( + 'Resume iter number is not divisible by cumulative_iters in ' + 'GradientCumulativeOptimizerHook, which means the gradient of ' + 'some iters is lost and the result may be influenced slightly.' + ) + + if self.has_batch_norm(runner.model) and self.cumulative_iters > 1: + runner.logger.warning( + 'GradientCumulativeOptimizerHook may slightly decrease ' + 'performance if the model has BatchNorm layers.') + + residual_iters = runner.max_iters - runner.iter + + self.divisible_iters = ( + residual_iters // self.cumulative_iters * self.cumulative_iters) + self.remainder_iters = residual_iters - self.divisible_iters + + self.initialized = True + + def after_train_iter(self, runner): + if not self.initialized: + self._init(runner) + + if runner.iter < self.divisible_iters: + loss_factor = self.cumulative_iters + else: + loss_factor = self.remainder_iters + loss = runner.outputs['loss'] + loss = loss / loss_factor + loss.backward() + + if (self.every_n_iters(runner, self.cumulative_iters) + or self.is_last_iter(runner)): + + if self.grad_clip is not None: + grad_norm = self.clip_grads(runner.model.parameters()) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update({'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + runner.optimizer.step() + runner.optimizer.zero_grad() + + +if (TORCH_VERSION != 'parrots' + and digit_version(TORCH_VERSION) >= digit_version('1.6.0')): + + @HOOKS.register_module() + class Fp16OptimizerHook(OptimizerHook): + """FP16 optimizer hook (using PyTorch's implementation). + + If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend, + to take care of the optimization procedure. + + Args: + loss_scale (float | str | dict): Scale factor configuration. + If loss_scale is a float, static loss scaling will be used with + the specified scale. If loss_scale is a string, it must be + 'dynamic', then dynamic loss scaling will be used. + It can also be a dict containing arguments of GradScalar. + Defaults to 512. For Pytorch >= 1.6, dbnet_cv uses official + implementation of GradScaler. If you use a dict version of + loss_scale to create GradScaler, please refer to: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler + for the parameters. + + Examples: + >>> loss_scale = dict( + ... init_scale=65536.0, + ... growth_factor=2.0, + ... backoff_factor=0.5, + ... growth_interval=2000 + ... ) + >>> optimizer_hook = Fp16OptimizerHook(loss_scale=loss_scale) + """ + + def __init__(self, + grad_clip=None, + coalesce=True, + bucket_size_mb=-1, + loss_scale=512., + distributed=True): + self.grad_clip = grad_clip + self.coalesce = coalesce + self.bucket_size_mb = bucket_size_mb + self.distributed = distributed + self._scale_update_param = None + if loss_scale == 'dynamic': + self.loss_scaler = GradScaler() + elif isinstance(loss_scale, float): + self._scale_update_param = loss_scale + self.loss_scaler = GradScaler(init_scale=loss_scale) + elif isinstance(loss_scale, dict): + self.loss_scaler = GradScaler(**loss_scale) + else: + raise ValueError('loss_scale must be of type float, dict, or ' + f'"dynamic", got {loss_scale}') + + def before_run(self, runner): + """Preparing steps before Mixed Precision Training.""" + # wrap model mode to fp16 + wrap_fp16_model(runner.model) + # resume from state dict + if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']: + scaler_state_dict = runner.meta['fp16']['loss_scaler'] + self.loss_scaler.load_state_dict(scaler_state_dict) + + def copy_grads_to_fp32(self, fp16_net, fp32_weights): + """Copy gradients from fp16 model to fp32 weight copy.""" + for fp32_param, fp16_param in zip(fp32_weights, + fp16_net.parameters()): + if fp16_param.grad is not None: + if fp32_param.grad is None: + fp32_param.grad = fp32_param.data.new( + fp32_param.size()) + fp32_param.grad.copy_(fp16_param.grad) + + def copy_params_to_fp16(self, fp16_net, fp32_weights): + """Copy updated params from fp32 weight copy to fp16 model.""" + for fp16_param, fp32_param in zip(fp16_net.parameters(), + fp32_weights): + fp16_param.data.copy_(fp32_param.data) + + def after_train_iter(self, runner): + """Backward optimization steps for Mixed Precision Training. For + dynamic loss scaling, please refer to + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler. + + 1. Scale the loss by a scale factor. + 2. Backward the loss to obtain the gradients. + 3. Unscale the optimizer’s gradient tensors. + 4. Call optimizer.step() and update scale factor. + 5. Save loss_scaler state_dict for resume purpose. + """ + # clear grads of last iteration + runner.model.zero_grad() + runner.optimizer.zero_grad() + + self.loss_scaler.scale(runner.outputs['loss']).backward() + self.loss_scaler.unscale_(runner.optimizer) + # grad clip + if self.grad_clip is not None: + grad_norm = self.clip_grads(runner.model.parameters()) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update({'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + # backward and update scaler + self.loss_scaler.step(runner.optimizer) + self.loss_scaler.update(self._scale_update_param) + + # save state_dict of loss_scaler + runner.meta.setdefault( + 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + @HOOKS.register_module() + class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook, + Fp16OptimizerHook): + """Fp16 optimizer Hook (using PyTorch's implementation) implements + multi-iters gradient cumulating. + + If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend, + to take care of the optimization procedure. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def after_train_iter(self, runner): + if not self.initialized: + self._init(runner) + + if runner.iter < self.divisible_iters: + loss_factor = self.cumulative_iters + else: + loss_factor = self.remainder_iters + loss = runner.outputs['loss'] + loss = loss / loss_factor + + self.loss_scaler.scale(loss).backward() + + if (self.every_n_iters(runner, self.cumulative_iters) + or self.is_last_iter(runner)): + + # copy fp16 grads in the model to fp32 params in the optimizer + self.loss_scaler.unscale_(runner.optimizer) + + if self.grad_clip is not None: + grad_norm = self.clip_grads(runner.model.parameters()) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update( + {'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + + # backward and update scaler + self.loss_scaler.step(runner.optimizer) + self.loss_scaler.update(self._scale_update_param) + + # save state_dict of loss_scaler + runner.meta.setdefault( + 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + # clear grads + runner.model.zero_grad() + runner.optimizer.zero_grad() + +else: + + @HOOKS.register_module() + class Fp16OptimizerHook(OptimizerHook): # type: ignore + """FP16 optimizer hook (dbnet_cv's implementation). + + The steps of fp16 optimizer is as follows. + 1. Scale the loss value. + 2. BP in the fp16 model. + 2. Copy gradients from fp16 model to fp32 weights. + 3. Update fp32 weights. + 4. Copy updated parameters from fp32 weights to fp16 model. + + Refer to https://arxiv.org/abs/1710.03740 for more details. + + Args: + loss_scale (float | str | dict): Scale factor configuration. + If loss_scale is a float, static loss scaling will be used with + the specified scale. If loss_scale is a string, it must be + 'dynamic', then dynamic loss scaling will be used. + It can also be a dict containing arguments of LossScaler. + Defaults to 512. + """ + + def __init__(self, + grad_clip=None, + coalesce=True, + bucket_size_mb=-1, + loss_scale=512., + distributed=True): + self.grad_clip = grad_clip + self.coalesce = coalesce + self.bucket_size_mb = bucket_size_mb + self.distributed = distributed + if loss_scale == 'dynamic': + self.loss_scaler = LossScaler(mode='dynamic') + elif isinstance(loss_scale, float): + self.loss_scaler = LossScaler( + init_scale=loss_scale, mode='static') + elif isinstance(loss_scale, dict): + self.loss_scaler = LossScaler(**loss_scale) + else: + raise ValueError('loss_scale must be of type float, dict, or ' + f'"dynamic", got {loss_scale}') + + def before_run(self, runner): + """Preparing steps before Mixed Precision Training. + + 1. Make a master copy of fp32 weights for optimization. + 2. Convert the main model from fp32 to fp16. + """ + # keep a copy of fp32 weights + old_groups = runner.optimizer.param_groups + runner.optimizer.param_groups = copy.deepcopy( + runner.optimizer.param_groups) + state = defaultdict(dict) + p_map = { + old_p: p + for old_p, p in zip( + chain(*(g['params'] for g in old_groups)), + chain(*(g['params'] + for g in runner.optimizer.param_groups))) + } + for k, v in runner.optimizer.state.items(): + state[p_map[k]] = v + runner.optimizer.state = state + # convert model to fp16 + wrap_fp16_model(runner.model) + # resume from state dict + if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']: + scaler_state_dict = runner.meta['fp16']['loss_scaler'] + self.loss_scaler.load_state_dict(scaler_state_dict) + + def copy_grads_to_fp32(self, fp16_net, fp32_weights): + """Copy gradients from fp16 model to fp32 weight copy.""" + for fp32_param, fp16_param in zip(fp32_weights, + fp16_net.parameters()): + if fp16_param.grad is not None: + if fp32_param.grad is None: + fp32_param.grad = fp32_param.data.new( + fp32_param.size()) + fp32_param.grad.copy_(fp16_param.grad) + + def copy_params_to_fp16(self, fp16_net, fp32_weights): + """Copy updated params from fp32 weight copy to fp16 model.""" + for fp16_param, fp32_param in zip(fp16_net.parameters(), + fp32_weights): + fp16_param.data.copy_(fp32_param.data) + + def after_train_iter(self, runner): + """Backward optimization steps for Mixed Precision Training. For + dynamic loss scaling, please refer `loss_scalar.py` + + 1. Scale the loss by a scale factor. + 2. Backward the loss to obtain the gradients (fp16). + 3. Copy gradients from the model to the fp32 weight copy. + 4. Scale the gradients back and update the fp32 weight copy. + 5. Copy back the params from fp32 weight copy to the fp16 model. + 6. Save loss_scaler state_dict for resume purpose. + """ + # clear grads of last iteration + runner.model.zero_grad() + runner.optimizer.zero_grad() + # scale the loss value + scaled_loss = runner.outputs['loss'] * self.loss_scaler.loss_scale + scaled_loss.backward() + # copy fp16 grads in the model to fp32 params in the optimizer + + fp32_weights = [] + for param_group in runner.optimizer.param_groups: + fp32_weights += param_group['params'] + self.copy_grads_to_fp32(runner.model, fp32_weights) + # allreduce grads + if self.distributed: + allreduce_grads(fp32_weights, self.coalesce, + self.bucket_size_mb) + + has_overflow = self.loss_scaler.has_overflow(fp32_weights) + # if has overflow, skip this iteration + if not has_overflow: + # scale the gradients back + for param in fp32_weights: + if param.grad is not None: + param.grad.div_(self.loss_scaler.loss_scale) + if self.grad_clip is not None: + grad_norm = self.clip_grads(fp32_weights) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update( + {'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + # update fp32 params + runner.optimizer.step() + # copy fp32 params to the fp16 model + self.copy_params_to_fp16(runner.model, fp32_weights) + self.loss_scaler.update_scale(has_overflow) + if has_overflow: + runner.logger.warning('Check overflow, downscale loss scale ' + f'to {self.loss_scaler.cur_scale}') + + # save state_dict of loss_scaler + runner.meta.setdefault( + 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + @HOOKS.register_module() + class GradientCumulativeFp16OptimizerHook( # type: ignore + GradientCumulativeOptimizerHook, Fp16OptimizerHook): + """Fp16 optimizer Hook (using dbnet_cv implementation) implements multi- + iters gradient cumulating.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def after_train_iter(self, runner): + if not self.initialized: + self._init(runner) + + if runner.iter < self.divisible_iters: + loss_factor = self.cumulative_iters + else: + loss_factor = self.remainder_iters + + loss = runner.outputs['loss'] + loss = loss / loss_factor + + # scale the loss value + scaled_loss = loss * self.loss_scaler.loss_scale + scaled_loss.backward() + + if (self.every_n_iters(runner, self.cumulative_iters) + or self.is_last_iter(runner)): + + # copy fp16 grads in the model to fp32 params in the optimizer + fp32_weights = [] + for param_group in runner.optimizer.param_groups: + fp32_weights += param_group['params'] + self.copy_grads_to_fp32(runner.model, fp32_weights) + # allreduce grads + if self.distributed: + allreduce_grads(fp32_weights, self.coalesce, + self.bucket_size_mb) + + has_overflow = self.loss_scaler.has_overflow(fp32_weights) + # if has overflow, skip this iteration + if not has_overflow: + # scale the gradients back + for param in fp32_weights: + if param.grad is not None: + param.grad.div_(self.loss_scaler.loss_scale) + if self.grad_clip is not None: + grad_norm = self.clip_grads(fp32_weights) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update( + {'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + # update fp32 params + runner.optimizer.step() + # copy fp32 params to the fp16 model + self.copy_params_to_fp16(runner.model, fp32_weights) + else: + runner.logger.warning( + 'Check overflow, downscale loss scale ' + f'to {self.loss_scaler.cur_scale}') + + self.loss_scaler.update_scale(has_overflow) + + # save state_dict of loss_scaler + runner.meta.setdefault( + 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + # clear grads + runner.model.zero_grad() + runner.optimizer.zero_grad() diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/profiler.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/profiler.py new file mode 100755 index 00000000..fef9adc1 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/profiler.py @@ -0,0 +1,180 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Callable, List, Optional, Union + +import torch + +from ..dist_utils import master_only +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class ProfilerHook(Hook): + """Profiler to analyze performance during training. + + PyTorch Profiler is a tool that allows the collection of the performance + metrics during the training. More details on Profiler can be found at + https://pytorch.org/docs/1.8.1/profiler.html#torch.profiler.profile + + Args: + by_epoch (bool): Profile performance by epoch or by iteration. + Default: True. + profile_iters (int): Number of iterations for profiling. + If ``by_epoch=True``, profile_iters indicates that they are the + first profile_iters epochs at the beginning of the + training, otherwise it indicates the first profile_iters + iterations. Default: 1. + activities (list[str]): List of activity groups (CPU, CUDA) to use in + profiling. Default: ['cpu', 'cuda']. + schedule (dict, optional): Config of generating the callable schedule. + if schedule is None, profiler will not add step markers into the + trace and table view. Default: None. + on_trace_ready (callable, dict): Either a handler or a dict of generate + handler. Default: None. + record_shapes (bool): Save information about operator's input shapes. + Default: False. + profile_memory (bool): Track tensor memory allocation/deallocation. + Default: False. + with_stack (bool): Record source information (file and line number) + for the ops. Default: False. + with_flops (bool): Use formula to estimate the FLOPS of specific + operators (matrix multiplication and 2D convolution). + Default: False. + json_trace_path (str, optional): Exports the collected trace in Chrome + JSON format. Default: None. + + Example: + >>> runner = ... # instantiate a Runner + >>> # tensorboard trace + >>> trace_config = dict(type='tb_trace', dir_name='work_dir') + >>> profiler_config = dict(on_trace_ready=trace_config) + >>> runner.register_profiler_hook(profiler_config) + >>> runner.run(data_loaders=[trainloader], workflow=[('train', 1)]) + """ + + def __init__(self, + by_epoch: bool = True, + profile_iters: int = 1, + activities: List[str] = ['cpu', 'cuda'], + schedule: Optional[dict] = None, + on_trace_ready: Optional[Union[Callable, dict]] = None, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + json_trace_path: Optional[str] = None) -> None: + try: + from torch import profiler # torch version >= 1.8.1 + except ImportError: + raise ImportError('profiler is the new feature of torch1.8.1, ' + f'but your version is {torch.__version__}') + + assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean.' + self.by_epoch = by_epoch + + if profile_iters < 1: + raise ValueError('profile_iters should be greater than 0, but got ' + f'{profile_iters}') + self.profile_iters = profile_iters + + if not isinstance(activities, list): + raise ValueError( + f'activities should be list, but got {type(activities)}') + self.activities = [] + for activity in activities: + activity = activity.lower() + if activity == 'cpu': + self.activities.append(profiler.ProfilerActivity.CPU) + elif activity == 'cuda': + self.activities.append(profiler.ProfilerActivity.CUDA) + else: + raise ValueError( + f'activity should be "cpu" or "cuda", but got {activity}') + + if schedule is not None: + self.schedule = profiler.schedule(**schedule) + else: + self.schedule = None + + self.on_trace_ready = on_trace_ready + self.record_shapes = record_shapes + self.profile_memory = profile_memory + self.with_stack = with_stack + self.with_flops = with_flops + self.json_trace_path = json_trace_path + + @master_only + def before_run(self, runner): + if self.by_epoch and runner.max_epochs < self.profile_iters: + raise ValueError('self.profile_iters should not be greater than ' + f'{runner.max_epochs}') + + if not self.by_epoch and runner.max_iters < self.profile_iters: + raise ValueError('self.profile_iters should not be greater than ' + f'{runner.max_iters}') + + if callable(self.on_trace_ready): # handler + _on_trace_ready = self.on_trace_ready + elif isinstance(self.on_trace_ready, dict): # config of handler + trace_cfg = self.on_trace_ready.copy() + trace_type = trace_cfg.pop('type') # log_trace handler + if trace_type == 'log_trace': + + def _log_handler(prof): + print(prof.key_averages().table(**trace_cfg)) + + _on_trace_ready = _log_handler + elif trace_type == 'tb_trace': # tensorboard_trace handler + try: + import torch_tb_profiler # noqa: F401 + except ImportError: + raise ImportError('please run "pip install ' + 'torch-tb-profiler" to install ' + 'torch_tb_profiler') + _on_trace_ready = torch.profiler.tensorboard_trace_handler( + **trace_cfg) + else: + raise ValueError('trace_type should be "log_trace" or ' + f'"tb_trace", but got {trace_type}') + elif self.on_trace_ready is None: + _on_trace_ready = None # type: ignore + else: + raise ValueError('on_trace_ready should be handler, dict or None, ' + f'but got {type(self.on_trace_ready)}') + + if self.by_epoch and runner.max_epochs > 1: + warnings.warn(f'profiler will profile {runner.max_epochs} epochs ' + 'instead of 1 epoch. Since profiler will slow down ' + 'the training, it is recommended to train 1 epoch ' + 'with ProfilerHook and adjust your setting according' + ' to the profiler summary. During normal training ' + '(epoch > 1), you may disable the ProfilerHook.') + + self.profiler = torch.profiler.profile( + activities=self.activities, + schedule=self.schedule, + on_trace_ready=_on_trace_ready, + record_shapes=self.record_shapes, + profile_memory=self.profile_memory, + with_stack=self.with_stack, + with_flops=self.with_flops) + + self.profiler.__enter__() + runner.logger.info('profiler is profiling...') + + @master_only + def after_train_epoch(self, runner): + if self.by_epoch and runner.epoch == self.profile_iters - 1: + runner.logger.info('profiler may take a few minutes...') + self.profiler.__exit__(None, None, None) + if self.json_trace_path is not None: + self.profiler.export_chrome_trace(self.json_trace_path) + + @master_only + def after_train_iter(self, runner): + self.profiler.step() + if not self.by_epoch and runner.iter == self.profile_iters - 1: + runner.logger.info('profiler may take a few minutes...') + self.profiler.__exit__(None, None, None) + if self.json_trace_path is not None: + self.profiler.export_chrome_trace(self.json_trace_path) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/sampler_seed.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/sampler_seed.py new file mode 100755 index 00000000..ee0dc6bd --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/sampler_seed.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class DistSamplerSeedHook(Hook): + """Data-loading sampler for distributed training. + + When distributed training, it is only useful in conjunction with + :obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same + purpose with :obj:`IterLoader`. + """ + + def before_epoch(self, runner): + if hasattr(runner.data_loader.sampler, 'set_epoch'): + # in case the data loader uses `SequentialSampler` in Pytorch + runner.data_loader.sampler.set_epoch(runner.epoch) + elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'): + # batch sampler in pytorch warps the sampler as its attributes. + runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/sync_buffer.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/sync_buffer.py new file mode 100755 index 00000000..6376b7ff --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/hooks/sync_buffer.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..dist_utils import allreduce_params +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class SyncBuffersHook(Hook): + """Synchronize model buffers such as running_mean and running_var in BN at + the end of each epoch. + + Args: + distributed (bool): Whether distributed training is used. It is + effective only for distributed training. Defaults to True. + """ + + def __init__(self, distributed=True): + self.distributed = distributed + + def after_epoch(self, runner): + """All-reduce model buffers at the end of each epoch.""" + if self.distributed: + allreduce_params(runner.model.buffers()) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/iter_based_runner.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/iter_based_runner.py new file mode 100755 index 00000000..851ad37e --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/iter_based_runner.py @@ -0,0 +1,277 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import platform +import shutil +import time +import warnings + +import torch +from torch.optim import Optimizer + +import dbnet_cv +from .base_runner import BaseRunner +from .builder import RUNNERS +from .checkpoint import save_checkpoint +from .hooks import IterTimerHook +from .utils import get_host_info + + +class IterLoader: + + def __init__(self, dataloader): + self._dataloader = dataloader + self.iter_loader = iter(self._dataloader) + self._epoch = 0 + + @property + def epoch(self): + return self._epoch + + def __next__(self): + try: + data = next(self.iter_loader) + except StopIteration: + self._epoch += 1 + if hasattr(self._dataloader.sampler, 'set_epoch'): + self._dataloader.sampler.set_epoch(self._epoch) + time.sleep(2) # Prevent possible deadlock during epoch transition + self.iter_loader = iter(self._dataloader) + data = next(self.iter_loader) + + return data + + def __len__(self): + return len(self._dataloader) + + +@RUNNERS.register_module() +class IterBasedRunner(BaseRunner): + """Iteration-based Runner. + + This runner train models iteration by iteration. + """ + + def train(self, data_loader, **kwargs): + self.model.train() + self.mode = 'train' + self.data_loader = data_loader + self._epoch = data_loader.epoch + data_batch = next(data_loader) + self.data_batch = data_batch + self.call_hook('before_train_iter') + outputs = self.model.train_step(data_batch, self.optimizer, **kwargs) + if not isinstance(outputs, dict): + raise TypeError('model.train_step() must return a dict') + if 'log_vars' in outputs: + self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) + self.outputs = outputs + self.call_hook('after_train_iter') + del self.data_batch + self._inner_iter += 1 + self._iter += 1 + + @torch.no_grad() + def val(self, data_loader, **kwargs): + self.model.eval() + self.mode = 'val' + self.data_loader = data_loader + data_batch = next(data_loader) + self.data_batch = data_batch + self.call_hook('before_val_iter') + outputs = self.model.val_step(data_batch, **kwargs) + if not isinstance(outputs, dict): + raise TypeError('model.val_step() must return a dict') + if 'log_vars' in outputs: + self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) + self.outputs = outputs + self.call_hook('after_val_iter') + del self.data_batch + self._inner_iter += 1 + + def run(self, data_loaders, workflow, max_iters=None, **kwargs): + """Start running. + + Args: + data_loaders (list[:obj:`DataLoader`]): Dataloaders for training + and validation. + workflow (list[tuple]): A list of (phase, iters) to specify the + running order and iterations. E.g, [('train', 10000), + ('val', 1000)] means running 10000 iterations for training and + 1000 iterations for validation, iteratively. + """ + assert isinstance(data_loaders, list) + assert dbnet_cv.is_list_of(workflow, tuple) + assert len(data_loaders) == len(workflow) + if max_iters is not None: + warnings.warn( + 'setting max_iters in run is deprecated, ' + 'please set max_iters in runner_config', DeprecationWarning) + self._max_iters = max_iters + assert self._max_iters is not None, ( + 'max_iters must be specified during instantiation') + + work_dir = self.work_dir if self.work_dir is not None else 'NONE' + self.logger.info('Start running, host: %s, work_dir: %s', + get_host_info(), work_dir) + self.logger.info('Hooks will be executed in the following order:\n%s', + self.get_hook_info()) + self.logger.info('workflow: %s, max: %d iters', workflow, + self._max_iters) + self.call_hook('before_run') + + iter_loaders = [IterLoader(x) for x in data_loaders] + + self.call_hook('before_epoch') + + while self.iter < self._max_iters: + for i, flow in enumerate(workflow): + self._inner_iter = 0 + mode, iters = flow + if not isinstance(mode, str) or not hasattr(self, mode): + raise ValueError( + 'runner has no method named "{}" to run a workflow'. + format(mode)) + iter_runner = getattr(self, mode) + for _ in range(iters): + if mode == 'train' and self.iter >= self._max_iters: + break + iter_runner(iter_loaders[i], **kwargs) + + time.sleep(1) # wait for some hooks like loggers to finish + self.call_hook('after_epoch') + self.call_hook('after_run') + + def resume(self, + checkpoint, + resume_optimizer=True, + map_location='default'): + """Resume model from checkpoint. + + Args: + checkpoint (str): Checkpoint to resume from. + resume_optimizer (bool, optional): Whether resume the optimizer(s) + if the checkpoint file includes optimizer(s). Default to True. + map_location (str, optional): Same as :func:`torch.load`. + Default to 'default'. + """ + if map_location == 'default': + device_id = torch.cuda.current_device() + checkpoint = self.load_checkpoint( + checkpoint, + map_location=lambda storage, loc: storage.cuda(device_id)) + else: + checkpoint = self.load_checkpoint( + checkpoint, map_location=map_location) + + self._epoch = checkpoint['meta']['epoch'] + self._iter = checkpoint['meta']['iter'] + self._inner_iter = checkpoint['meta']['iter'] + if 'optimizer' in checkpoint and resume_optimizer: + if isinstance(self.optimizer, Optimizer): + self.optimizer.load_state_dict(checkpoint['optimizer']) + elif isinstance(self.optimizer, dict): + for k in self.optimizer.keys(): + self.optimizer[k].load_state_dict( + checkpoint['optimizer'][k]) + else: + raise TypeError( + 'Optimizer should be dict or torch.optim.Optimizer ' + f'but got {type(self.optimizer)}') + + self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}') + + def save_checkpoint(self, + out_dir, + filename_tmpl='iter_{}.pth', + meta=None, + save_optimizer=True, + create_symlink=True): + """Save checkpoint to file. + + Args: + out_dir (str): Directory to save checkpoint files. + filename_tmpl (str, optional): Checkpoint file template. + Defaults to 'iter_{}.pth'. + meta (dict, optional): Metadata to be saved in checkpoint. + Defaults to None. + save_optimizer (bool, optional): Whether save optimizer. + Defaults to True. + create_symlink (bool, optional): Whether create symlink to the + latest checkpoint file. Defaults to True. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError( + f'meta should be a dict or None, but got {type(meta)}') + if self.meta is not None: + meta.update(self.meta) + # Note: meta.update(self.meta) should be done before + # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise + # there will be problems with resumed checkpoints. + # More details in https://github.com/open-mmlab/dbnet_cv/pull/1108 + meta.update(epoch=self.epoch + 1, iter=self.iter) + + filename = filename_tmpl.format(self.iter + 1) + filepath = osp.join(out_dir, filename) + optimizer = self.optimizer if save_optimizer else None + save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) + # in some environments, `os.symlink` is not supported, you may need to + # set `create_symlink` to False + if create_symlink: + dst_file = osp.join(out_dir, 'latest.pth') + if platform.system() != 'Windows': + dbnet_cv.symlink(filename, dst_file) + else: + shutil.copy(filepath, dst_file) + + def register_training_hooks(self, + lr_config, + optimizer_config=None, + checkpoint_config=None, + log_config=None, + momentum_config=None, + custom_hooks_config=None): + """Register default hooks for iter-based training. + + Checkpoint hook, optimizer stepper hook and logger hooks will be set to + `by_epoch=False` by default. + + Default hooks include: + + +----------------------+-------------------------+ + | Hooks | Priority | + +======================+=========================+ + | LrUpdaterHook | VERY_HIGH (10) | + +----------------------+-------------------------+ + | MomentumUpdaterHook | HIGH (30) | + +----------------------+-------------------------+ + | OptimizerStepperHook | ABOVE_NORMAL (40) | + +----------------------+-------------------------+ + | CheckpointSaverHook | NORMAL (50) | + +----------------------+-------------------------+ + | IterTimerHook | LOW (70) | + +----------------------+-------------------------+ + | LoggerHook(s) | VERY_LOW (90) | + +----------------------+-------------------------+ + | CustomHook(s) | defaults to NORMAL (50) | + +----------------------+-------------------------+ + + If custom hooks have same priority with default hooks, custom hooks + will be triggered after default hooks. + """ + if checkpoint_config is not None: + checkpoint_config.setdefault('by_epoch', False) + if lr_config is not None: + lr_config.setdefault('by_epoch', False) + if log_config is not None: + for info in log_config['hooks']: + info.setdefault('by_epoch', False) + super().register_training_hooks( + lr_config=lr_config, + momentum_config=momentum_config, + optimizer_config=optimizer_config, + checkpoint_config=checkpoint_config, + log_config=log_config, + timer_config=IterTimerHook(), + custom_hooks_config=custom_hooks_config) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/log_buffer.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/log_buffer.py new file mode 100755 index 00000000..3c9f3796 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/log_buffer.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict + +import numpy as np + + +class LogBuffer: + + def __init__(self): + self.val_history = OrderedDict() + self.n_history = OrderedDict() + self.output = OrderedDict() + self.ready = False + + def clear(self) -> None: + self.val_history.clear() + self.n_history.clear() + self.clear_output() + + def clear_output(self) -> None: + self.output.clear() + self.ready = False + + def update(self, vars: dict, count: int = 1) -> None: + assert isinstance(vars, dict) + for key, var in vars.items(): + if key not in self.val_history: + self.val_history[key] = [] + self.n_history[key] = [] + self.val_history[key].append(var) + self.n_history[key].append(count) + + def average(self, n: int = 0) -> None: + """Average latest n values or all values.""" + assert n >= 0 + for key in self.val_history: + values = np.array(self.val_history[key][-n:]) + nums = np.array(self.n_history[key][-n:]) + avg = np.sum(values * nums) / np.sum(nums) + self.output[key] = avg + self.ready = True diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/optimizer/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/optimizer/__init__.py new file mode 100755 index 00000000..53c34d04 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/optimizer/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import (OPTIMIZER_BUILDERS, OPTIMIZERS, build_optimizer, + build_optimizer_constructor) +from .default_constructor import DefaultOptimizerConstructor + +__all__ = [ + 'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor', + 'build_optimizer', 'build_optimizer_constructor' +] diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/optimizer/builder.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/optimizer/builder.py new file mode 100755 index 00000000..49d8f05a --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/optimizer/builder.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import inspect +from typing import Dict, List + +import torch + +from ...utils import Registry, build_from_cfg + +OPTIMIZERS = Registry('optimizer') +OPTIMIZER_BUILDERS = Registry('optimizer builder') + + +def register_torch_optimizers() -> List: + torch_optimizers = [] + for module_name in dir(torch.optim): + if module_name.startswith('__'): + continue + _optim = getattr(torch.optim, module_name) + if inspect.isclass(_optim) and issubclass(_optim, + torch.optim.Optimizer): + OPTIMIZERS.register_module()(_optim) + torch_optimizers.append(module_name) + return torch_optimizers + + +TORCH_OPTIMIZERS = register_torch_optimizers() + + +def build_optimizer_constructor(cfg: Dict): + return build_from_cfg(cfg, OPTIMIZER_BUILDERS) + + +def build_optimizer(model, cfg: Dict): + optimizer_cfg = copy.deepcopy(cfg) + constructor_type = optimizer_cfg.pop('constructor', + 'DefaultOptimizerConstructor') + paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) + optim_constructor = build_optimizer_constructor( + dict( + type=constructor_type, + optimizer_cfg=optimizer_cfg, + paramwise_cfg=paramwise_cfg)) + optimizer = optim_constructor(model) + return optimizer diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/optimizer/default_constructor.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/optimizer/default_constructor.py new file mode 100755 index 00000000..93fd0bde --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/optimizer/default_constructor.py @@ -0,0 +1,258 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from torch.nn import GroupNorm, LayerNorm + +from dbnet_cv.utils import _BatchNorm, _InstanceNorm, build_from_cfg, is_list_of +from dbnet_cv.utils.ext_loader import check_ops_exist +from .builder import OPTIMIZER_BUILDERS, OPTIMIZERS + + +@OPTIMIZER_BUILDERS.register_module() +class DefaultOptimizerConstructor: + """Default constructor for optimizers. + + By default each parameter share the same optimizer settings, and we + provide an argument ``paramwise_cfg`` to specify parameter-wise settings. + It is a dict and may contain the following fields: + + - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If + one of the keys in ``custom_keys`` is a substring of the name of one + parameter, then the setting of the parameter will be specified by + ``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will + be ignored. It should be noted that the aforementioned ``key`` is the + longest key that is a substring of the name of the parameter. If there + are multiple matched keys with the same length, then the key with lower + alphabet order will be chosen. + ``custom_keys[key]`` should be a dict and may contain fields ``lr_mult`` + and ``decay_mult``. See Example 2 below. + - ``bias_lr_mult`` (float): It will be multiplied to the learning + rate for all bias parameters (except for those in normalization + layers and offset layers of DCN). + - ``bias_decay_mult`` (float): It will be multiplied to the weight + decay for all bias parameters (except for those in + normalization layers, depthwise conv layers, offset layers of DCN). + - ``norm_decay_mult`` (float): It will be multiplied to the weight + decay for all weight and bias parameters of normalization + layers. + - ``dwconv_decay_mult`` (float): It will be multiplied to the weight + decay for all weight and bias parameters of depthwise conv + layers. + - ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning + rate for parameters of offset layer in the deformable convs + of a model. + - ``bypass_duplicate`` (bool): If true, the duplicate parameters + would not be added into optimizer. Default: False. + + Note: + + 1. If the option ``dcn_offset_lr_mult`` is used, the constructor will + override the effect of ``bias_lr_mult`` in the bias of offset layer. + So be careful when using both ``bias_lr_mult`` and + ``dcn_offset_lr_mult``. If you wish to apply both of them to the offset + layer in deformable convs, set ``dcn_offset_lr_mult`` to the original + ``dcn_offset_lr_mult`` * ``bias_lr_mult``. + + 2. If the option ``dcn_offset_lr_mult`` is used, the constructor will + apply it to all the DCN layers in the model. So be careful when the + model contains multiple DCN layers in places other than backbone. + + Args: + model (:obj:`nn.Module`): The model with parameters to be optimized. + optimizer_cfg (dict): The config dict of the optimizer. + Positional fields are + + - `type`: class name of the optimizer. + + Optional fields are + + - any arguments of the corresponding optimizer type, e.g., + lr, weight_decay, momentum, etc. + paramwise_cfg (dict, optional): Parameter-wise options. + + Example 1: + >>> model = torch.nn.modules.Conv1d(1, 1, 1) + >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9, + >>> weight_decay=0.0001) + >>> paramwise_cfg = dict(norm_decay_mult=0.) + >>> optim_builder = DefaultOptimizerConstructor( + >>> optimizer_cfg, paramwise_cfg) + >>> optimizer = optim_builder(model) + + Example 2: + >>> # assume model have attribute model.backbone and model.cls_head + >>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95) + >>> paramwise_cfg = dict(custom_keys={ + 'backbone': dict(lr_mult=0.1, decay_mult=0.9)}) + >>> optim_builder = DefaultOptimizerConstructor( + >>> optimizer_cfg, paramwise_cfg) + >>> optimizer = optim_builder(model) + >>> # Then the `lr` and `weight_decay` for model.backbone is + >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for + >>> # model.cls_head is (0.01, 0.95). + """ + + def __init__(self, + optimizer_cfg: Dict, + paramwise_cfg: Optional[Dict] = None): + if not isinstance(optimizer_cfg, dict): + raise TypeError('optimizer_cfg should be a dict', + f'but got {type(optimizer_cfg)}') + self.optimizer_cfg = optimizer_cfg + self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg + self.base_lr = optimizer_cfg.get('lr', None) + self.base_wd = optimizer_cfg.get('weight_decay', None) + self._validate_cfg() + + def _validate_cfg(self) -> None: + if not isinstance(self.paramwise_cfg, dict): + raise TypeError('paramwise_cfg should be None or a dict, ' + f'but got {type(self.paramwise_cfg)}') + + if 'custom_keys' in self.paramwise_cfg: + if not isinstance(self.paramwise_cfg['custom_keys'], dict): + raise TypeError( + 'If specified, custom_keys must be a dict, ' + f'but got {type(self.paramwise_cfg["custom_keys"])}') + if self.base_wd is None: + for key in self.paramwise_cfg['custom_keys']: + if 'decay_mult' in self.paramwise_cfg['custom_keys'][key]: + raise ValueError('base_wd should not be None') + + # get base lr and weight decay + # weight_decay must be explicitly specified if mult is specified + if ('bias_decay_mult' in self.paramwise_cfg + or 'norm_decay_mult' in self.paramwise_cfg + or 'dwconv_decay_mult' in self.paramwise_cfg): + if self.base_wd is None: + raise ValueError('base_wd should not be None') + + def _is_in(self, param_group: Dict, param_group_list: List) -> bool: + assert is_list_of(param_group_list, dict) + param = set(param_group['params']) + param_set = set() + for group in param_group_list: + param_set.update(set(group['params'])) + + return not param.isdisjoint(param_set) + + def add_params(self, + params: List[Dict], + module: nn.Module, + prefix: str = '', + is_dcn_module: Union[int, float, None] = None) -> None: + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + + Args: + params (list[dict]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + prefix (str): The prefix of the module + is_dcn_module (int|float|None): If the current module is a + submodule of DCN, `is_dcn_module` will be passed to + control conv_offset layer's learning rate. Defaults to None. + """ + # get param-wise options + custom_keys = self.paramwise_cfg.get('custom_keys', {}) + # first sort with alphabet order and then sort with reversed len of str + sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) + + bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.) + bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.) + norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.) + dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', 1.) + bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False) + dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', 1.) + + # special rules for norm layers and depth-wise conv layers + is_norm = isinstance(module, + (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) + is_dwconv = ( + isinstance(module, torch.nn.Conv2d) + and module.in_channels == module.groups) + + for name, param in module.named_parameters(recurse=False): + param_group = {'params': [param]} + if not param.requires_grad: + params.append(param_group) + continue + if bypass_duplicate and self._is_in(param_group, params): + warnings.warn(f'{prefix} is duplicate. It is skipped since ' + f'bypass_duplicate={bypass_duplicate}') + continue + # if the parameter match one of the custom keys, ignore other rules + is_custom = False + for key in sorted_keys: + if key in f'{prefix}.{name}': + is_custom = True + lr_mult = custom_keys[key].get('lr_mult', 1.) + param_group['lr'] = self.base_lr * lr_mult + if self.base_wd is not None: + decay_mult = custom_keys[key].get('decay_mult', 1.) + param_group['weight_decay'] = self.base_wd * decay_mult + break + + if not is_custom: + # bias_lr_mult affects all bias parameters + # except for norm.bias dcn.conv_offset.bias + if name == 'bias' and not (is_norm or is_dcn_module): + param_group['lr'] = self.base_lr * bias_lr_mult + + if (prefix.find('conv_offset') != -1 and is_dcn_module + and isinstance(module, torch.nn.Conv2d)): + # deal with both dcn_offset's bias & weight + param_group['lr'] = self.base_lr * dcn_offset_lr_mult + + # apply weight decay policies + if self.base_wd is not None: + # norm decay + if is_norm: + param_group[ + 'weight_decay'] = self.base_wd * norm_decay_mult + # depth-wise conv + elif is_dwconv: + param_group[ + 'weight_decay'] = self.base_wd * dwconv_decay_mult + # bias lr and decay + elif name == 'bias' and not is_dcn_module: + # TODO: current bias_decay_mult will have affect on DCN + param_group[ + 'weight_decay'] = self.base_wd * bias_decay_mult + params.append(param_group) + + if check_ops_exist(): + from dbnet_cv.ops import DeformConv2d, ModulatedDeformConv2d + is_dcn_module = isinstance(module, + (DeformConv2d, ModulatedDeformConv2d)) + else: + is_dcn_module = False + for child_name, child_mod in module.named_children(): + child_prefix = f'{prefix}.{child_name}' if prefix else child_name + self.add_params( + params, + child_mod, + prefix=child_prefix, + is_dcn_module=is_dcn_module) + + def __call__(self, model: nn.Module): + if hasattr(model, 'module'): + model = model.module + + optimizer_cfg = self.optimizer_cfg.copy() + # if no paramwise option is specified, just use the global setting + if not self.paramwise_cfg: + optimizer_cfg['params'] = model.parameters() + return build_from_cfg(optimizer_cfg, OPTIMIZERS) + + # set param-wise lr and weight decay recursively + params: List[Dict] = [] + self.add_params(params, model) + optimizer_cfg['params'] = params + + return build_from_cfg(optimizer_cfg, OPTIMIZERS) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/priority.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/priority.py new file mode 100755 index 00000000..ff644043 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/priority.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from enum import Enum +from typing import Union + + +class Priority(Enum): + """Hook priority levels. + + +--------------+------------+ + | Level | Value | + +==============+============+ + | HIGHEST | 0 | + +--------------+------------+ + | VERY_HIGH | 10 | + +--------------+------------+ + | HIGH | 30 | + +--------------+------------+ + | ABOVE_NORMAL | 40 | + +--------------+------------+ + | NORMAL | 50 | + +--------------+------------+ + | BELOW_NORMAL | 60 | + +--------------+------------+ + | LOW | 70 | + +--------------+------------+ + | VERY_LOW | 90 | + +--------------+------------+ + | LOWEST | 100 | + +--------------+------------+ + """ + + HIGHEST = 0 + VERY_HIGH = 10 + HIGH = 30 + ABOVE_NORMAL = 40 + NORMAL = 50 + BELOW_NORMAL = 60 + LOW = 70 + VERY_LOW = 90 + LOWEST = 100 + + +def get_priority(priority: Union[int, str, Priority]) -> int: + """Get priority value. + + Args: + priority (int or str or :obj:`Priority`): Priority. + + Returns: + int: The priority value. + """ + if isinstance(priority, int): + if priority < 0 or priority > 100: + raise ValueError('priority must be between 0 and 100') + return priority + elif isinstance(priority, Priority): + return priority.value + elif isinstance(priority, str): + return Priority[priority.upper()].value + else: + raise TypeError('priority must be an integer or Priority enum value') diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/runner/utils.py b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/utils.py new file mode 100755 index 00000000..3aab1f5a --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/runner/utils.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import random +import sys +import time +import warnings +from getpass import getuser +from socket import gethostname +from types import ModuleType +from typing import Optional + +import numpy as np +import torch + +import dbnet_cv + + +def get_host_info(): + """Get hostname and username. + + Return empty string if exception raised, e.g. ``getpass.getuser()`` will + lead to error in docker container + """ + host = '' + try: + host = f'{getuser()}@{gethostname()}' + except Exception as e: + warnings.warn(f'Host or user not found: {str(e)}') + finally: + return host + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def obj_from_dict(info: dict, + parent: Optional[ModuleType] = None, + default_args: Optional[dict] = None): + """Initialize an object from dict. + + The dict must contain the key "type", which indicates the object type, it + can be either a string or type, such as "list" or ``list``. Remaining + fields are treated as the arguments for constructing the object. + + Args: + info (dict): Object types and arguments. + parent (:class:`module`): Module which may containing expected object + classes. + default_args (dict, optional): Default arguments for initializing the + object. + + Returns: + any type: Object built from the dict. + """ + assert isinstance(info, dict) and 'type' in info + assert isinstance(default_args, dict) or default_args is None + args = info.copy() + obj_type = args.pop('type') + if dbnet_cv.is_str(obj_type): + if parent is not None: + obj_type = getattr(parent, obj_type) + else: + obj_type = sys.modules[obj_type] + elif not isinstance(obj_type, type): + raise TypeError('type must be a str or valid type, but ' + f'got {type(obj_type)}') + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + return obj_type(**args) + + +def set_random_seed(seed: int, + deterministic: bool = False, + use_rank_shift: bool = False) -> None: + """Set random seed. + + Args: + seed (int): Seed to be used. + deterministic (bool): Whether to set the deterministic option for + CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` + to True and `torch.backends.cudnn.benchmark` to False. + Default: False. + rank_shift (bool): Whether to add rank number to the random seed to + have different random seed in different threads. Default: False. + """ + if use_rank_shift: + rank, _ = dbnet_cv.runner.get_dist_info() + seed += rank + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + if deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/utils/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/__init__.py new file mode 100755 index 00000000..e0825ed6 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/__init__.py @@ -0,0 +1,78 @@ +# flake8: noqa +# Copyright (c) OpenMMLab. All rights reserved. +from .config import Config, ConfigDict, DictAction +from .misc import (check_prerequisites, concat_list, deprecated_api_warning, + has_method, import_modules_from_strings, is_list_of, + is_method_overridden, is_seq_of, is_str, is_tuple_of, + iter_cast, list_cast, requires_executable, requires_package, + slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple, + to_ntuple, tuple_cast) +from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist, + scandir, symlink) +from .progressbar import (ProgressBar, track_iter_progress, + track_parallel_progress, track_progress) +from .testing import (assert_attrs_equal, assert_dict_contains_subset, + assert_dict_has_keys, assert_is_norm_layer, + assert_keys_equal, assert_params_all_zeros, + check_python_script) +from .timer import Timer, TimerError, check_time +from .version_utils import digit_version, get_git_hash + +try: + import torch +except ImportError: + __all__ = [ + 'Config', 'ConfigDict', 'DictAction', 'is_str', 'iter_cast', + 'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of', 'is_tuple_of', + 'slice_list', 'concat_list', 'check_prerequisites', 'requires_package', + 'requires_executable', 'is_filepath', 'fopen', 'check_file_exist', + 'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar', + 'track_progress', 'track_iter_progress', 'track_parallel_progress', + 'Timer', 'TimerError', 'check_time', 'deprecated_api_warning', + 'digit_version', 'get_git_hash', 'import_modules_from_strings', + 'assert_dict_contains_subset', 'assert_attrs_equal', + 'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script', + 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple', + 'is_method_overridden', 'has_method' + ] +else: + from .device_type import IS_IPU_AVAILABLE, IS_MLU_AVAILABLE + from .env import collect_env + from .hub import load_url + from .logging import get_logger, print_log + from .parrots_jit import jit, skip_no_elena + # yapf: disable + from .parrots_wrapper import (IS_CUDA_AVAILABLE, TORCH_VERSION, + BuildExtension, CppExtension, CUDAExtension, + DataLoader, PoolDataLoader, SyncBatchNorm, + _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, + _AvgPoolNd, _BatchNorm, _ConvNd, + _ConvTransposeMixin, _get_cuda_home, + _InstanceNorm, _MaxPoolNd, get_build_config, + is_rocm_pytorch) + # yapf: enable + from .registry import Registry, build_from_cfg + from .seed import worker_init_fn + from .trace import is_jit_tracing + __all__ = [ + 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger', + 'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast', + 'is_seq_of', 'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list', + 'check_prerequisites', 'requires_package', 'requires_executable', + 'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist', + 'symlink', 'scandir', 'ProgressBar', 'track_progress', + 'track_iter_progress', 'track_parallel_progress', 'Registry', + 'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'SyncBatchNorm', + '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', '_AvgPoolNd', '_BatchNorm', + '_ConvNd', '_ConvTransposeMixin', '_InstanceNorm', '_MaxPoolNd', + 'get_build_config', 'BuildExtension', 'CppExtension', 'CUDAExtension', + 'DataLoader', 'PoolDataLoader', 'TORCH_VERSION', + 'deprecated_api_warning', 'digit_version', 'get_git_hash', + 'import_modules_from_strings', 'jit', 'skip_no_elena', + 'assert_dict_contains_subset', 'assert_attrs_equal', + 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer', + 'assert_params_all_zeros', 'check_python_script', + 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', + '_get_cuda_home', 'load_url', 'has_method', 'IS_CUDA_AVAILABLE', + 'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE' + ] diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/utils/config.py b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/config.py new file mode 100755 index 00000000..1f018a95 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/config.py @@ -0,0 +1,741 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import ast +import copy +import os +import os.path as osp +import platform +import shutil +import sys +import tempfile +import types +import uuid +import warnings +from argparse import Action, ArgumentParser +from collections import abc +from importlib import import_module +from pathlib import Path + +from addict import Dict +from yapf.yapflib.yapf_api import FormatCode + +from .misc import import_modules_from_strings +from .path import check_file_exist + +if platform.system() == 'Windows': + import regex as re # type: ignore +else: + import re # type: ignore + +BASE_KEY = '_base_' +DELETE_KEY = '_delete_' +DEPRECATION_KEY = '_deprecation_' +RESERVED_KEYS = ['filename', 'text', 'pretty_text'] + + +class ConfigDict(Dict): + + def __missing__(self, name): + raise KeyError(name) + + def __getattr__(self, name): + try: + value = super().__getattr__(name) + except KeyError: + ex = AttributeError(f"'{self.__class__.__name__}' object has no " + f"attribute '{name}'") + except Exception as e: + ex = e + else: + return value + raise ex + + +def add_args(parser, cfg, prefix=''): + for k, v in cfg.items(): + if isinstance(v, str): + parser.add_argument('--' + prefix + k) + elif isinstance(v, int): + parser.add_argument('--' + prefix + k, type=int) + elif isinstance(v, float): + parser.add_argument('--' + prefix + k, type=float) + elif isinstance(v, bool): + parser.add_argument('--' + prefix + k, action='store_true') + elif isinstance(v, dict): + add_args(parser, v, prefix + k + '.') + elif isinstance(v, abc.Iterable): + parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+') + else: + print(f'cannot parse key {prefix + k} of type {type(v)}') + return parser + + +class Config: + """A facility for config and config files. + + It supports common file formats as configs: python/json/yaml. The interface + is the same as a dict object and also allows access config values as + attributes. + + Example: + >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) + >>> cfg.a + 1 + >>> cfg.b + {'b1': [0, 1]} + >>> cfg.b.b1 + [0, 1] + >>> cfg = Config.fromfile('tests/data/config/a.py') + >>> cfg.filename + "/home/kchen/projects/dbnet_cv/tests/data/config/a.py" + >>> cfg.item4 + 'test' + >>> cfg + "Config [path: /home/kchen/projects/dbnet_cv/tests/data/config/a.py]: " + "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" + """ + + @staticmethod + def _validate_py_syntax(filename): + with open(filename, encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + content = f.read() + try: + ast.parse(content) + except SyntaxError as e: + raise SyntaxError('There are syntax errors in config ' + f'file {filename}: {e}') + + @staticmethod + def _substitute_predefined_vars(filename, temp_config_name): + file_dirname = osp.dirname(filename) + file_basename = osp.basename(filename) + file_basename_no_extension = osp.splitext(file_basename)[0] + file_extname = osp.splitext(filename)[1] + support_templates = dict( + fileDirname=file_dirname, + fileBasename=file_basename, + fileBasenameNoExtension=file_basename_no_extension, + fileExtname=file_extname) + with open(filename, encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + config_file = f.read() + for key, value in support_templates.items(): + regexp = r'\{\{\s*' + str(key) + r'\s*\}\}' + value = value.replace('\\', '/') + config_file = re.sub(regexp, value, config_file) + with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file: + tmp_config_file.write(config_file) + + @staticmethod + def _pre_substitute_base_vars(filename, temp_config_name): + """Substitute base variable placehoders to string, so that parsing + would work.""" + with open(filename, encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + config_file = f.read() + base_var_dict = {} + regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}' + base_vars = set(re.findall(regexp, config_file)) + for base_var in base_vars: + randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}' + base_var_dict[randstr] = base_var + regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}' + config_file = re.sub(regexp, f'"{randstr}"', config_file) + with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file: + tmp_config_file.write(config_file) + return base_var_dict + + @staticmethod + def _substitute_base_vars(cfg, base_var_dict, base_cfg): + """Substitute variable strings to their actual values.""" + cfg = copy.deepcopy(cfg) + + if isinstance(cfg, dict): + for k, v in cfg.items(): + if isinstance(v, str) and v in base_var_dict: + new_v = base_cfg + for new_k in base_var_dict[v].split('.'): + new_v = new_v[new_k] + cfg[k] = new_v + elif isinstance(v, (list, tuple, dict)): + cfg[k] = Config._substitute_base_vars( + v, base_var_dict, base_cfg) + elif isinstance(cfg, tuple): + cfg = tuple( + Config._substitute_base_vars(c, base_var_dict, base_cfg) + for c in cfg) + elif isinstance(cfg, list): + cfg = [ + Config._substitute_base_vars(c, base_var_dict, base_cfg) + for c in cfg + ] + elif isinstance(cfg, str) and cfg in base_var_dict: + new_v = base_cfg + for new_k in base_var_dict[cfg].split('.'): + new_v = new_v[new_k] + cfg = new_v + + return cfg + + @staticmethod + def _file2dict(filename, use_predefined_variables=True): + filename = osp.abspath(osp.expanduser(filename)) + check_file_exist(filename) + fileExtname = osp.splitext(filename)[1] + if fileExtname not in ['.py', '.json', '.yaml', '.yml']: + raise OSError('Only py/yml/yaml/json type are supported now!') + + with tempfile.TemporaryDirectory() as temp_config_dir: + temp_config_file = tempfile.NamedTemporaryFile( + dir=temp_config_dir, suffix=fileExtname) + if platform.system() == 'Windows': + temp_config_file.close() + temp_config_name = osp.basename(temp_config_file.name) + # Substitute predefined variables + if use_predefined_variables: + Config._substitute_predefined_vars(filename, + temp_config_file.name) + else: + shutil.copyfile(filename, temp_config_file.name) + # Substitute base variables from placeholders to strings + base_var_dict = Config._pre_substitute_base_vars( + temp_config_file.name, temp_config_file.name) + + if filename.endswith('.py'): + temp_module_name = osp.splitext(temp_config_name)[0] + sys.path.insert(0, temp_config_dir) + Config._validate_py_syntax(filename) + mod = import_module(temp_module_name) + sys.path.pop(0) + cfg_dict = { + name: value + for name, value in mod.__dict__.items() + if not name.startswith('__') + and not isinstance(value, types.ModuleType) + and not isinstance(value, types.FunctionType) + } + # delete imported module + del sys.modules[temp_module_name] + elif filename.endswith(('.yml', '.yaml', '.json')): + import dbnet_cv + cfg_dict = dbnet_cv.load(temp_config_file.name) + # close temp file + temp_config_file.close() + + # check deprecation information + if DEPRECATION_KEY in cfg_dict: + deprecation_info = cfg_dict.pop(DEPRECATION_KEY) + warning_msg = f'The config file {filename} will be deprecated ' \ + 'in the future.' + if 'expected' in deprecation_info: + warning_msg += f' Please use {deprecation_info["expected"]} ' \ + 'instead.' + if 'reference' in deprecation_info: + warning_msg += ' More information can be found at ' \ + f'{deprecation_info["reference"]}' + warnings.warn(warning_msg, DeprecationWarning) + + cfg_text = filename + '\n' + with open(filename, encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + cfg_text += f.read() + + if BASE_KEY in cfg_dict: + cfg_dir = osp.dirname(filename) + base_filename = cfg_dict.pop(BASE_KEY) + base_filename = base_filename if isinstance( + base_filename, list) else [base_filename] + + cfg_dict_list = list() + cfg_text_list = list() + for f in base_filename: + _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f)) + cfg_dict_list.append(_cfg_dict) + cfg_text_list.append(_cfg_text) + + base_cfg_dict = dict() + for c in cfg_dict_list: + duplicate_keys = base_cfg_dict.keys() & c.keys() + if len(duplicate_keys) > 0: + raise KeyError('Duplicate key is not allowed among bases. ' + f'Duplicate keys: {duplicate_keys}') + base_cfg_dict.update(c) + + # Substitute base variables from strings to their actual values + cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict, + base_cfg_dict) + + base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) + cfg_dict = base_cfg_dict + + # merge cfg_text + cfg_text_list.append(cfg_text) + cfg_text = '\n'.join(cfg_text_list) + + return cfg_dict, cfg_text + + @staticmethod + def _merge_a_into_b(a, b, allow_list_keys=False): + """merge dict ``a`` into dict ``b`` (non-inplace). + + Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid + in-place modifications. + + Args: + a (dict): The source dict to be merged into ``b``. + b (dict): The origin dict to be fetch keys from ``a``. + allow_list_keys (bool): If True, int string keys (e.g. '0', '1') + are allowed in source ``a`` and will replace the element of the + corresponding index in b if b is a list. Default: False. + + Returns: + dict: The modified dict of ``b`` using ``a``. + + Examples: + # Normally merge a into b. + >>> Config._merge_a_into_b( + ... dict(obj=dict(a=2)), dict(obj=dict(a=1))) + {'obj': {'a': 2}} + + # Delete b first and merge a into b. + >>> Config._merge_a_into_b( + ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1))) + {'obj': {'a': 2}} + + # b is a list + >>> Config._merge_a_into_b( + ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True) + [{'a': 2}, {'b': 2}] + """ + b = b.copy() + for k, v in a.items(): + if allow_list_keys and k.isdigit() and isinstance(b, list): + k = int(k) + if len(b) <= k: + raise KeyError(f'Index {k} exceeds the length of list {b}') + b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) + elif isinstance(v, dict): + if k in b and not v.pop(DELETE_KEY, False): + allowed_types = (dict, list) if allow_list_keys else dict + if not isinstance(b[k], allowed_types): + raise TypeError( + f'{k}={v} in child config cannot inherit from ' + f'base because {k} is a dict in the child config ' + f'but is of type {type(b[k])} in base config. ' + f'You may set `{DELETE_KEY}=True` to ignore the ' + f'base config.') + b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) + else: + b[k] = ConfigDict(v) + else: + b[k] = v + return b + + @staticmethod + def fromfile(filename, + use_predefined_variables=True, + import_custom_modules=True): + if isinstance(filename, Path): + filename = str(filename) + cfg_dict, cfg_text = Config._file2dict(filename, + use_predefined_variables) + if import_custom_modules and cfg_dict.get('custom_imports', None): + import_modules_from_strings(**cfg_dict['custom_imports']) + return Config(cfg_dict, cfg_text=cfg_text, filename=filename) + + @staticmethod + def fromstring(cfg_str, file_format): + """Generate config from config str. + + Args: + cfg_str (str): Config str. + file_format (str): Config file format corresponding to the + config str. Only py/yml/yaml/json type are supported now! + + Returns: + :obj:`Config`: Config obj. + """ + if file_format not in ['.py', '.json', '.yaml', '.yml']: + raise OSError('Only py/yml/yaml/json type are supported now!') + if file_format != '.py' and 'dict(' in cfg_str: + # check if users specify a wrong suffix for python + warnings.warn( + 'Please check "file_format", the file format may be .py') + with tempfile.NamedTemporaryFile( + 'w', encoding='utf-8', suffix=file_format, + delete=False) as temp_file: + temp_file.write(cfg_str) + # on windows, previous implementation cause error + # see PR 1077 for details + cfg = Config.fromfile(temp_file.name) + os.remove(temp_file.name) + return cfg + + @staticmethod + def auto_argparser(description=None): + """Generate argparser from config file automatically (experimental)""" + partial_parser = ArgumentParser(description=description) + partial_parser.add_argument('config', help='config file path') + cfg_file = partial_parser.parse_known_args()[0].config + cfg = Config.fromfile(cfg_file) + parser = ArgumentParser(description=description) + parser.add_argument('config', help='config file path') + add_args(parser, cfg) + return parser, cfg + + def __init__(self, cfg_dict=None, cfg_text=None, filename=None): + if cfg_dict is None: + cfg_dict = dict() + elif not isinstance(cfg_dict, dict): + raise TypeError('cfg_dict must be a dict, but ' + f'got {type(cfg_dict)}') + for key in cfg_dict: + if key in RESERVED_KEYS: + raise KeyError(f'{key} is reserved for config file') + + if isinstance(filename, Path): + filename = str(filename) + + super().__setattr__('_cfg_dict', ConfigDict(cfg_dict)) + super().__setattr__('_filename', filename) + if cfg_text: + text = cfg_text + elif filename: + with open(filename) as f: + text = f.read() + else: + text = '' + super().__setattr__('_text', text) + + @property + def filename(self): + return self._filename + + @property + def text(self): + return self._text + + @property + def pretty_text(self): + + indent = 4 + + def _indent(s_, num_spaces): + s = s_.split('\n') + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * ' ') + line for line in s] + s = '\n'.join(s) + s = first + '\n' + s + return s + + def _format_basic_types(k, v, use_mapping=False): + if isinstance(v, str): + v_str = f"'{v}'" + else: + v_str = str(v) + + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: {v_str}' + else: + attr_str = f'{str(k)}={v_str}' + attr_str = _indent(attr_str, indent) + + return attr_str + + def _format_list(k, v, use_mapping=False): + # check if all items in the list are dict + if all(isinstance(_, dict) for _ in v): + v_str = '[\n' + v_str += '\n'.join( + f'dict({_indent(_format_dict(v_), indent)}),' + for v_ in v).rstrip(',') + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: {v_str}' + else: + attr_str = f'{str(k)}={v_str}' + attr_str = _indent(attr_str, indent) + ']' + else: + attr_str = _format_basic_types(k, v, use_mapping) + return attr_str + + def _contain_invalid_identifier(dict_str): + contain_invalid_identifier = False + for key_name in dict_str: + contain_invalid_identifier |= \ + (not str(key_name).isidentifier()) + return contain_invalid_identifier + + def _format_dict(input_dict, outest_level=False): + r = '' + s = [] + + use_mapping = _contain_invalid_identifier(input_dict) + if use_mapping: + r += '{' + for idx, (k, v) in enumerate(input_dict.items()): + is_last = idx >= len(input_dict) - 1 + end = '' if outest_level or is_last else ',' + if isinstance(v, dict): + v_str = '\n' + _format_dict(v) + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: dict({v_str}' + else: + attr_str = f'{str(k)}=dict({v_str}' + attr_str = _indent(attr_str, indent) + ')' + end + elif isinstance(v, list): + attr_str = _format_list(k, v, use_mapping) + end + else: + attr_str = _format_basic_types(k, v, use_mapping) + end + + s.append(attr_str) + r += '\n'.join(s) + if use_mapping: + r += '}' + return r + + cfg_dict = self._cfg_dict.to_dict() + text = _format_dict(cfg_dict, outest_level=True) + # copied from setup.cfg + yapf_style = dict( + based_on_style='pep8', + blank_line_before_nested_class_or_def=True, + split_before_expression_after_opening_paren=True) + text, _ = FormatCode(text, style_config=yapf_style, verify=True) + + return text + + def __repr__(self): + return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' + + def __len__(self): + return len(self._cfg_dict) + + def __getattr__(self, name): + return getattr(self._cfg_dict, name) + + def __getitem__(self, name): + return self._cfg_dict.__getitem__(name) + + def __setattr__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setattr__(name, value) + + def __setitem__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setitem__(name, value) + + def __iter__(self): + return iter(self._cfg_dict) + + def __getstate__(self): + return (self._cfg_dict, self._filename, self._text) + + def __copy__(self): + cls = self.__class__ + other = cls.__new__(cls) + other.__dict__.update(self.__dict__) + + return other + + def __deepcopy__(self, memo): + cls = self.__class__ + other = cls.__new__(cls) + memo[id(self)] = other + + for key, value in self.__dict__.items(): + super(Config, other).__setattr__(key, copy.deepcopy(value, memo)) + + return other + + def __setstate__(self, state): + _cfg_dict, _filename, _text = state + super().__setattr__('_cfg_dict', _cfg_dict) + super().__setattr__('_filename', _filename) + super().__setattr__('_text', _text) + + def dump(self, file=None): + """Dumps config into a file or returns a string representation of the + config. + + If a file argument is given, saves the config to that file using the + format defined by the file argument extension. + + Otherwise, returns a string representing the config. The formatting of + this returned string is defined by the extension of `self.filename`. If + `self.filename` is not defined, returns a string representation of a + dict (lowercased and using ' for strings). + + Examples: + >>> cfg_dict = dict(item1=[1, 2], item2=dict(a=0), + ... item3=True, item4='test') + >>> cfg = Config(cfg_dict=cfg_dict) + >>> dump_file = "a.py" + >>> cfg.dump(dump_file) + + Args: + file (str, optional): Path of the output file where the config + will be dumped. Defaults to None. + """ + import dbnet_cv + cfg_dict = super().__getattribute__('_cfg_dict').to_dict() + if file is None: + if self.filename is None or self.filename.endswith('.py'): + return self.pretty_text + else: + file_format = self.filename.split('.')[-1] + return dbnet_cv.dump(cfg_dict, file_format=file_format) + elif file.endswith('.py'): + with open(file, 'w', encoding='utf-8') as f: + f.write(self.pretty_text) + else: + file_format = file.split('.')[-1] + return dbnet_cv.dump(cfg_dict, file=file, file_format=file_format) + + def merge_from_dict(self, options, allow_list_keys=True): + """Merge list into cfg_dict. + + Merge the dict parsed by MultipleKVAction into this cfg. + + Examples: + >>> options = {'model.backbone.depth': 50, + ... 'model.backbone.with_cp':True} + >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) + >>> cfg.merge_from_dict(options) + >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + >>> assert cfg_dict == dict( + ... model=dict(backbone=dict(depth=50, with_cp=True))) + + >>> # Merge list element + >>> cfg = Config(dict(pipeline=[ + ... dict(type='LoadImage'), dict(type='LoadAnnotations')])) + >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')}) + >>> cfg.merge_from_dict(options, allow_list_keys=True) + >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + >>> assert cfg_dict == dict(pipeline=[ + ... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')]) + + Args: + options (dict): dict of configs to merge from. + allow_list_keys (bool): If True, int string keys (e.g. '0', '1') + are allowed in ``options`` and will replace the element of the + corresponding index in the config if the config is a list. + Default: True. + """ + option_cfg_dict = {} + for full_key, v in options.items(): + d = option_cfg_dict + key_list = full_key.split('.') + for subkey in key_list[:-1]: + d.setdefault(subkey, ConfigDict()) + d = d[subkey] + subkey = key_list[-1] + d[subkey] = v + + cfg_dict = super().__getattribute__('_cfg_dict') + super().__setattr__( + '_cfg_dict', + Config._merge_a_into_b( + option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys)) + + +class DictAction(Action): + """ + argparse action to split an argument into KEY=VALUE form + on the first = and append to a dictionary. List options can + be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit + brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build + list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]' + """ + + @staticmethod + def _parse_int_float_bool(val): + try: + return int(val) + except ValueError: + pass + try: + return float(val) + except ValueError: + pass + if val.lower() in ['true', 'false']: + return True if val.lower() == 'true' else False + if val == 'None': + return None + return val + + @staticmethod + def _parse_iterable(val): + """Parse iterable values in the string. + + All elements inside '()' or '[]' are treated as iterable values. + + Args: + val (str): Value string. + + Returns: + list | tuple: The expanded list or tuple from the string. + + Examples: + >>> DictAction._parse_iterable('1,2,3') + [1, 2, 3] + >>> DictAction._parse_iterable('[a, b, c]') + ['a', 'b', 'c'] + >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]') + [(1, 2, 3), ['a', 'b'], 'c'] + """ + + def find_next_comma(string): + """Find the position of next comma in the string. + + If no ',' is found in the string, return the string length. All + chars inside '()' and '[]' are treated as one element and thus ',' + inside these brackets are ignored. + """ + assert (string.count('(') == string.count(')')) and ( + string.count('[') == string.count(']')), \ + f'Imbalanced brackets exist in {string}' + end = len(string) + for idx, char in enumerate(string): + pre = string[:idx] + # The string before this ',' is balanced + if ((char == ',') and (pre.count('(') == pre.count(')')) + and (pre.count('[') == pre.count(']'))): + end = idx + break + return end + + # Strip ' and " characters and replace whitespace. + val = val.strip('\'\"').replace(' ', '') + is_tuple = False + if val.startswith('(') and val.endswith(')'): + is_tuple = True + val = val[1:-1] + elif val.startswith('[') and val.endswith(']'): + val = val[1:-1] + elif ',' not in val: + # val is a single value + return DictAction._parse_int_float_bool(val) + + values = [] + while len(val) > 0: + comma_idx = find_next_comma(val) + element = DictAction._parse_iterable(val[:comma_idx]) + values.append(element) + val = val[comma_idx + 1:] + if is_tuple: + values = tuple(values) + return values + + def __call__(self, parser, namespace, values, option_string=None): + options = {} + for kv in values: + key, val = kv.split('=', maxsplit=1) + options[key] = self._parse_iterable(val) + setattr(namespace, self.dest, options) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/utils/device_type.py b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/device_type.py new file mode 100755 index 00000000..c66052c2 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/device_type.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. + + +def is_ipu_available() -> bool: + try: + import poptorch + return poptorch.ipuHardwareIsAvailable() + except ImportError: + return False + + +IS_IPU_AVAILABLE = is_ipu_available() + + +def is_mlu_available() -> bool: + try: + import torch + return (hasattr(torch, 'is_mlu_available') + and torch.is_mlu_available()) + except Exception: + return False + + +IS_MLU_AVAILABLE = is_mlu_available() diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/utils/env.py b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/env.py new file mode 100755 index 00000000..2d26da83 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/env.py @@ -0,0 +1,120 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""This file holding some environment constant for sharing by other files.""" + +import os.path as osp +import subprocess +import sys +from collections import defaultdict + +import cv2 +import torch + +import dbnet_cv +from .parrots_wrapper import get_build_config + + +def collect_env(): + """Collect the information of the running environments. + + Returns: + dict: The environment information. The following fields are contained. + + - sys.platform: The variable of ``sys.platform``. + - Python: Python version. + - CUDA available: Bool, indicating if CUDA is available. + - GPU devices: Device type of each GPU. + - CUDA_HOME (optional): The env var ``CUDA_HOME``. + - NVCC (optional): NVCC version. + - GCC: GCC version, "n/a" if GCC is not installed. + - MSVC: Microsoft Virtual C++ Compiler version, Windows only. + - PyTorch: PyTorch version. + - PyTorch compiling details: The output of \ + ``torch.__config__.show()``. + - TorchVision (optional): TorchVision version. + - OpenCV: OpenCV version. + - DBNET_CV: DBNET_CV version. + - DBNET_CV Compiler: The GCC version for compiling DBNET_CV ops. + - DBNET_CV CUDA Compiler: The CUDA version for compiling DBNET_CV ops. + """ + env_info = {} + env_info['sys.platform'] = sys.platform + env_info['Python'] = sys.version.replace('\n', '') + + cuda_available = torch.cuda.is_available() + env_info['CUDA available'] = cuda_available + + if cuda_available: + devices = defaultdict(list) + for k in range(torch.cuda.device_count()): + devices[torch.cuda.get_device_name(k)].append(str(k)) + for name, device_ids in devices.items(): + env_info['GPU ' + ','.join(device_ids)] = name + + from dbnet_cv.utils.parrots_wrapper import _get_cuda_home + CUDA_HOME = _get_cuda_home() + env_info['CUDA_HOME'] = CUDA_HOME + + if CUDA_HOME is not None and osp.isdir(CUDA_HOME): + try: + nvcc = osp.join(CUDA_HOME, 'bin/nvcc') + nvcc = subprocess.check_output(f'"{nvcc}" -V', shell=True) + nvcc = nvcc.decode('utf-8').strip() + release = nvcc.rfind('Cuda compilation tools') + build = nvcc.rfind('Build ') + nvcc = nvcc[release:build].strip() + except subprocess.SubprocessError: + nvcc = 'Not Available' + env_info['NVCC'] = nvcc + + try: + # Check C++ Compiler. + # For Unix-like, sysconfig has 'CC' variable like 'gcc -pthread ...', + # indicating the compiler used, we use this to get the compiler name + import sysconfig + cc = sysconfig.get_config_var('CC') + if cc: + cc = osp.basename(cc.split()[0]) + cc_info = subprocess.check_output(f'{cc} --version', shell=True) + env_info['GCC'] = cc_info.decode('utf-8').partition( + '\n')[0].strip() + else: + # on Windows, cl.exe is not in PATH. We need to find the path. + # distutils.ccompiler.new_compiler() returns a msvccompiler + # object and after initialization, path to cl.exe is found. + import locale + import os + from distutils.ccompiler import new_compiler + ccompiler = new_compiler() + ccompiler.initialize() + cc = subprocess.check_output( + f'{ccompiler.cc}', stderr=subprocess.STDOUT, shell=True) + encoding = os.device_encoding( + sys.stdout.fileno()) or locale.getpreferredencoding() + env_info['MSVC'] = cc.decode(encoding).partition('\n')[0].strip() + env_info['GCC'] = 'n/a' + except subprocess.CalledProcessError: + env_info['GCC'] = 'n/a' + + env_info['PyTorch'] = torch.__version__ + env_info['PyTorch compiling details'] = get_build_config() + + try: + import torchvision + env_info['TorchVision'] = torchvision.__version__ + except ModuleNotFoundError: + pass + + env_info['OpenCV'] = cv2.__version__ + + env_info['DBNET_CV'] = dbnet_cv.__version__ + + try: + from dbnet_cv.ops import get_compiler_version, get_compiling_cuda_version + except ModuleNotFoundError: + env_info['DBNET_CV Compiler'] = 'n/a' + env_info['DBNET_CV CUDA Compiler'] = 'n/a' + else: + env_info['DBNET_CV Compiler'] = get_compiler_version() + env_info['DBNET_CV CUDA Compiler'] = get_compiling_cuda_version() + + return env_info diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/utils/ext_loader.py b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/ext_loader.py new file mode 100755 index 00000000..58508798 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/ext_loader.py @@ -0,0 +1,72 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import importlib +import os +import pkgutil +import warnings +from collections import namedtuple + +import torch + +if torch.__version__ != 'parrots': + + def load_ext(name, funcs): + ext = importlib.import_module('dbnet_cv.' + name) + for fun in funcs: + assert hasattr(ext, fun), f'{fun} miss in module {name}' + return ext +else: + from parrots import extension + from parrots.base import ParrotsException + + has_return_value_ops = [ + 'nms', + 'softnms', + 'nms_match', + 'nms_rotated', + 'top_pool_forward', + 'top_pool_backward', + 'bottom_pool_forward', + 'bottom_pool_backward', + 'left_pool_forward', + 'left_pool_backward', + 'right_pool_forward', + 'right_pool_backward', + 'fused_bias_leakyrelu', + 'upfirdn2d', + 'ms_deform_attn_forward', + 'pixel_group', + 'contour_expand', + 'diff_iou_rotated_sort_vertices_forward', + ] + + def get_fake_func(name, e): + + def fake_func(*args, **kwargs): + warnings.warn(f'{name} is not supported in parrots now') + raise e + + return fake_func + + def load_ext(name, funcs): + ExtModule = namedtuple('ExtModule', funcs) + ext_list = [] + lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + for fun in funcs: + try: + ext_fun = extension.load(fun, name, lib_dir=lib_root) + except ParrotsException as e: + if 'No element registered' not in e.message: + warnings.warn(e.message) + ext_fun = get_fake_func(fun, e) + ext_list.append(ext_fun) + else: + if fun in has_return_value_ops: + ext_list.append(ext_fun.op) + else: + ext_list.append(ext_fun.op_) + return ExtModule(*ext_list) + + +def check_ops_exist() -> bool: + ext_loader = pkgutil.find_loader('dbnet_cv._ext') + return ext_loader is not None diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/utils/hub.py b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/hub.py new file mode 100755 index 00000000..a9cbbc95 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/hub.py @@ -0,0 +1,131 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# The 1.6 release of PyTorch switched torch.save to use a new zipfile-based +# file format. It will cause RuntimeError when a checkpoint was saved in +# torch >= 1.6.0 but loaded in torch < 1.7.0. +# More details at https://github.com/open-mmlab/mmpose/issues/904 +from .parrots_wrapper import TORCH_VERSION +from .path import mkdir_or_exist +from .version_utils import digit_version + +if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version( + '1.7.0'): + # Modified from https://github.com/pytorch/pytorch/blob/master/torch/hub.py + import os + import sys + import warnings + import zipfile + from urllib.parse import urlparse + + import torch + from torch.hub import HASH_REGEX, _get_torch_home, download_url_to_file + + # Hub used to support automatically extracts from zipfile manually + # compressed by users. The legacy zip format expects only one file from + # torch.save() < 1.6 in the zip. We should remove this support since + # zipfile is now default zipfile format for torch.save(). + def _is_legacy_zip_format(filename): + if zipfile.is_zipfile(filename): + infolist = zipfile.ZipFile(filename).infolist() + return len(infolist) == 1 and not infolist[0].is_dir() + return False + + def _legacy_zip_load(filename, model_dir, map_location): + warnings.warn( + 'Falling back to the old format < 1.6. This support will' + ' be deprecated in favor of default zipfile format ' + 'introduced in 1.6. Please redo torch.save() to save it ' + 'in the new zipfile format.', DeprecationWarning) + # Note: extractall() defaults to overwrite file if exists. No need to + # clean up beforehand. We deliberately don't handle tarfile here + # since our legacy serialization format was in tar. + # E.g. resnet18-5c106cde.pth which is widely used. + with zipfile.ZipFile(filename) as f: + members = f.infolist() + if len(members) != 1: + raise RuntimeError( + 'Only one file(not dir) is allowed in the zipfile') + f.extractall(model_dir) + extraced_name = members[0].filename + extracted_file = os.path.join(model_dir, extraced_name) + return torch.load(extracted_file, map_location=map_location) + + def load_url(url, + model_dir=None, + map_location=None, + progress=True, + check_hash=False, + file_name=None): + r"""Loads the Torch serialized object at the given URL. + + If downloaded file is a zip file, it will be automatically decompressed + + If the object is already present in `model_dir`, it's deserialized and + returned. + The default value of ``model_dir`` is ``/checkpoints`` where + ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. + + Args: + url (str): URL of the object to download + model_dir (str, optional): directory in which to save the object + map_location (optional): a function or a dict specifying how to + remap storage locations (see torch.load) + progress (bool, optional): whether or not to display a progress bar + to stderr. Default: True + check_hash(bool, optional): If True, the filename part of the URL + should follow the naming convention ``filename-.ext`` + where ```` is the first eight or more digits of the + SHA256 hash of the contents of the file. The hash is used to + ensure unique names and to verify the contents of the file. + Default: False + file_name (str, optional): name for the downloaded file. Filename + from ``url`` will be used if not set. Default: None. + + Example: + >>> url = ('https://s3.amazonaws.com/pytorch/models/resnet18-5c106' + ... 'cde.pth') + >>> state_dict = torch.hub.load_state_dict_from_url(url) + """ + # Issue warning to move data if old env is set + if os.getenv('TORCH_MODEL_ZOO'): + warnings.warn( + 'TORCH_MODEL_ZOO is deprecated, please use env ' + 'TORCH_HOME instead', DeprecationWarning) + + if model_dir is None: + torch_home = _get_torch_home() + model_dir = os.path.join(torch_home, 'checkpoints') + + mkdir_or_exist(model_dir) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format( + url, cached_file)) + hash_prefix = None + if check_hash: + r = HASH_REGEX.search(filename) # r is Optional[Match[str]] + hash_prefix = r.group(1) if r else None + download_url_to_file( + url, cached_file, hash_prefix, progress=progress) + + if _is_legacy_zip_format(cached_file): + return _legacy_zip_load(cached_file, model_dir, map_location) + + try: + return torch.load(cached_file, map_location=map_location) + except RuntimeError as error: + if digit_version(TORCH_VERSION) < digit_version('1.5.0'): + warnings.warn( + f'If the error is the same as "{cached_file} is a zip ' + 'archive (did you mean to use torch.jit.load()?)", you can' + ' upgrade your torch to 1.5.0 or higher (current torch ' + f'version is {TORCH_VERSION}). The error was raised ' + ' because the checkpoint was saved in torch>=1.6.0 but ' + 'loaded in torch<1.5.') + raise error +else: + from torch.utils.model_zoo import load_url # type: ignore # noqa: F401 diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/utils/logging.py b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/logging.py new file mode 100755 index 00000000..5a90aac8 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/logging.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging + +import torch.distributed as dist + +logger_initialized: dict = {} + + +def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): + """Initialize and get a logger by name. + + If the logger has not been initialized, this method will initialize the + logger by adding one or two handlers, otherwise the initialized logger will + be directly returned. During initialization, a StreamHandler will always be + added. If `log_file` is specified and the process rank is 0, a FileHandler + will also be added. + + Args: + name (str): Logger name. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the logger. + log_level (int): The logger level. Note that only the process of + rank 0 is affected, and other processes will set the level to + "Error" thus be silent most of the time. + file_mode (str): The file mode used in opening log file. + Defaults to 'w'. + + Returns: + logging.Logger: The expected logger. + """ + logger = logging.getLogger(name) + if name in logger_initialized: + return logger + # handle hierarchical names + # e.g., logger "a" is initialized, then logger "a.b" will skip the + # initialization since it is a child of "a". + for logger_name in logger_initialized: + if name.startswith(logger_name): + return logger + + # handle duplicate logs to the console + # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET) + # to the root logger. As logger.propagate is True by default, this root + # level handler causes logging messages from rank>0 processes to + # unexpectedly show up on the console, creating much unwanted clutter. + # To fix this issue, we set the root logger's StreamHandler, if any, to log + # at the ERROR level. + for handler in logger.root.handlers: + if type(handler) is logging.StreamHandler: + handler.setLevel(logging.ERROR) + + stream_handler = logging.StreamHandler() + handlers = [stream_handler] + + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + + # only rank 0 will add a FileHandler + if rank == 0 and log_file is not None: + # Here, the default behaviour of the official logger is 'a'. Thus, we + # provide an interface to change the file mode to the default + # behaviour. + file_handler = logging.FileHandler(log_file, file_mode) + handlers.append(file_handler) + + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + for handler in handlers: + handler.setFormatter(formatter) + handler.setLevel(log_level) + logger.addHandler(handler) + + if rank == 0: + logger.setLevel(log_level) + else: + logger.setLevel(logging.ERROR) + + logger_initialized[name] = True + + return logger + + +def print_log(msg, logger=None, level=logging.INFO): + """Print a log message. + + Args: + msg (str): The message to be logged. + logger (logging.Logger | str | None): The logger to be used. + Some special loggers are: + + - "silent": no message will be printed. + - other str: the logger obtained with `get_root_logger(logger)`. + - None: The `print()` method will be used to print log messages. + level (int): Logging level. Only available when `logger` is a Logger + object or "root". + """ + if logger is None: + print(msg) + elif isinstance(logger, logging.Logger): + logger.log(level, msg) + elif logger == 'silent': + pass + elif isinstance(logger, str): + _logger = get_logger(logger) + _logger.log(level, msg) + else: + raise TypeError( + 'logger should be either a logging.Logger object, str, ' + f'"silent" or None, but got {type(logger)}') diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/utils/misc.py b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/misc.py new file mode 100755 index 00000000..7957ea89 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/misc.py @@ -0,0 +1,377 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import collections.abc +import functools +import itertools +import subprocess +import warnings +from collections import abc +from importlib import import_module +from inspect import getfullargspec +from itertools import repeat + + +# From PyTorch internals +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def is_str(x): + """Whether the input is an string instance. + + Note: This method is deprecated since python 2 is no longer supported. + """ + return isinstance(x, str) + + +def import_modules_from_strings(imports, allow_failed_imports=False): + """Import modules from the given list of strings. + + Args: + imports (list | str | None): The given module names to be imported. + allow_failed_imports (bool): If True, the failed imports will return + None. Otherwise, an ImportError is raise. Default: False. + + Returns: + list[module] | module | None: The imported modules. + + Examples: + >>> osp, sys = import_modules_from_strings( + ... ['os.path', 'sys']) + >>> import os.path as osp_ + >>> import sys as sys_ + >>> assert osp == osp_ + >>> assert sys == sys_ + """ + if not imports: + return + single_import = False + if isinstance(imports, str): + single_import = True + imports = [imports] + if not isinstance(imports, list): + raise TypeError( + f'custom_imports must be a list but got type {type(imports)}') + imported = [] + for imp in imports: + if not isinstance(imp, str): + raise TypeError( + f'{imp} is of type {type(imp)} and cannot be imported.') + try: + imported_tmp = import_module(imp) + except ImportError: + if allow_failed_imports: + warnings.warn(f'{imp} failed to import and is ignored.', + UserWarning) + imported_tmp = None + else: + raise ImportError + imported.append(imported_tmp) + if single_import: + imported = imported[0] + return imported + + +def iter_cast(inputs, dst_type, return_type=None): + """Cast elements of an iterable object into some type. + + Args: + inputs (Iterable): The input object. + dst_type (type): Destination type. + return_type (type, optional): If specified, the output object will be + converted to this type, otherwise an iterator. + + Returns: + iterator or specified type: The converted object. + """ + if not isinstance(inputs, abc.Iterable): + raise TypeError('inputs must be an iterable object') + if not isinstance(dst_type, type): + raise TypeError('"dst_type" must be a valid type') + + out_iterable = map(dst_type, inputs) + + if return_type is None: + return out_iterable + else: + return return_type(out_iterable) + + +def list_cast(inputs, dst_type): + """Cast elements of an iterable object into a list of some type. + + A partial method of :func:`iter_cast`. + """ + return iter_cast(inputs, dst_type, return_type=list) + + +def tuple_cast(inputs, dst_type): + """Cast elements of an iterable object into a tuple of some type. + + A partial method of :func:`iter_cast`. + """ + return iter_cast(inputs, dst_type, return_type=tuple) + + +def is_seq_of(seq, expected_type, seq_type=None): + """Check whether it is a sequence of some type. + + Args: + seq (Sequence): The sequence to be checked. + expected_type (type): Expected type of sequence items. + seq_type (type, optional): Expected sequence type. + + Returns: + bool: Whether the sequence is valid. + """ + if seq_type is None: + exp_seq_type = abc.Sequence + else: + assert isinstance(seq_type, type) + exp_seq_type = seq_type + if not isinstance(seq, exp_seq_type): + return False + for item in seq: + if not isinstance(item, expected_type): + return False + return True + + +def is_list_of(seq, expected_type): + """Check whether it is a list of some type. + + A partial method of :func:`is_seq_of`. + """ + return is_seq_of(seq, expected_type, seq_type=list) + + +def is_tuple_of(seq, expected_type): + """Check whether it is a tuple of some type. + + A partial method of :func:`is_seq_of`. + """ + return is_seq_of(seq, expected_type, seq_type=tuple) + + +def slice_list(in_list, lens): + """Slice a list into several sub lists by a list of given length. + + Args: + in_list (list): The list to be sliced. + lens(int or list): The expected length of each out list. + + Returns: + list: A list of sliced list. + """ + if isinstance(lens, int): + assert len(in_list) % lens == 0 + lens = [lens] * int(len(in_list) / lens) + if not isinstance(lens, list): + raise TypeError('"indices" must be an integer or a list of integers') + elif sum(lens) != len(in_list): + raise ValueError('sum of lens and list length does not ' + f'match: {sum(lens)} != {len(in_list)}') + out_list = [] + idx = 0 + for i in range(len(lens)): + out_list.append(in_list[idx:idx + lens[i]]) + idx += lens[i] + return out_list + + +def concat_list(in_list): + """Concatenate a list of list into a single list. + + Args: + in_list (list): The list of list to be merged. + + Returns: + list: The concatenated flat list. + """ + return list(itertools.chain(*in_list)) + + +def check_prerequisites( + prerequisites, + checker, + msg_tmpl='Prerequisites "{}" are required in method "{}" but not ' + 'found, please install them first.'): # yapf: disable + """A decorator factory to check if prerequisites are satisfied. + + Args: + prerequisites (str of list[str]): Prerequisites to be checked. + checker (callable): The checker method that returns True if a + prerequisite is meet, False otherwise. + msg_tmpl (str): The message template with two variables. + + Returns: + decorator: A specific decorator. + """ + + def wrap(func): + + @functools.wraps(func) + def wrapped_func(*args, **kwargs): + requirements = [prerequisites] if isinstance( + prerequisites, str) else prerequisites + missing = [] + for item in requirements: + if not checker(item): + missing.append(item) + if missing: + print(msg_tmpl.format(', '.join(missing), func.__name__)) + raise RuntimeError('Prerequisites not meet.') + else: + return func(*args, **kwargs) + + return wrapped_func + + return wrap + + +def _check_py_package(package): + try: + import_module(package) + except ImportError: + return False + else: + return True + + +def _check_executable(cmd): + if subprocess.call(f'which {cmd}', shell=True) != 0: + return False + else: + return True + + +def requires_package(prerequisites): + """A decorator to check if some python packages are installed. + + Example: + >>> @requires_package('numpy') + >>> func(arg1, args): + >>> return numpy.zeros(1) + array([0.]) + >>> @requires_package(['numpy', 'non_package']) + >>> func(arg1, args): + >>> return numpy.zeros(1) + ImportError + """ + return check_prerequisites(prerequisites, checker=_check_py_package) + + +def requires_executable(prerequisites): + """A decorator to check if some executable files are installed. + + Example: + >>> @requires_executable('ffmpeg') + >>> func(arg1, args): + >>> print(1) + 1 + """ + return check_prerequisites(prerequisites, checker=_check_executable) + + +def deprecated_api_warning(name_dict, cls_name=None): + """A decorator to check if some arguments are deprecate and try to replace + deprecate src_arg_name to dst_arg_name. + + Args: + name_dict(dict): + key (str): Deprecate argument names. + val (str): Expected argument names. + + Returns: + func: New function. + """ + + def api_warning_wrapper(old_func): + + @functools.wraps(old_func) + def new_func(*args, **kwargs): + # get the arg spec of the decorated method + args_info = getfullargspec(old_func) + # get name of the function + func_name = old_func.__name__ + if cls_name is not None: + func_name = f'{cls_name}.{func_name}' + if args: + arg_names = args_info.args[:len(args)] + for src_arg_name, dst_arg_name in name_dict.items(): + if src_arg_name in arg_names: + warnings.warn( + f'"{src_arg_name}" is deprecated in ' + f'`{func_name}`, please use "{dst_arg_name}" ' + 'instead', DeprecationWarning) + arg_names[arg_names.index(src_arg_name)] = dst_arg_name + if kwargs: + for src_arg_name, dst_arg_name in name_dict.items(): + if src_arg_name in kwargs: + + assert dst_arg_name not in kwargs, ( + f'The expected behavior is to replace ' + f'the deprecated key `{src_arg_name}` to ' + f'new key `{dst_arg_name}`, but got them ' + f'in the arguments at the same time, which ' + f'is confusing. `{src_arg_name} will be ' + f'deprecated in the future, please ' + f'use `{dst_arg_name}` instead.') + + warnings.warn( + f'"{src_arg_name}" is deprecated in ' + f'`{func_name}`, please use "{dst_arg_name}" ' + 'instead', DeprecationWarning) + kwargs[dst_arg_name] = kwargs.pop(src_arg_name) + + # apply converted arguments to the decorated method + output = old_func(*args, **kwargs) + return output + + return new_func + + return api_warning_wrapper + + +def is_method_overridden(method, base_class, derived_class): + """Check if a method of base class is overridden in derived class. + + Args: + method (str): the method name to check. + base_class (type): the class of the base class. + derived_class (type | Any): the class or instance of the derived class. + """ + assert isinstance(base_class, type), \ + "base_class doesn't accept instance, Please pass class instead." + + if not isinstance(derived_class, type): + derived_class = derived_class.__class__ + + base_method = getattr(base_class, method) + derived_method = getattr(derived_class, method) + return derived_method != base_method + + +def has_method(obj: object, method: str) -> bool: + """Check whether the object has a method. + + Args: + method (str): The method name to check. + obj (object): The object to check. + + Returns: + bool: True if the object has the method else False. + """ + return hasattr(obj, method) and callable(getattr(obj, method)) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/utils/parrots_jit.py b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/parrots_jit.py new file mode 100755 index 00000000..61873f6d --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/parrots_jit.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os + +from .parrots_wrapper import TORCH_VERSION + +parrots_jit_option = os.getenv('PARROTS_JIT_OPTION') + +if TORCH_VERSION == 'parrots' and parrots_jit_option == 'ON': + from parrots.jit import pat as jit +else: + + def jit(func=None, + check_input=None, + full_shape=True, + derivate=False, + coderize=False, + optimize=False): + + def wrapper(func): + + def wrapper_inner(*args, **kargs): + return func(*args, **kargs) + + return wrapper_inner + + if func is None: + return wrapper + else: + return func + + +if TORCH_VERSION == 'parrots': + from parrots.utils.tester import skip_no_elena +else: + + def skip_no_elena(func): + + def wrapper(*args, **kargs): + return func(*args, **kargs) + + return wrapper diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/utils/parrots_wrapper.py b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/parrots_wrapper.py new file mode 100755 index 00000000..cf2c7e5c --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/parrots_wrapper.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import partial + +import torch + +TORCH_VERSION = torch.__version__ + + +def is_cuda_available() -> bool: + return torch.cuda.is_available() + + +IS_CUDA_AVAILABLE = is_cuda_available() + + +def is_rocm_pytorch() -> bool: + is_rocm = False + if TORCH_VERSION != 'parrots': + try: + from torch.utils.cpp_extension import ROCM_HOME + is_rocm = True if ((torch.version.hip is not None) and + (ROCM_HOME is not None)) else False + except ImportError: + pass + return is_rocm + + +def _get_cuda_home(): + if TORCH_VERSION == 'parrots': + from parrots.utils.build_extension import CUDA_HOME + else: + if is_rocm_pytorch(): + from torch.utils.cpp_extension import ROCM_HOME + CUDA_HOME = ROCM_HOME + else: + from torch.utils.cpp_extension import CUDA_HOME + return CUDA_HOME + + +def get_build_config(): + if TORCH_VERSION == 'parrots': + from parrots.config import get_build_info + return get_build_info() + else: + return torch.__config__.show() + + +def _get_conv(): + if TORCH_VERSION == 'parrots': + from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin + else: + from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin + return _ConvNd, _ConvTransposeMixin + + +def _get_dataloader(): + if TORCH_VERSION == 'parrots': + from torch.utils.data import DataLoader, PoolDataLoader + else: + from torch.utils.data import DataLoader + PoolDataLoader = DataLoader + return DataLoader, PoolDataLoader + + +def _get_extension(): + if TORCH_VERSION == 'parrots': + from parrots.utils.build_extension import BuildExtension, Extension + CppExtension = partial(Extension, cuda=False) + CUDAExtension = partial(Extension, cuda=True) + else: + from torch.utils.cpp_extension import (BuildExtension, CppExtension, + CUDAExtension) + return BuildExtension, CppExtension, CUDAExtension + + +def _get_pool(): + if TORCH_VERSION == 'parrots': + from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd, + _AdaptiveMaxPoolNd, _AvgPoolNd, + _MaxPoolNd) + else: + from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd, + _AdaptiveMaxPoolNd, _AvgPoolNd, + _MaxPoolNd) + return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd + + +def _get_norm(): + if TORCH_VERSION == 'parrots': + from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm + SyncBatchNorm_ = torch.nn.SyncBatchNorm2d + else: + from torch.nn.modules.batchnorm import _BatchNorm + from torch.nn.modules.instancenorm import _InstanceNorm + SyncBatchNorm_ = torch.nn.SyncBatchNorm + return _BatchNorm, _InstanceNorm, SyncBatchNorm_ + + +_ConvNd, _ConvTransposeMixin = _get_conv() +DataLoader, PoolDataLoader = _get_dataloader() +BuildExtension, CppExtension, CUDAExtension = _get_extension() +_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm() +_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool() + + +class SyncBatchNorm(SyncBatchNorm_): # type: ignore + + def _check_input_dim(self, input): + if TORCH_VERSION == 'parrots': + if input.dim() < 2: + raise ValueError( + f'expected at least 2D input (got {input.dim()}D input)') + else: + super()._check_input_dim(input) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/utils/path.py b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/path.py new file mode 100755 index 00000000..56808183 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/path.py @@ -0,0 +1,101 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +from pathlib import Path + +from .misc import is_str + + +def is_filepath(x): + return is_str(x) or isinstance(x, Path) + + +def fopen(filepath, *args, **kwargs): + if is_str(filepath): + return open(filepath, *args, **kwargs) + elif isinstance(filepath, Path): + return filepath.open(*args, **kwargs) + raise ValueError('`filepath` should be a string or a Path') + + +def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): + if not osp.isfile(filename): + raise FileNotFoundError(msg_tmpl.format(filename)) + + +def mkdir_or_exist(dir_name, mode=0o777): + if dir_name == '': + return + dir_name = osp.expanduser(dir_name) + os.makedirs(dir_name, mode=mode, exist_ok=True) + + +def symlink(src, dst, overwrite=True, **kwargs): + if os.path.lexists(dst) and overwrite: + os.remove(dst) + os.symlink(src, dst, **kwargs) + + +def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True): + """Scan a directory to find the interested files. + + Args: + dir_path (str | :obj:`Path`): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + case_sensitive (bool, optional) : If set to False, ignore the case of + suffix. Default: True. + + Returns: + A generator for all the interested files with relative paths. + """ + if isinstance(dir_path, (str, Path)): + dir_path = str(dir_path) + else: + raise TypeError('"dir_path" must be a string or Path object') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + if suffix is not None and not case_sensitive: + suffix = suffix.lower() if isinstance(suffix, str) else tuple( + item.lower() for item in suffix) + + root = dir_path + + def _scandir(dir_path, suffix, recursive, case_sensitive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + _rel_path = rel_path if case_sensitive else rel_path.lower() + if suffix is None or _rel_path.endswith(suffix): + yield rel_path + elif recursive and os.path.isdir(entry.path): + # scan recursively if entry.path is a directory + yield from _scandir(entry.path, suffix, recursive, + case_sensitive) + + return _scandir(dir_path, suffix, recursive, case_sensitive) + + +def find_vcs_root(path, markers=('.git', )): + """Finds the root directory (including itself) of specified markers. + + Args: + path (str): Path of directory or file. + markers (list[str], optional): List of file or directory names. + + Returns: + The directory contained one of the markers or None if not found. + """ + if osp.isfile(path): + path = osp.dirname(path) + + prev, cur = None, osp.abspath(osp.expanduser(path)) + while cur != prev: + if any(osp.exists(osp.join(cur, marker)) for marker in markers): + return cur + prev, cur = cur, osp.split(cur)[0] + return None diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/utils/progressbar.py b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/progressbar.py new file mode 100755 index 00000000..0062f670 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/progressbar.py @@ -0,0 +1,208 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys +from collections.abc import Iterable +from multiprocessing import Pool +from shutil import get_terminal_size + +from .timer import Timer + + +class ProgressBar: + """A progress bar which can print the progress.""" + + def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout): + self.task_num = task_num + self.bar_width = bar_width + self.completed = 0 + self.file = file + if start: + self.start() + + @property + def terminal_width(self): + width, _ = get_terminal_size() + return width + + def start(self): + if self.task_num > 0: + self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, ' + 'elapsed: 0s, ETA:') + else: + self.file.write('completed: 0, elapsed: 0s') + self.file.flush() + self.timer = Timer() + + def update(self, num_tasks=1): + assert num_tasks > 0 + self.completed += num_tasks + elapsed = self.timer.since_start() + if elapsed > 0: + fps = self.completed / elapsed + else: + fps = float('inf') + if self.task_num > 0: + percentage = self.completed / float(self.task_num) + eta = int(elapsed * (1 - percentage) / percentage + 0.5) + msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \ + f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \ + f'ETA: {eta:5}s' + + bar_width = min(self.bar_width, + int(self.terminal_width - len(msg)) + 2, + int(self.terminal_width * 0.6)) + bar_width = max(2, bar_width) + mark_width = int(bar_width * percentage) + bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width) + self.file.write(msg.format(bar_chars)) + else: + self.file.write( + f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,' + f' {fps:.1f} tasks/s') + self.file.flush() + + +def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs): + """Track the progress of tasks execution with a progress bar. + + Tasks are done with a simple for-loop. + + Args: + func (callable): The function to be applied to each task. + tasks (list or tuple[Iterable, int]): A list of tasks or + (tasks, total num). + bar_width (int): Width of progress bar. + + Returns: + list: The task results. + """ + if isinstance(tasks, tuple): + assert len(tasks) == 2 + assert isinstance(tasks[0], Iterable) + assert isinstance(tasks[1], int) + task_num = tasks[1] + tasks = tasks[0] + elif isinstance(tasks, Iterable): + task_num = len(tasks) + else: + raise TypeError( + '"tasks" must be an iterable object or a (iterator, int) tuple') + prog_bar = ProgressBar(task_num, bar_width, file=file) + results = [] + for task in tasks: + results.append(func(task, **kwargs)) + prog_bar.update() + prog_bar.file.write('\n') + return results + + +def init_pool(process_num, initializer=None, initargs=None): + if initializer is None: + return Pool(process_num) + elif initargs is None: + return Pool(process_num, initializer) + else: + if not isinstance(initargs, tuple): + raise TypeError('"initargs" must be a tuple') + return Pool(process_num, initializer, initargs) + + +def track_parallel_progress(func, + tasks, + nproc, + initializer=None, + initargs=None, + bar_width=50, + chunksize=1, + skip_first=False, + keep_order=True, + file=sys.stdout): + """Track the progress of parallel task execution with a progress bar. + + The built-in :mod:`multiprocessing` module is used for process pools and + tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`. + + Args: + func (callable): The function to be applied to each task. + tasks (list or tuple[Iterable, int]): A list of tasks or + (tasks, total num). + nproc (int): Process (worker) number. + initializer (None or callable): Refer to :class:`multiprocessing.Pool` + for details. + initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for + details. + chunksize (int): Refer to :class:`multiprocessing.Pool` for details. + bar_width (int): Width of progress bar. + skip_first (bool): Whether to skip the first sample for each worker + when estimating fps, since the initialization step may takes + longer. + keep_order (bool): If True, :func:`Pool.imap` is used, otherwise + :func:`Pool.imap_unordered` is used. + + Returns: + list: The task results. + """ + if isinstance(tasks, tuple): + assert len(tasks) == 2 + assert isinstance(tasks[0], Iterable) + assert isinstance(tasks[1], int) + task_num = tasks[1] + tasks = tasks[0] + elif isinstance(tasks, Iterable): + task_num = len(tasks) + else: + raise TypeError( + '"tasks" must be an iterable object or a (iterator, int) tuple') + pool = init_pool(nproc, initializer, initargs) + start = not skip_first + task_num -= nproc * chunksize * int(skip_first) + prog_bar = ProgressBar(task_num, bar_width, start, file=file) + results = [] + if keep_order: + gen = pool.imap(func, tasks, chunksize) + else: + gen = pool.imap_unordered(func, tasks, chunksize) + for result in gen: + results.append(result) + if skip_first: + if len(results) < nproc * chunksize: + continue + elif len(results) == nproc * chunksize: + prog_bar.start() + continue + prog_bar.update() + prog_bar.file.write('\n') + pool.close() + pool.join() + return results + + +def track_iter_progress(tasks, bar_width=50, file=sys.stdout): + """Track the progress of tasks iteration or enumeration with a progress + bar. + + Tasks are yielded with a simple for-loop. + + Args: + tasks (list or tuple[Iterable, int]): A list of tasks or + (tasks, total num). + bar_width (int): Width of progress bar. + + Yields: + list: The task results. + """ + if isinstance(tasks, tuple): + assert len(tasks) == 2 + assert isinstance(tasks[0], Iterable) + assert isinstance(tasks[1], int) + task_num = tasks[1] + tasks = tasks[0] + elif isinstance(tasks, Iterable): + task_num = len(tasks) + else: + raise TypeError( + '"tasks" must be an iterable object or a (iterator, int) tuple') + prog_bar = ProgressBar(task_num, bar_width, file=file) + for task in tasks: + yield task + prog_bar.update() + prog_bar.file.write('\n') diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/utils/registry.py b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/registry.py new file mode 100755 index 00000000..a283583d --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/registry.py @@ -0,0 +1,340 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +import warnings +from functools import partial +from typing import Any, Dict, Optional + +from .misc import deprecated_api_warning, is_seq_of + + +def build_from_cfg(cfg: Dict, + registry: 'Registry', + default_args: Optional[Dict] = None) -> Any: + """Build a module from config dict when it is a class configuration, or + call a function from config dict when it is a function configuration. + + Example: + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + >>> resnet = build_from_cfg(dict(type='Resnet'), MODELS) + >>> # Returns an instantiated object + >>> @MODELS.register_module() + >>> def resnet50(): + >>> pass + >>> resnet = build_from_cfg(dict(type='resnet50'), MODELS) + >>> # Return a result of the calling function + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + registry (:obj:`Registry`): The registry to search the type from. + default_args (dict, optional): Default initialization arguments. + + Returns: + object: The constructed object. + """ + if not isinstance(cfg, dict): + raise TypeError(f'cfg must be a dict, but got {type(cfg)}') + if 'type' not in cfg: + if default_args is None or 'type' not in default_args: + raise KeyError( + '`cfg` or `default_args` must contain the key "type", ' + f'but got {cfg}\n{default_args}') + if not isinstance(registry, Registry): + raise TypeError('registry must be an dbnet_cv.Registry object, ' + f'but got {type(registry)}') + if not (isinstance(default_args, dict) or default_args is None): + raise TypeError('default_args must be a dict or None, ' + f'but got {type(default_args)}') + + args = cfg.copy() + + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + + obj_type = args.pop('type') + if isinstance(obj_type, str): + obj_cls = registry.get(obj_type) + if obj_cls is None: + raise KeyError( + f'{obj_type} is not in the {registry.name} registry') + elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): + obj_cls = obj_type + else: + raise TypeError( + f'type must be a str or valid type, but got {type(obj_type)}') + try: + return obj_cls(**args) + except Exception as e: + # Normal TypeError does not print class name. + raise type(e)(f'{obj_cls.__name__}: {e}') + + +class Registry: + """A registry to map strings to classes or functions. + + Registered object could be built from registry. Meanwhile, registered + functions could be called from registry. + + Example: + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + >>> resnet = MODELS.build(dict(type='ResNet')) + >>> @MODELS.register_module() + >>> def resnet50(): + >>> pass + >>> resnet = MODELS.build(dict(type='resnet50')) + + Please refer to + https://dbnet_cv.readthedocs.io/en/latest/understand_dbnet_cv/registry.html for + advanced usage. + + Args: + name (str): Registry name. + build_func(func, optional): Build function to construct instance from + Registry, func:`build_from_cfg` is used if neither ``parent`` or + ``build_func`` is specified. If ``parent`` is specified and + ``build_func`` is not given, ``build_func`` will be inherited + from ``parent``. Default: None. + parent (Registry, optional): Parent registry. The class registered in + children registry could be built from parent. Default: None. + scope (str, optional): The scope of registry. It is the key to search + for children registry. If not specified, scope will be the name of + the package where class is defined, e.g. dbnet_det, mmcls, mmseg. + Default: None. + """ + + def __init__(self, name, build_func=None, parent=None, scope=None): + self._name = name + self._module_dict = dict() + self._children = dict() + self._scope = self.infer_scope() if scope is None else scope + + # self.build_func will be set with the following priority: + # 1. build_func + # 2. parent.build_func + # 3. build_from_cfg + if build_func is None: + if parent is not None: + self.build_func = parent.build_func + else: + self.build_func = build_from_cfg + else: + self.build_func = build_func + if parent is not None: + assert isinstance(parent, Registry) + parent._add_children(self) + self.parent = parent + else: + self.parent = None + + def __len__(self): + return len(self._module_dict) + + def __contains__(self, key): + return self.get(key) is not None + + def __repr__(self): + format_str = self.__class__.__name__ + \ + f'(name={self._name}, ' \ + f'items={self._module_dict})' + return format_str + + @staticmethod + def infer_scope(): + """Infer the scope of registry. + + The name of the package where registry is defined will be returned. + + Example: + >>> # in dbnet_det/models/backbone/resnet.py + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + The scope of ``ResNet`` will be ``dbnet_det``. + + Returns: + str: The inferred scope name. + """ + # We access the caller using inspect.currentframe() instead of + # inspect.stack() for performance reasons. See details in PR #1844 + frame = inspect.currentframe() + # get the frame where `infer_scope()` is called + infer_scope_caller = frame.f_back.f_back + filename = inspect.getmodule(infer_scope_caller).__name__ + split_filename = filename.split('.') + return split_filename[0] + + @staticmethod + def split_scope_key(key): + """Split scope and key. + + The first scope will be split from key. + + Examples: + >>> Registry.split_scope_key('dbnet_det.ResNet') + 'dbnet_det', 'ResNet' + >>> Registry.split_scope_key('ResNet') + None, 'ResNet' + + Return: + tuple[str | None, str]: The former element is the first scope of + the key, which can be ``None``. The latter is the remaining key. + """ + split_index = key.find('.') + if split_index != -1: + return key[:split_index], key[split_index + 1:] + else: + return None, key + + @property + def name(self): + return self._name + + @property + def scope(self): + return self._scope + + @property + def module_dict(self): + return self._module_dict + + @property + def children(self): + return self._children + + def get(self, key): + """Get the registry record. + + Args: + key (str): The class name in string format. + + Returns: + class: The corresponding class. + """ + scope, real_key = self.split_scope_key(key) + if scope is None or scope == self._scope: + # get from self + if real_key in self._module_dict: + return self._module_dict[real_key] + else: + # get from self._children + if scope in self._children: + return self._children[scope].get(real_key) + else: + # goto root + parent = self.parent + while parent.parent is not None: + parent = parent.parent + return parent.get(key) + + def build(self, *args, **kwargs): + return self.build_func(*args, **kwargs, registry=self) + + def _add_children(self, registry): + """Add children for a registry. + + The ``registry`` will be added as children based on its scope. + The parent registry could build objects from children registry. + + Example: + >>> models = Registry('models') + >>> dbnet_det_models = Registry('models', parent=models) + >>> @dbnet_det_models.register_module() + >>> class ResNet: + >>> pass + >>> resnet = models.build(dict(type='dbnet_det.ResNet')) + """ + + assert isinstance(registry, Registry) + assert registry.scope is not None + assert registry.scope not in self.children, \ + f'scope {registry.scope} exists in {self.name} registry' + self.children[registry.scope] = registry + + @deprecated_api_warning(name_dict=dict(module_class='module')) + def _register_module(self, module, module_name=None, force=False): + if not inspect.isclass(module) and not inspect.isfunction(module): + raise TypeError('module must be a class or a function, ' + f'but got {type(module)}') + + if module_name is None: + module_name = module.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in self._module_dict: + raise KeyError(f'{name} is already registered ' + f'in {self.name}') + self._module_dict[name] = module + + def deprecated_register_module(self, cls=None, force=False): + warnings.warn( + 'The old API of register_module(module, force=False) ' + 'is deprecated and will be removed, please use the new API ' + 'register_module(name=None, force=False, module=None) instead.', + DeprecationWarning) + if cls is None: + return partial(self.deprecated_register_module, force=force) + self._register_module(cls, force=force) + return cls + + def register_module(self, name=None, force=False, module=None): + """Register a module. + + A record will be added to `self._module_dict`, whose key is the class + name or the specified name, and value is the class itself. + It can be used as a decorator or a normal function. + + Example: + >>> backbones = Registry('backbone') + >>> @backbones.register_module() + >>> class ResNet: + >>> pass + + >>> backbones = Registry('backbone') + >>> @backbones.register_module(name='mnet') + >>> class MobileNet: + >>> pass + + >>> backbones = Registry('backbone') + >>> class ResNet: + >>> pass + >>> backbones.register_module(ResNet) + + Args: + name (str | None): The module name to be registered. If not + specified, the class name will be used. + force (bool, optional): Whether to override an existing class with + the same name. Default: False. + module (type): Module class or function to be registered. + """ + if not isinstance(force, bool): + raise TypeError(f'force must be a boolean, but got {type(force)}') + # NOTE: This is a walkaround to be compatible with the old api, + # while it may introduce unexpected bugs. + if isinstance(name, type): + return self.deprecated_register_module(name, force=force) + + # raise the error ahead of time + if not (name is None or isinstance(name, str) or is_seq_of(name, str)): + raise TypeError( + 'name must be either of None, an instance of str or a sequence' + f' of str, but got {type(name)}') + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + self._register_module(module=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(module): + self._register_module(module=module, module_name=name, force=force) + return module + + return _register diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/utils/seed.py b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/seed.py new file mode 100755 index 00000000..003f9236 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/seed.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random + +import numpy as np +import torch + + +def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): + """Function to initialize each worker. + + The seed of each worker equals to + ``num_worker * rank + worker_id + user_seed``. + + Args: + worker_id (int): Id for each worker. + num_workers (int): Number of workers. + rank (int): Rank in distributed training. + seed (int): Random seed. + """ + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) + torch.manual_seed(worker_seed) diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/utils/testing.py b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/testing.py new file mode 100755 index 00000000..7b64e8fa --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/testing.py @@ -0,0 +1,141 @@ +# Copyright (c) Open-MMLab. +import sys +from collections.abc import Iterable +from runpy import run_path +from shlex import split +from typing import Any, Dict, List +from unittest.mock import patch + + +def check_python_script(cmd): + """Run the python cmd script with `__main__`. The difference between + `os.system` is that, this function exectues code in the current process, so + that it can be tracked by coverage tools. Currently it supports two forms: + + - ./tests/data/scripts/hello.py zz + - python tests/data/scripts/hello.py zz + """ + args = split(cmd) + if args[0] == 'python': + args = args[1:] + with patch.object(sys, 'argv', args): + run_path(args[0], run_name='__main__') + + +def _any(judge_result): + """Since built-in ``any`` works only when the element of iterable is not + iterable, implement the function.""" + if not isinstance(judge_result, Iterable): + return judge_result + + try: + for element in judge_result: + if _any(element): + return True + except TypeError: + # Maybe encounter the case: torch.tensor(True) | torch.tensor(False) + if judge_result: + return True + return False + + +def assert_dict_contains_subset(dict_obj: Dict[Any, Any], + expected_subset: Dict[Any, Any]) -> bool: + """Check if the dict_obj contains the expected_subset. + + Args: + dict_obj (Dict[Any, Any]): Dict object to be checked. + expected_subset (Dict[Any, Any]): Subset expected to be contained in + dict_obj. + + Returns: + bool: Whether the dict_obj contains the expected_subset. + """ + + for key, value in expected_subset.items(): + if key not in dict_obj.keys() or _any(dict_obj[key] != value): + return False + return True + + +def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool: + """Check if attribute of class object is correct. + + Args: + obj (object): Class object to be checked. + expected_attrs (Dict[str, Any]): Dict of the expected attrs. + + Returns: + bool: Whether the attribute of class object is correct. + """ + for attr, value in expected_attrs.items(): + if not hasattr(obj, attr) or _any(getattr(obj, attr) != value): + return False + return True + + +def assert_dict_has_keys(obj: Dict[str, Any], + expected_keys: List[str]) -> bool: + """Check if the obj has all the expected_keys. + + Args: + obj (Dict[str, Any]): Object to be checked. + expected_keys (List[str]): Keys expected to contained in the keys of + the obj. + + Returns: + bool: Whether the obj has the expected keys. + """ + return set(expected_keys).issubset(set(obj.keys())) + + +def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool: + """Check if target_keys is equal to result_keys. + + Args: + result_keys (List[str]): Result keys to be checked. + target_keys (List[str]): Target keys to be checked. + + Returns: + bool: Whether target_keys is equal to result_keys. + """ + return set(result_keys) == set(target_keys) + + +def assert_is_norm_layer(module) -> bool: + """Check if the module is a norm layer. + + Args: + module (nn.Module): The module to be checked. + + Returns: + bool: Whether the module is a norm layer. + """ + from torch.nn import GroupNorm, LayerNorm + + from .parrots_wrapper import _BatchNorm, _InstanceNorm + norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm) + return isinstance(module, norm_layer_candidates) + + +def assert_params_all_zeros(module) -> bool: + """Check if the parameters of the module is all zeros. + + Args: + module (nn.Module): The module to be checked. + + Returns: + bool: Whether the parameters of the module is all zeros. + """ + weight_data = module.weight.data + is_weight_zero = weight_data.allclose( + weight_data.new_zeros(weight_data.size())) + + if hasattr(module, 'bias') and module.bias is not None: + bias_data = module.bias.data + is_bias_zero = bias_data.allclose( + bias_data.new_zeros(bias_data.size())) + else: + is_bias_zero = True + + return is_weight_zero and is_bias_zero diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/utils/timer.py b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/timer.py new file mode 100755 index 00000000..c236356c --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/timer.py @@ -0,0 +1,118 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from time import time + + +class TimerError(Exception): + + def __init__(self, message): + self.message = message + super().__init__(message) + + +class Timer: + """A flexible Timer class. + + Examples: + >>> import time + >>> import dbnet_cv + >>> with dbnet_cv.Timer(): + >>> # simulate a code block that will run for 1s + >>> time.sleep(1) + 1.000 + >>> with dbnet_cv.Timer(print_tmpl='it takes {:.1f} seconds'): + >>> # simulate a code block that will run for 1s + >>> time.sleep(1) + it takes 1.0 seconds + >>> timer = dbnet_cv.Timer() + >>> time.sleep(0.5) + >>> print(timer.since_start()) + 0.500 + >>> time.sleep(0.5) + >>> print(timer.since_last_check()) + 0.500 + >>> print(timer.since_start()) + 1.000 + """ + + def __init__(self, start=True, print_tmpl=None): + self._is_running = False + self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}' + if start: + self.start() + + @property + def is_running(self): + """bool: indicate whether the timer is running""" + return self._is_running + + def __enter__(self): + self.start() + return self + + def __exit__(self, type, value, traceback): + print(self.print_tmpl.format(self.since_last_check())) + self._is_running = False + + def start(self): + """Start the timer.""" + if not self._is_running: + self._t_start = time() + self._is_running = True + self._t_last = time() + + def since_start(self): + """Total time since the timer is started. + + Returns: + float: Time in seconds. + """ + if not self._is_running: + raise TimerError('timer is not running') + self._t_last = time() + return self._t_last - self._t_start + + def since_last_check(self): + """Time since the last checking. + + Either :func:`since_start` or :func:`since_last_check` is a checking + operation. + + Returns: + float: Time in seconds. + """ + if not self._is_running: + raise TimerError('timer is not running') + dur = time() - self._t_last + self._t_last = time() + return dur + + +_g_timers = {} # global timers + + +def check_time(timer_id): + """Add check points in a single line. + + This method is suitable for running a task on a list of items. A timer will + be registered when the method is called for the first time. + + Examples: + >>> import time + >>> import dbnet_cv + >>> for i in range(1, 6): + >>> # simulate a code block + >>> time.sleep(i) + >>> dbnet_cv.check_time('task1') + 2.000 + 3.000 + 4.000 + 5.000 + + Args: + str: Timer identifier. + """ + if timer_id not in _g_timers: + _g_timers[timer_id] = Timer() + return 0 + else: + return _g_timers[timer_id].since_last_check() diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/utils/trace.py b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/trace.py new file mode 100755 index 00000000..6b4b7599 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/trace.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch + +from dbnet_cv.utils import digit_version + + +def is_jit_tracing() -> bool: + if (torch.__version__ != 'parrots' + and digit_version(torch.__version__) >= digit_version('1.6.0')): + on_trace = torch.jit.is_tracing() + # In PyTorch 1.6, torch.jit.is_tracing has a bug. + # Refers to https://github.com/pytorch/pytorch/issues/42448 + if isinstance(on_trace, bool): + return on_trace + else: + return torch._C._is_tracing() + else: + warnings.warn( + 'torch.jit.is_tracing is only supported after v1.6.0. ' + 'Therefore is_tracing returns False automatically. Please ' + 'set on_trace manually if you are using trace.', UserWarning) + return False diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/utils/version_utils.py b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/version_utils.py new file mode 100755 index 00000000..77c41f60 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/utils/version_utils.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import subprocess +import warnings + +from packaging.version import parse + + +def digit_version(version_str: str, length: int = 4): + """Convert a version string into a tuple of integers. + + This method is usually used for comparing two versions. For pre-release + versions: alpha < beta < rc. + + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Default: 4. + + Returns: + tuple[int]: The version info in digits (integers). + """ + assert 'parrots' not in version_str + version = parse(version_str) + assert version.release, f'failed to parse version {version_str}' + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + mapping = {'a': -3, 'b': -2, 'rc': -1} + val = -4 + # version.pre can be None + if version.pre: + if version.pre[0] not in mapping: + warnings.warn(f'unknown prerelease version {version.pre[0]}, ' + 'version checking may go wrong') + else: + val = mapping[version.pre[0]] + release.extend([val, version.pre[-1]]) + else: + release.extend([val, 0]) + + elif version.is_postrelease: + release.extend([1, version.post]) # type: ignore + else: + release.extend([0, 0]) + return tuple(release) + + +def _minimal_ext_cmd(cmd): + # construct minimal environment + env = {} + for k in ['SYSTEMROOT', 'PATH', 'HOME']: + v = os.environ.get(k) + if v is not None: + env[k] = v + # LANGUAGE is used on win32 + env['LANGUAGE'] = 'C' + env['LANG'] = 'C' + env['LC_ALL'] = 'C' + out = subprocess.Popen( + cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + return out + + +def get_git_hash(fallback='unknown', digits=None): + """Get the git hash of the current repo. + + Args: + fallback (str, optional): The fallback string when git hash is + unavailable. Defaults to 'unknown'. + digits (int, optional): kept digits of the hash. Defaults to None, + meaning all digits are kept. + + Returns: + str: Git commit hash. + """ + + if digits is not None and not isinstance(digits, int): + raise TypeError('digits must be None or an integer') + + try: + out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) + sha = out.strip().decode('ascii') + if digits is not None: + sha = sha[:digits] + except OSError: + sha = fallback + + return sha diff --git a/cv/ocr/dbnet/pytorch/dbnet_cv/version.py b/cv/ocr/dbnet/pytorch/dbnet_cv/version.py new file mode 100755 index 00000000..a65e3278 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_cv/version.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +__version__ = '1.5.3' + + +def parse_version_info(version_str: str, length: int = 4) -> tuple: + """Parse a version string into a tuple. + + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Default: 4. + + Returns: + tuple[int | str]: The version info, e.g., "1.3.0" is parsed into + (1, 3, 0, 0, 0, 0), and "2.0.0rc1" is parsed into + (2, 0, 0, 0, 'rc', 1) (when length is set to 4). + """ + from packaging.version import parse + version = parse(version_str) + assert version.release, f'failed to parse version {version_str}' + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + release.extend(list(version.pre)) # type: ignore + elif version.is_postrelease: + release.extend(list(version.post)) # type: ignore + else: + release.extend([0, 0]) + return tuple(release) + + +version_info = tuple(int(x) for x in __version__.split('.')[:3]) + +__all__ = ['__version__', 'version_info', 'parse_version_info'] diff --git a/cv/ocr/dbnet/pytorch/dist_train.sh b/cv/ocr/dbnet/pytorch/dist_train.sh new file mode 100755 index 00000000..4ef0ea66 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dist_train.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +CONFIG=$1 +GPUS=$2 +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +PORT=${PORT:-29500} +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python3 -m torch.distributed.launch \ + --nnodes=$NNODES \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --nproc_per_node=$GPUS \ + --master_port=$PORT \ + $(dirname "$0")/train.py \ + $CONFIG \ + --launcher pytorch ${@:3} diff --git a/cv/ocr/dbnet/pytorch/requirements.txt b/cv/ocr/dbnet/pytorch/requirements.txt new file mode 100755 index 00000000..9b52871d --- /dev/null +++ b/cv/ocr/dbnet/pytorch/requirements.txt @@ -0,0 +1,12 @@ + + +addict +yapf +terminaltables +pycocotools +lmdb +shapely +pyclipper +imgaug + + diff --git a/cv/ocr/dbnet/pytorch/setup.py b/cv/ocr/dbnet/pytorch/setup.py new file mode 100755 index 00000000..1ade11c3 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/setup.py @@ -0,0 +1,327 @@ +import glob +import os +import platform +import re +import warnings +from pkg_resources import DistributionNotFound, get_distribution +from setuptools import find_packages, setup + +# EXT_TYPE = '' +# try: +# import torch +# if torch.__version__ == 'parrots': +# from parrots.utils.build_extension import BuildExtension +# EXT_TYPE = 'parrots' +# elif (hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()) or \ +# os.getenv('FORCE_MLU', '0') == '1': +# from torch_mlu.utils.cpp_extension import BuildExtension +# EXT_TYPE = 'pytorch' +# else: +# from torch.utils.cpp_extension import BuildExtension +# EXT_TYPE = 'pytorch' +# cmd_class = {'build_ext': BuildExtension} +# except ModuleNotFoundError: +# cmd_class = {} +# print('Skip building ext ops due to the absence of torch.') + +import torch +from torch.utils.cpp_extension import BuildExtension +cmd_class = {'build_ext': BuildExtension} +EXT_TYPE = 'pytorch' +# def choose_requirement(primary, secondary): +# """If some version of primary requirement installed, return primary, else +# return secondary.""" +# try: +# name = re.split(r'[!<>=]', primary)[0] +# get_distribution(name) +# except DistributionNotFound: +# return secondary + +# return str(primary) + + +# def get_version(): +# version_file = 'dbnet_cv/version.py' +# with open(version_file, 'r', encoding='utf-8') as f: +# exec(compile(f.read(), version_file, 'exec')) +# version = locals()['__version__'] +# local_version_identifier = os.environ.get('DBNET_CV_LOCAL_VERSION_IDENTIFIER', '') +# if local_version_identifier != '': +# version += '+' + local_version_identifier +# return version + + +# def parse_requirements(fname='requirements/runtime.txt', with_version=True): +# """Parse the package dependencies listed in a requirements file but strips +# specific versioning information. + +# Args: +# fname (str): path to requirements file +# with_version (bool, default=False): if True include version specs + +# Returns: +# List[str]: list of requirements items + +# CommandLine: +# python -c "import setup; print(setup.parse_requirements())" +# """ +# import sys +# from os.path import exists +# require_fpath = fname + +# def parse_line(line): +# """Parse information from a line in a requirements text file.""" +# if line.startswith('-r '): +# # Allow specifying requirements in other files +# target = line.split(' ')[1] +# for info in parse_require_file(target): +# yield info +# else: +# info = {'line': line} +# if line.startswith('-e '): +# info['package'] = line.split('#egg=')[1] +# else: +# # Remove versioning from the package +# pat = '(' + '|'.join(['>=', '==', '>']) + ')' +# parts = re.split(pat, line, maxsplit=1) +# parts = [p.strip() for p in parts] + +# info['package'] = parts[0] +# if len(parts) > 1: +# op, rest = parts[1:] +# if ';' in rest: +# # Handle platform specific dependencies +# # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies +# version, platform_deps = map(str.strip, +# rest.split(';')) +# info['platform_deps'] = platform_deps +# else: +# version = rest # NOQA +# info['version'] = (op, version) +# yield info + + # def parse_require_file(fpath): + # with open(fpath) as f: + # for line in f.readlines(): + # line = line.strip() + # if line and not line.startswith('#'): + # yield from parse_line(line) + + # def gen_packages_items(): + # if exists(require_fpath): + # for info in parse_require_file(require_fpath): + # parts = [info['package']] + # if with_version and 'version' in info: + # parts.extend(info['version']) + # if not sys.version.startswith('3.4'): + # # apparently package_deps are broken in 3.4 + # platform_deps = info.get('platform_deps') + # if platform_deps is not None: + # parts.append(';' + platform_deps) + # item = ''.join(parts) + # yield item + + # packages = list(gen_packages_items()) + # return packages + + +# install_requires = parse_requirements() + + + + +def get_extensions(): + extensions = [] + + + + + + + if EXT_TYPE == 'pytorch': + ext_name = 'dbnet_cv._ext' + from torch.utils.cpp_extension import CppExtension, CUDAExtension + + # prevent ninja from using too many resources + try: + import psutil + num_cpu = len(psutil.Process().cpu_affinity()) + cpu_use = max(4, num_cpu - 1) + except (ModuleNotFoundError, AttributeError): + cpu_use = 4 + + os.environ.setdefault('MAX_JOBS', str(cpu_use)) + define_macros = [] + + # Before PyTorch1.8.0, when compiling CUDA code, `cxx` is a + # required key passed to PyTorch. Even if there is no flag passed + # to cxx, users also need to pass an empty list to PyTorch. + # Since PyTorch1.8.0, it has a default value so users do not need + # to pass an empty list anymore. + # More details at https://github.com/pytorch/pytorch/pull/45956 + extra_compile_args = {'cxx': []} + + # Since the PR (https://github.com/open-mmlab/dbnet_cv/pull/1463) uses + # c++14 features, the argument ['std=c++14'] must be added here. + # However, in the windows environment, some standard libraries + # will depend on c++17 or higher. In fact, for the windows + # environment, the compiler will choose the appropriate compiler + # to compile those cpp files, so there is no need to add the + # argument + if platform.system() != 'Windows': + extra_compile_args['cxx'] = ['-std=c++14'] + + include_dirs = [] + + is_rocm_pytorch = False + try: + from torch.utils.cpp_extension import ROCM_HOME + is_rocm_pytorch = True if ((torch.version.hip is not None) and + (ROCM_HOME is not None)) else False + except ImportError: + pass + + if is_rocm_pytorch or torch.cuda.is_available() or os.getenv( + 'FORCE_CUDA', '0') == '1': + if is_rocm_pytorch: + define_macros += [('HIP_DIFF', None)] + define_macros += [('DBNET_CV_WITH_CUDA', None)] + cuda_args = os.getenv('DBNET_CV_CUDA_ARGS') + extra_compile_args['nvcc'] = [cuda_args] if cuda_args else [] + op_files = glob.glob('./dbnet_cv/ops/csrc/pytorch/*.cpp') + \ + glob.glob('./dbnet_cv/ops/csrc/pytorch/cpu/*.cpp') + \ + glob.glob('./dbnet_cv/ops/csrc/pytorch/cuda/*.cu') + \ + glob.glob('./dbnet_cv/ops/csrc/pytorch/cuda/*.cpp') + extension = CUDAExtension + include_dirs.append(os.path.abspath('./dbnet_cv/ops/csrc/common')) + include_dirs.append(os.path.abspath('./dbnet_cv/ops/csrc/common/cuda')) + # elif (hasattr(torch, 'is_mlu_available') and + # torch.is_mlu_available()) or \ + # os.getenv('FORCE_MLU', '0') == '1': + # from torch_mlu.utils.cpp_extension import MLUExtension + # define_macros += [('DBNET_CV_WITH_MLU', None)] + # mlu_args = os.getenv('DBNET_CV_MLU_ARGS') + # extra_compile_args['cncc'] = [mlu_args] if mlu_args else [] + # op_files = glob.glob('./dbnet_cv/ops/csrc/pytorch/*.cpp') + \ + # glob.glob('./dbnet_cv/ops/csrc/pytorch/cpu/*.cpp') + \ + # glob.glob('./dbnet_cv/ops/csrc/pytorch/mlu/*.cpp') + \ + # glob.glob('./dbnet_cv/ops/csrc/common/mlu/*.mlu') + # extension = MLUExtension + # include_dirs.append(os.path.abspath('./dbnet_cv/ops/csrc/common')) + # include_dirs.append(os.path.abspath('./dbnet_cv/ops/csrc/common/mlu')) + # else: + # print(f'Compiling {ext_name} only with CPU') + # op_files = glob.glob('./dbnet_cv/ops/csrc/pytorch/*.cpp') + \ + # glob.glob('./dbnet_cv/ops/csrc/pytorch/cpu/*.cpp') + # extension = CppExtension + # include_dirs.append(os.path.abspath('./dbnet_cv/ops/csrc/common')) + + # Since the PR (https://github.com/open-mmlab/dbnet_cv/pull/1463) uses + # c++14 features, the argument ['std=c++14'] must be added here. + # However, in the windows environment, some standard libraries + # will depend on c++17 or higher. In fact, for the windows + # environment, the compiler will choose the appropriate compiler + # to compile those cpp files, so there is no need to add the + # argument + if 'nvcc' in extra_compile_args and platform.system() != 'Windows': + extra_compile_args['nvcc'] += ['-std=c++14'] + + ext_ops = extension( + name=ext_name, + sources=op_files, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args) + extensions.append(ext_ops) + + # if EXT_TYPE == 'pytorch' and os.getenv('DBNET_CV_WITH_ORT', '0') != '0': + + # # Following strings of text style are from colorama package + # bright_style, reset_style = '\x1b[1m', '\x1b[0m' + # red_text, blue_text = '\x1b[31m', '\x1b[34m' + # white_background = '\x1b[107m' + + # msg = white_background + bright_style + red_text + # msg += 'DeprecationWarning: ' + \ + # 'Custom ONNXRuntime Ops will be deprecated in future. ' + # msg += blue_text + \ + # 'Welcome to use the unified model deployment toolbox ' + # msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy' + # msg += reset_style + # warnings.warn(msg) + # ext_name = 'dbnet_cv._ext_ort' + # import onnxruntime + # from torch.utils.cpp_extension import include_paths, library_paths + # library_dirs = [] + # libraries = [] + # include_dirs = [] + # ort_path = os.getenv('ONNXRUNTIME_DIR', '0') + # library_dirs += [os.path.join(ort_path, 'lib')] + # libraries.append('onnxruntime') + # define_macros = [] + # extra_compile_args = {'cxx': []} + + # include_path = os.path.abspath('./dbnet_cv/ops/csrc/onnxruntime') + # include_dirs.append(include_path) + # include_dirs.append(os.path.join(ort_path, 'include')) + + # op_files = glob.glob('./dbnet_cv/ops/csrc/onnxruntime/cpu/*') + # if onnxruntime.get_device() == 'GPU' or os.getenv('FORCE_CUDA', + # '0') == '1': + # define_macros += [('DBNET_CV_WITH_CUDA', None)] + # cuda_args = os.getenv('DBNET_CV_CUDA_ARGS') + # extra_compile_args['nvcc'] = [cuda_args] if cuda_args else [] + # op_files += glob.glob('./dbnet_cv/ops/csrc/onnxruntime/gpu/*') + # include_dirs += include_paths(cuda=True) + # library_dirs += library_paths(cuda=True) + # else: + # include_dirs += include_paths(cuda=False) + # library_dirs += library_paths(cuda=False) + + # from setuptools import Extension + # ext_ops = Extension( + # name=ext_name, + # sources=op_files, + # include_dirs=include_dirs, + # define_macros=define_macros, + # extra_compile_args=extra_compile_args, + # language='c++', + # library_dirs=library_dirs, + # libraries=libraries) + # extensions.append(ext_ops) + + return extensions + + +setup( + name='dbnet_cv' if os.getenv('DBNET_CV_WITH_OPS', '0') == '0' else 'dbnet_cv-full', + # version=get_version(), + # description='OpenMMLab Computer Vision Foundation', + # keywords='computer vision', + packages=find_packages(), + include_package_data=True, + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Topic :: Utilities', + ], + # url='https://github.com/open-mmlab/dbnet_cv', + # author='DBNET_CV Contributors', + # author_email='openmmlab@gmail.com', + # install_requires=install_requires, + # extras_require={ + # 'all': parse_requirements('requirements.txt'), + # 'tests': parse_requirements('requirements/test.txt'), + # 'build': parse_requirements('requirements/build.txt'), + # 'optional': parse_requirements('requirements/optional.txt'), + # }, + ext_modules=get_extensions(), + cmdclass=cmd_class, + zip_safe=False) diff --git a/cv/ocr/dbnet/pytorch/tools/analyze_logs.py b/cv/ocr/dbnet/pytorch/tools/analyze_logs.py new file mode 100755 index 00000000..349735ef --- /dev/null +++ b/cv/ocr/dbnet/pytorch/tools/analyze_logs.py @@ -0,0 +1,190 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from https://github.com/open- +mmlab/dbnet_detection/blob/master/tools/analysis_tools/analyze_logs.py.""" +import argparse +import json +from collections import defaultdict + +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns + + +def cal_train_time(log_dicts, args): + for i, log_dict in enumerate(log_dicts): + print(f'{"-" * 5}Analyze train time of {args.json_logs[i]}{"-" * 5}') + all_times = [] + for epoch in log_dict.keys(): + if args.include_outliers: + all_times.append(log_dict[epoch]['time']) + else: + all_times.append(log_dict[epoch]['time'][1:]) + all_times = np.array(all_times) + epoch_ave_time = all_times.mean(-1) + slowest_epoch = epoch_ave_time.argmax() + fastest_epoch = epoch_ave_time.argmin() + std_over_epoch = epoch_ave_time.std() + print(f'slowest epoch {slowest_epoch + 1}, ' + f'average time is {epoch_ave_time[slowest_epoch]:.4f}') + print(f'fastest epoch {fastest_epoch + 1}, ' + f'average time is {epoch_ave_time[fastest_epoch]:.4f}') + print(f'time std over epochs is {std_over_epoch:.4f}') + print(f'average iter time: {np.mean(all_times):.4f} s/iter') + print() + + +def plot_curve(log_dicts, args): + if args.backend is not None: + plt.switch_backend(args.backend) + sns.set_style(args.style) + # if legend is None, use {filename}_{key} as legend + legend = args.legend + if legend is None: + legend = [] + for json_log in args.json_logs: + for metric in args.keys: + legend.append(f'{json_log}_{metric}') + assert len(legend) == (len(args.json_logs) * len(args.keys)) + metrics = args.keys + + num_metrics = len(metrics) + for i, log_dict in enumerate(log_dicts): + epochs = list(log_dict.keys()) + for j, metric in enumerate(metrics): + print(f'Plot curve of {args.json_logs[i]}, metric is {metric}') + + epoch_based_metrics = [ + 'hmean', 'word_acc', 'word_acc_ignore_case', + 'word_acc_ignore_case_symbol', 'char_recall', 'char_precision', + '1-N.E.D', 'macro_f1' + ] + if any(metric in m for m in epoch_based_metrics): + # determine whether it is a epoch-plotted metric + # e.g. hmean-iou:hmean, 0_word_acc + xs = [] + ys = [] + for epoch in epochs: + ys += log_dict[epoch][metric] + if 'val' in log_dict[epoch]['mode']: + xs.append(epoch) + plt.xlabel('epoch') + plt.plot(xs, ys, label=legend[i * num_metrics + j], marker='o') + else: + xs = [] + ys = [] + if log_dict[epochs[0]]['mode'][-1] == 'val': + num_iters_per_epoch = log_dict[epochs[0]]['iter'][-2] + else: + num_iters_per_epoch = log_dict[epochs[0]]['iter'][-1] + for epoch in epochs: + iters = log_dict[epoch]['iter'] + if log_dict[epoch]['mode'][-1] == 'val': + iters = iters[:-1] + xs.append( + np.array(iters) + (epoch - 1) * num_iters_per_epoch) + ys.append(np.array(log_dict[epoch][metric][:len(iters)])) + xs = np.concatenate(xs) + ys = np.concatenate(ys) + plt.xlabel('iter') + plt.plot( + xs, ys, label=legend[i * num_metrics + j], linewidth=0.5) + plt.grid() + plt.legend() + if args.title is not None: + plt.title(args.title) + if args.out is None: + plt.show() + else: + print(f'Save curve to: {args.out}') + plt.savefig(args.out) + plt.cla() + + +def add_plot_parser(subparsers): + parser_plt = subparsers.add_parser( + 'plot_curve', help='Parser for plotting curves') + parser_plt.add_argument( + 'json_logs', + type=str, + nargs='+', + help='Path of train log in json format') + parser_plt.add_argument( + '--keys', + type=str, + nargs='+', + default=['loss'], + help='The metric that you want to plot') + parser_plt.add_argument('--title', type=str, help='Title of figure') + parser_plt.add_argument( + '--legend', + type=str, + nargs='+', + default=None, + help='Legend of each plot') + parser_plt.add_argument( + '--backend', type=str, default=None, help='Backend of plt') + parser_plt.add_argument( + '--style', type=str, default='dark', help='Style of plt') + parser_plt.add_argument('--out', type=str, default=None) + + +def add_time_parser(subparsers): + parser_time = subparsers.add_parser( + 'cal_train_time', + help='Parser for computing the average time per training iteration') + parser_time.add_argument( + 'json_logs', + type=str, + nargs='+', + help='Path of train log in json format') + parser_time.add_argument( + '--include-outliers', + action='store_true', + help='Include the first value of every epoch when computing ' + 'the average time') + + +def parse_args(): + parser = argparse.ArgumentParser(description='Analyze Json Log') + # currently only support plot curve and calculate average train time + subparsers = parser.add_subparsers(dest='task', help='task parser') + add_plot_parser(subparsers) + add_time_parser(subparsers) + args = parser.parse_args() + return args + + +def load_json_logs(json_logs): + # load and convert json_logs to log_dict, key is epoch, value is a sub dict + # keys of sub dict is different metrics, e.g. memory, loss + # value of sub dict is a list of corresponding values of all iterations + log_dicts = [dict() for _ in json_logs] + for json_log, log_dict in zip(json_logs, log_dicts): + with open(json_log, 'r') as log_file: + for line in log_file: + log = json.loads(line.strip()) + # skip lines without `epoch` field + if 'epoch' not in log: + continue + epoch = log.pop('epoch') + if epoch not in log_dict: + log_dict[epoch] = defaultdict(list) + for k, v in log.items(): + log_dict[epoch][k].append(v) + return log_dicts + + +def main(): + args = parse_args() + + json_logs = args.json_logs + for json_log in json_logs: + assert json_log.endswith('.json') + + log_dicts = load_json_logs(json_logs) + + eval(args.task)(log_dicts, args) + + +if __name__ == '__main__': + main() diff --git a/cv/ocr/dbnet/pytorch/tools/benchmark_processing.py b/cv/ocr/dbnet/pytorch/tools/benchmark_processing.py new file mode 100755 index 00000000..cb499eda --- /dev/null +++ b/cv/ocr/dbnet/pytorch/tools/benchmark_processing.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python +# Copyright (c) OpenMMLab. All rights reserved. +"""This file is for benchmark data loading process. It can also be used to +refresh the memcached cache. The command line to run this file is: + +$ python -m cProfile -o program.prof tools/analysis/benchmark_processing.py +configs/task/method/[config filename] + +Note: When debugging, the `workers_per_gpu` in the config should be set to 0 +during benchmark. + +It use cProfile to record cpu running time and output to program.prof +To visualize cProfile output program.prof, use Snakeviz and run: +$ snakeviz program.prof +""" +import argparse + +import dbnet_cv +from dbnet_cv import Config +from dbnet_det.datasets import build_dataloader + +from dbnet.datasets import build_dataset + +assert build_dataset is not None + + +def main(): + parser = argparse.ArgumentParser(description='Benchmark data loading') + parser.add_argument('config', help='Train config file path.') + args = parser.parse_args() + cfg = Config.fromfile(args.config) + + dataset = build_dataset(cfg.data.train) + + # prepare data loaders + if 'imgs_per_gpu' in cfg.data: + cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu + + data_loader = build_dataloader( + dataset, + cfg.data.samples_per_gpu, + cfg.data.workers_per_gpu, + 1, + dist=False, + seed=None) + + # Start progress bar after first 5 batches + prog_bar = dbnet_cv.ProgressBar( + len(dataset) - 5 * cfg.data.samples_per_gpu, start=False) + for i, data in enumerate(data_loader): + if i == 5: + prog_bar.start() + for _ in range(len(data['img'])): + if i < 5: + continue + prog_bar.update() + + +if __name__ == '__main__': + main() diff --git a/cv/ocr/dbnet/pytorch/tools/misc/print_config.py b/cv/ocr/dbnet/pytorch/tools/misc/print_config.py new file mode 100755 index 00000000..7e5c6f49 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/tools/misc/print_config.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import warnings + +from dbnet_cv import Config, DictAction + + +def parse_args(): + parser = argparse.ArgumentParser(description='Print the whole config') + parser.add_argument('config', help='config file path') + parser.add_argument( + '--options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file (deprecate), ' + 'change to --cfg-options instead.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + + if args.options and args.cfg_options: + raise ValueError( + '--options and --cfg-options cannot be both ' + 'specified, --options is deprecated in favor of --cfg-options') + if args.options: + warnings.warn('--options is deprecated in favor of --cfg-options') + args.cfg_options = args.options + + return args + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + # import modules from string list. + if cfg.get('custom_imports', None): + from dbnet_cv.utils import import_modules_from_strings + import_modules_from_strings(**cfg['custom_imports']) + print(f'Config:\n{cfg.pretty_text}') + + +if __name__ == '__main__': + main() diff --git a/cv/ocr/dbnet/pytorch/tools/publish_model.py b/cv/ocr/dbnet/pytorch/tools/publish_model.py new file mode 100755 index 00000000..73b8a8cb --- /dev/null +++ b/cv/ocr/dbnet/pytorch/tools/publish_model.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import subprocess + +import torch + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Process a checkpoint to be published') + parser.add_argument('in_file', help='input checkpoint filename') + parser.add_argument('out_file', help='output checkpoint filename') + args = parser.parse_args() + return args + + +def process_checkpoint(in_file, out_file): + checkpoint = torch.load(in_file, map_location='cpu') + # remove optimizer for smaller file size + if 'optimizer' in checkpoint: + del checkpoint['optimizer'] + # if it is necessary to remove some sensitive data in checkpoint['meta'], + # add the code here. + if 'meta' in checkpoint: + checkpoint['meta'] = {'CLASSES': 0} + torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False) + sha = subprocess.check_output(['sha256sum', out_file]).decode() + final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) + subprocess.Popen(['mv', out_file, final_file]) + + +def main(): + args = parse_args() + process_checkpoint(args.in_file, args.out_file) + + +if __name__ == '__main__': + main() diff --git a/cv/ocr/dbnet/pytorch/train.py b/cv/ocr/dbnet/pytorch/train.py new file mode 100755 index 00000000..cf13ad13 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/train.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import copy +import os +import os.path as osp +import time +import warnings +import ssl +ssl._create_default_https_context = ssl._create_unverified_context + +import dbnet_cv +import torch +import torch.distributed as dist +from dbnet_cv import Config, DictAction +from dbnet_cv.runner import get_dist_info, init_dist, set_random_seed +from dbnet_cv.utils import get_git_hash + +from dbnet import __version__ +from dbnet.apis import init_random_seed, train_detector +from dbnet.datasets import build_dataset +from dbnet.models import build_detector +from dbnet.utils import (collect_env, get_root_logger, is_2dlist, + setup_multi_processes) + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a detector.') + parser.add_argument('config', help='Train config file path.') + parser.add_argument('--work-dir', help='The dir to save logs and models.') + parser.add_argument( + '--load-from', help='The checkpoint file to load from.') + parser.add_argument( + '--resume-from', help='The checkpoint file to resume from.') + parser.add_argument( + '--no-validate', + action='store_true', + help='Whether not to evaluate the checkpoint during training.') + group_gpus = parser.add_mutually_exclusive_group() + group_gpus.add_argument( + '--gpus', + type=int, + help='(Deprecated, please use --gpu-id) number of gpus to use ' + '(only applicable to non-distributed training).') + group_gpus.add_argument( + '--gpu-ids', + type=int, + nargs='+', + help='(Deprecated, please use --gpu-id) ids of gpus to use ' + '(only applicable to non-distributed training)') + group_gpus.add_argument( + '--gpu-id', + type=int, + default=0, + help='id of gpu to use ' + '(only applicable to non-distributed training)') + parser.add_argument('--seed', type=int, default=None, help='Random seed.') + parser.add_argument( + '--diff-seed', + action='store_true', + help='Whether or not set different seeds for different ranks') + parser.add_argument( + '--deterministic', + action='store_true', + help='Whether to set deterministic options for CUDNN backend.') + parser.add_argument( + '--options', + nargs='+', + action=DictAction, + help='Override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file (deprecate), ' + 'change to --cfg-options instead.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='Override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be of the form of either ' + 'key="[a,b]" or key=a,b .The argument also allows nested list/tuple ' + 'values, e.g. key="[(a,b),(c,d)]". Note that the quotation marks ' + 'are necessary and that no white space is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='Options for job launcher.') + parser.add_argument('--local_rank', type=int, default=0) + + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + if args.options and args.cfg_options: + raise ValueError( + '--options and --cfg-options cannot be both ' + 'specified, --options is deprecated in favor of --cfg-options') + if args.options: + warnings.warn('--options is deprecated in favor of --cfg-options') + args.cfg_options = args.options + + return args + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + setup_multi_processes(cfg) + + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + if args.load_from is not None: + cfg.load_from = args.load_from + if args.resume_from is not None: + cfg.resume_from = args.resume_from + if args.gpus is not None: + cfg.gpu_ids = range(1) + warnings.warn('`--gpus` is deprecated because we only support ' + 'single GPU mode in non-distributed training. ' + 'Use `gpus=1` now.') + if args.gpu_ids is not None: + cfg.gpu_ids = args.gpu_ids[0:1] + warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. ' + 'Because we only support single GPU mode in ' + 'non-distributed training. Use the first GPU ' + 'in `gpu_ids` now.') + if args.gpus is None and args.gpu_ids is None: + cfg.gpu_ids = [args.gpu_id] + + # init distributed env first, since logger depends on the dist info. + if args.launcher == 'none': + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + # re-set gpu_ids with distributed training mode + _, world_size = get_dist_info() + cfg.gpu_ids = range(world_size) + + # create work_dir + dbnet_cv.mkdir_or_exist(osp.abspath(cfg.work_dir)) + # dump config + cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) + # init the logger before other steps + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + log_file = osp.join(cfg.work_dir, f'{timestamp}.log') + logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) + + # init the meta dict to record some important information such as + # environment info and seed, which will be logged + meta = dict() + # log env info + env_info_dict = collect_env() + env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) + dash_line = '-' * 60 + '\n' + logger.info('Environment info:\n' + dash_line + env_info + '\n' + + dash_line) + meta['env_info'] = env_info + meta['config'] = cfg.pretty_text + # log some basic info + logger.info(f'Distributed training: {distributed}') + logger.info(f'Config:\n{cfg.pretty_text}') + + # set random seeds + seed = init_random_seed(args.seed) + seed = seed + dist.get_rank() if args.diff_seed else seed + logger.info(f'Set random seed to {seed}, ' + f'deterministic: {args.deterministic}') + set_random_seed(seed, deterministic=args.deterministic) + cfg.seed = seed + meta['seed'] = seed + meta['exp_name'] = osp.basename(args.config) + + model = build_detector( + cfg.model, + train_cfg=cfg.get('train_cfg'), + test_cfg=cfg.get('test_cfg')) + model.init_weights() + + datasets = [build_dataset(cfg.data.train)] + if len(cfg.workflow) == 2: + val_dataset = copy.deepcopy(cfg.data.val) + if cfg.data.train.get('pipeline', None) is None: + if is_2dlist(cfg.data.train.datasets): + train_pipeline = cfg.data.train.datasets[0][0].pipeline + else: + train_pipeline = cfg.data.train.datasets[0].pipeline + elif is_2dlist(cfg.data.train.pipeline): + train_pipeline = cfg.data.train.pipeline[0] + else: + train_pipeline = cfg.data.train.pipeline + + if val_dataset['type'] in ['ConcatDataset', 'UniformConcatDataset']: + for dataset in val_dataset['datasets']: + dataset.pipeline = train_pipeline + else: + val_dataset.pipeline = train_pipeline + datasets.append(build_dataset(val_dataset)) + if cfg.get('checkpoint_config', None) is not None: + # save dbnet_det version, config file content and class names in + # checkpoints as meta data + cfg.checkpoint_config.meta = dict( + mmocr_version=__version__ + get_git_hash()[:7], + CLASSES=datasets[0].CLASSES) + # add an attribute for visualization convenience + model.CLASSES = datasets[0].CLASSES + train_detector( + model, + datasets, + cfg, + distributed=distributed, + validate=(not args.no_validate), + timestamp=timestamp, + meta=meta) + + +if __name__ == '__main__': + main() -- Gitee From c3e8f6e8115e515da928640a9e089f8d59bd0477 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 9 Mar 2023 02:13:28 +0000 Subject: [PATCH 2/3] Signed-off-by: root modified: dist_train.sh for dbnet pytorch link #I6BFUO modified: dist_train.sh for dbnet pytorch --- cv/ocr/dbnet/pytorch/dist_train.sh | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/cv/ocr/dbnet/pytorch/dist_train.sh b/cv/ocr/dbnet/pytorch/dist_train.sh index 4ef0ea66..727c7431 100755 --- a/cv/ocr/dbnet/pytorch/dist_train.sh +++ b/cv/ocr/dbnet/pytorch/dist_train.sh @@ -1,4 +1,18 @@ #!/usr/bin/env bash +# Copyright (c) 2023, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. CONFIG=$1 GPUS=$2 -- Gitee From 099cddab11ef8f25adf50d275112774e3bc29ba8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=B4=E6=B0=B8=E4=B9=90?= Date: Thu, 16 Mar 2023 03:05:32 +0000 Subject: [PATCH 3/3] =?UTF-8?q?Signed-off-by:=20=E5=90=B4=E6=B0=B8?= =?UTF-8?q?=E4=B9=90=20?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit add dbnet_det for dbnet pytorch link #I6BFUO add dbnet_det for dbnet pytorch --- cv/ocr/dbnet/pytorch/README.md | 12 +- cv/ocr/dbnet/pytorch/dbnet_det/__init__.py | 29 + .../dbnet/pytorch/dbnet_det/apis/__init__.py | 1 + cv/ocr/dbnet/pytorch/dbnet_det/apis/test.py | 206 ++ .../dbnet/pytorch/dbnet_det/core/__init__.py | 10 + .../pytorch/dbnet_det/core/bbox/__init__.py | 28 + .../pytorch/dbnet_det/core/bbox/transforms.py | 270 ++ .../dbnet_det/core/evaluation/__init__.py | 19 + .../core/evaluation/bbox_overlaps.py | 65 + .../dbnet_det/core/evaluation/class_names.py | 332 ++ .../dbnet_det/core/evaluation/eval_hooks.py | 140 + .../dbnet_det/core/evaluation/mean_ap.py | 782 +++++ .../core/evaluation/panoptic_utils.py | 6 + .../dbnet_det/core/evaluation/recall.py | 197 ++ .../pytorch/dbnet_det/core/mask/__init__.py | 9 + .../pytorch/dbnet_det/core/mask/structures.py | 1102 +++++++ .../pytorch/dbnet_det/core/mask/utils.py | 89 + .../pytorch/dbnet_det/core/utils/__init__.py | 13 + .../dbnet_det/core/utils/dist_utils.py | 193 ++ .../pytorch/dbnet_det/core/utils/misc.py | 208 ++ .../dbnet_det/core/visualization/__init__.py | 9 + .../dbnet_det/core/visualization/image.py | 559 ++++ .../dbnet_det/core/visualization/palette.py | 59 + .../pytorch/dbnet_det/datasets/__init__.py | 28 + .../datasets/api_wrappers/__init__.py | 7 + .../datasets/api_wrappers/coco_api.py | 47 + .../api_wrappers/panoptic_evaluation.py | 228 ++ .../pytorch/dbnet_det/datasets/builder.py | 215 ++ .../dbnet/pytorch/dbnet_det/datasets/coco.py | 649 ++++ .../pytorch/dbnet_det/datasets/custom.py | 410 +++ .../dbnet_det/datasets/dataset_wrappers.py | 456 +++ .../dbnet_det/datasets/pipelines/__init__.py | 31 + .../dbnet_det/datasets/pipelines/compose.py | 55 + .../datasets/pipelines/formatting.py | 392 +++ .../dbnet_det/datasets/pipelines/loading.py | 643 ++++ .../datasets/pipelines/test_time_aug.py | 121 + .../datasets/pipelines/transforms.py | 2919 +++++++++++++++++ .../dbnet_det/datasets/samplers/__init__.py | 10 + .../datasets/samplers/class_aware_sampler.py | 176 + .../datasets/samplers/distributed_sampler.py | 54 + .../datasets/samplers/group_sampler.py | 148 + .../datasets/samplers/infinite_sampler.py | 186 ++ .../dbnet/pytorch/dbnet_det/datasets/utils.py | 166 + .../pytorch/dbnet_det/models/__init__.py | 19 + .../dbnet_det/models/backbones/__init__.py | 30 + .../models/backbones/mobilenet_v3.py | 265 ++ .../dbnet_det/models/backbones/resnet.py | 672 ++++ .../dbnet/pytorch/dbnet_det/models/builder.py | 59 + .../dbnet_det/models/detectors/__init__.py | 2 + .../dbnet_det/models/detectors/base.py | 360 ++ .../models/detectors/single_stage.py | 171 + .../dbnet_det/models/utils/__init__.py | 34 + .../models/utils/inverted_residual.py | 128 + .../dbnet_det/models/utils/make_divisible.py | 25 + .../dbnet_det/models/utils/res_layer.py | 190 ++ .../dbnet_det/models/utils/se_layer.py | 124 + .../dbnet/pytorch/dbnet_det/utils/__init__.py | 18 + .../dbnet_det/utils/inverted_residual.py | 209 ++ .../dbnet/pytorch/dbnet_det/utils/logger.py | 65 + .../pytorch/dbnet_det/utils/profiling.py | 40 + .../dbnet_det/utils/util_distribution.py | 74 + cv/ocr/dbnet/pytorch/dbnet_det/version.py | 19 + 62 files changed, 13779 insertions(+), 4 deletions(-) create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/apis/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/apis/test.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/core/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/core/bbox/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/core/bbox/transforms.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/bbox_overlaps.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/class_names.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/eval_hooks.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/mean_ap.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/panoptic_utils.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/recall.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/core/mask/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/core/mask/structures.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/core/mask/utils.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/core/utils/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/core/utils/dist_utils.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/core/utils/misc.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/core/visualization/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/core/visualization/image.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/core/visualization/palette.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/api_wrappers/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/api_wrappers/coco_api.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/api_wrappers/panoptic_evaluation.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/builder.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/coco.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/custom.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/dataset_wrappers.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/compose.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/formatting.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/loading.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/test_time_aug.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/transforms.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/class_aware_sampler.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/distributed_sampler.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/group_sampler.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/infinite_sampler.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/datasets/utils.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/models/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/models/backbones/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/models/backbones/mobilenet_v3.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/models/backbones/resnet.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/models/builder.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/models/detectors/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/models/detectors/base.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/models/detectors/single_stage.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/models/utils/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/models/utils/inverted_residual.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/models/utils/make_divisible.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/models/utils/res_layer.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/models/utils/se_layer.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/utils/__init__.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/utils/inverted_residual.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/utils/logger.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/utils/profiling.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/utils/util_distribution.py create mode 100755 cv/ocr/dbnet/pytorch/dbnet_det/version.py diff --git a/cv/ocr/dbnet/pytorch/README.md b/cv/ocr/dbnet/pytorch/README.md index 86641bbb..d92a6cfc 100755 --- a/cv/ocr/dbnet/pytorch/README.md +++ b/cv/ocr/dbnet/pytorch/README.md @@ -24,12 +24,16 @@ mv Challenge4_Test_Task1_GT annotations/test Please [instances_training.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_training.json) download instances_training.json here Please [instances_test.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_test.json) download instances_test.json here +```shell -├── icdar2015 -│   ├── imgs -│   ├── instances_test.json -│   └── instances_training.json +icdar2015/ +├── imgs +│ ├── test +│ └── training +├── instances_test.json +└── instances_training.json +``` ### Build Extension ```shell diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_det/__init__.py new file mode 100755 index 00000000..6f9c9ce0 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dbnet_cv + +from .version import __version__, short_version + + +def digit_version(version_str): + digit_version = [] + for x in version_str.split('.'): + if x.isdigit(): + digit_version.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + digit_version.append(int(patch_version[0]) - 1) + digit_version.append(int(patch_version[1])) + return digit_version + + +dbnet_cv_minimum_version = '1.3.17' +dbnet_cv_maximum_version = '1.6.0' +dbnet_cv_version = digit_version(dbnet_cv.__version__) + + +assert (dbnet_cv_version >= digit_version(dbnet_cv_minimum_version) + and dbnet_cv_version <= digit_version(dbnet_cv_maximum_version)), \ + f'DBNET_CV=={dbnet_cv.__version__} is used but incompatible. ' \ + f'Please install dbnet_cv>={dbnet_cv_minimum_version}, <={dbnet_cv_maximum_version}.' + +__all__ = ['__version__', 'short_version'] diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/apis/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_det/apis/__init__.py new file mode 100755 index 00000000..b9370ea3 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/apis/__init__.py @@ -0,0 +1 @@ +from .test import multi_gpu_test, single_gpu_test \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/apis/test.py b/cv/ocr/dbnet/pytorch/dbnet_det/apis/test.py new file mode 100755 index 00000000..1604818f --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/apis/test.py @@ -0,0 +1,206 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import pickle +import shutil +import tempfile +import time + +import dbnet_cv +import torch +import torch.distributed as dist +from dbnet_cv.image import tensor2imgs +from dbnet_cv.runner import get_dist_info + +from dbnet_det.core import encode_mask_results + + +def single_gpu_test(model, + data_loader, + show=False, + out_dir=None, + show_score_thr=0.3): + model.eval() + results = [] + dataset = data_loader.dataset + PALETTE = getattr(dataset, 'PALETTE', None) + prog_bar = dbnet_cv.ProgressBar(len(dataset)) + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(return_loss=False, rescale=True, **data) + + batch_size = len(result) + if show or out_dir: + if batch_size == 1 and isinstance(data['img'][0], torch.Tensor): + img_tensor = data['img'][0] + else: + img_tensor = data['img'][0].data[0] + img_metas = data['img_metas'][0].data[0] + imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) + assert len(imgs) == len(img_metas) + + for i, (img, img_meta) in enumerate(zip(imgs, img_metas)): + h, w, _ = img_meta['img_shape'] + img_show = img[:h, :w, :] + + ori_h, ori_w = img_meta['ori_shape'][:-1] + img_show = dbnet_cv.imresize(img_show, (ori_w, ori_h)) + + if out_dir: + out_file = osp.join(out_dir, img_meta['ori_filename']) + else: + out_file = None + + model.module.show_result( + img_show, + result[i], + bbox_color=PALETTE, + text_color=PALETTE, + mask_color=PALETTE, + show=show, + out_file=out_file, + score_thr=show_score_thr) + + # encode mask results + if isinstance(result[0], tuple): + result = [(bbox_results, encode_mask_results(mask_results)) + for bbox_results, mask_results in result] + # This logic is only used in panoptic segmentation test. + elif isinstance(result[0], dict) and 'ins_results' in result[0]: + for j in range(len(result)): + bbox_results, mask_results = result[j]['ins_results'] + result[j]['ins_results'] = (bbox_results, + encode_mask_results(mask_results)) + + results.extend(result) + + for _ in range(batch_size): + prog_bar.update() + return results + + +def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): + """Test model with multiple gpus. + This method tests model with multiple gpus and collects the results + under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' + it encodes results to gpu tensors and use gpu communication for results + collection. On cpu mode it saves the results on different gpus to 'tmpdir' + and collects them by the rank 0 worker. + Args: + model (nn.Module): Model to be tested. + data_loader (nn.Dataloader): Pytorch data loader. + tmpdir (str): Path of directory to save the temporary results from + different gpus under cpu mode. + gpu_collect (bool): Option to use either gpu or cpu to collect results. + Returns: + list: The prediction results. + """ + model.eval() + results = [] + dataset = data_loader.dataset + rank, world_size = get_dist_info() + if rank == 0: + prog_bar = dbnet_cv.ProgressBar(len(dataset)) + time.sleep(2) # This line can prevent deadlock problem in some cases. + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(return_loss=False, rescale=True, **data) + # encode mask results + if isinstance(result[0], tuple): + result = [(bbox_results, encode_mask_results(mask_results)) + for bbox_results, mask_results in result] + # This logic is only used in panoptic segmentation test. + elif isinstance(result[0], dict) and 'ins_results' in result[0]: + for j in range(len(result)): + bbox_results, mask_results = result[j]['ins_results'] + result[j]['ins_results'] = ( + bbox_results, encode_mask_results(mask_results)) + + results.extend(result) + + if rank == 0: + batch_size = len(result) + for _ in range(batch_size * world_size): + prog_bar.update() + + # collect results from all ranks + if gpu_collect: + results = collect_results_gpu(results, len(dataset)) + else: + results = collect_results_cpu(results, len(dataset), tmpdir) + return results + + +def collect_results_cpu(result_part, size, tmpdir=None): + rank, world_size = get_dist_info() + # create a tmp dir if it is not specified + if tmpdir is None: + MAX_LEN = 512 + # 32 is whitespace + dir_tensor = torch.full((MAX_LEN, ), + 32, + dtype=torch.uint8, + device='cuda') + if rank == 0: + dbnet_cv.mkdir_or_exist('.dist_test') + tmpdir = tempfile.mkdtemp(dir='.dist_test') + tmpdir = torch.tensor( + bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') + dir_tensor[:len(tmpdir)] = tmpdir + dist.broadcast(dir_tensor, 0) + tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() + else: + dbnet_cv.mkdir_or_exist(tmpdir) + # dump the part result to the dir + dbnet_cv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl')) + dist.barrier() + # collect all parts + if rank != 0: + return None + else: + # load results of all parts from tmp dir + part_list = [] + for i in range(world_size): + part_file = osp.join(tmpdir, f'part_{i}.pkl') + part_list.append(dbnet_cv.load(part_file)) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + # remove tmp dir + shutil.rmtree(tmpdir) + return ordered_results + + +def collect_results_gpu(result_part, size): + rank, world_size = get_dist_info() + # dump result part to tensor with pickle + part_tensor = torch.tensor( + bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') + # gather all result part tensor shape + shape_tensor = torch.tensor(part_tensor.shape, device='cuda') + shape_list = [shape_tensor.clone() for _ in range(world_size)] + dist.all_gather(shape_list, shape_tensor) + # padding result part tensor to max length + shape_max = torch.tensor(shape_list).max() + part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') + part_send[:shape_tensor[0]] = part_tensor + part_recv_list = [ + part_tensor.new_zeros(shape_max) for _ in range(world_size) + ] + # gather all result part + dist.all_gather(part_recv_list, part_send) + + if rank == 0: + part_list = [] + for recv, shape in zip(part_recv_list, shape_list): + part_list.append( + pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + return ordered_results \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/core/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_det/core/__init__.py new file mode 100755 index 00000000..7a0fc45e --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/core/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# from .anchor import * # noqa: F401, F403 +from .bbox import * # noqa: F401, F403 +# from .data_structures import * # noqa: F401, F403 +from .evaluation import * # noqa: F401, F403 +# from .hook import * # noqa: F401, F403 +from .mask import * # noqa: F401, F403 +# from .optimizers import * # noqa: F401, F403 +# from .post_processing import * # noqa: F401, F403 +from .utils import * # noqa: F401, F403 diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/core/bbox/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_det/core/bbox/__init__.py new file mode 100755 index 00000000..677fb895 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/core/bbox/__init__.py @@ -0,0 +1,28 @@ +# # Copyright (c) OpenMMLab. All rights reserved. +# from .assigners import (AssignResult, BaseAssigner, CenterRegionAssigner, +# MaxIoUAssigner, RegionAssigner) +# from .builder import build_assigner, build_bbox_coder, build_sampler +# from .coder import (BaseBBoxCoder, DeltaXYWHBBoxCoder, DistancePointBBoxCoder, +# PseudoBBoxCoder, TBLRBBoxCoder) +# from .iou_calculators import BboxOverlaps2D, bbox_overlaps +# from .samplers import (BaseSampler, CombinedSampler, +# InstanceBalancedPosSampler, IoUBalancedNegSampler, +# OHEMSampler, PseudoSampler, RandomSampler, +# SamplingResult, ScoreHLRSampler) +from .transforms import (bbox2distance, bbox2result, bbox2roi, + bbox_cxcywh_to_xyxy, bbox_flip, bbox_mapping, + bbox_mapping_back, bbox_rescale, bbox_xyxy_to_cxcywh, + distance2bbox, find_inside_bboxes, roi2bbox) + +# __all__ = [ +# 'bbox_overlaps', 'BboxOverlaps2D', 'BaseAssigner', 'MaxIoUAssigner', +# 'AssignResult', 'BaseSampler', 'PseudoSampler', 'RandomSampler', +# 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler', +# 'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler', 'build_assigner', +# 'build_sampler', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', +# 'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'bbox2distance', +# 'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder', +# 'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'DistancePointBBoxCoder', +# 'CenterRegionAssigner', 'bbox_rescale', 'bbox_cxcywh_to_xyxy', +# 'bbox_xyxy_to_cxcywh', 'RegionAssigner', 'find_inside_bboxes' +# ] diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/core/bbox/transforms.py b/cv/ocr/dbnet/pytorch/dbnet_det/core/bbox/transforms.py new file mode 100755 index 00000000..9f5d43e7 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/core/bbox/transforms.py @@ -0,0 +1,270 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch + + +def find_inside_bboxes(bboxes, img_h, img_w): + """Find bboxes as long as a part of bboxes is inside the image. + + Args: + bboxes (Tensor): Shape (N, 4). + img_h (int): Image height. + img_w (int): Image width. + + Returns: + Tensor: Index of the remaining bboxes. + """ + inside_inds = (bboxes[:, 0] < img_w) & (bboxes[:, 2] > 0) \ + & (bboxes[:, 1] < img_h) & (bboxes[:, 3] > 0) + return inside_inds + + +def bbox_flip(bboxes, img_shape, direction='horizontal'): + """Flip bboxes horizontally or vertically. + + Args: + bboxes (Tensor): Shape (..., 4*k) + img_shape (tuple): Image shape. + direction (str): Flip direction, options are "horizontal", "vertical", + "diagonal". Default: "horizontal" + + Returns: + Tensor: Flipped bboxes. + """ + assert bboxes.shape[-1] % 4 == 0 + assert direction in ['horizontal', 'vertical', 'diagonal'] + flipped = bboxes.clone() + if direction == 'horizontal': + flipped[..., 0::4] = img_shape[1] - bboxes[..., 2::4] + flipped[..., 2::4] = img_shape[1] - bboxes[..., 0::4] + elif direction == 'vertical': + flipped[..., 1::4] = img_shape[0] - bboxes[..., 3::4] + flipped[..., 3::4] = img_shape[0] - bboxes[..., 1::4] + else: + flipped[..., 0::4] = img_shape[1] - bboxes[..., 2::4] + flipped[..., 1::4] = img_shape[0] - bboxes[..., 3::4] + flipped[..., 2::4] = img_shape[1] - bboxes[..., 0::4] + flipped[..., 3::4] = img_shape[0] - bboxes[..., 1::4] + return flipped + + +def bbox_mapping(bboxes, + img_shape, + scale_factor, + flip, + flip_direction='horizontal'): + """Map bboxes from the original image scale to testing scale.""" + new_bboxes = bboxes * bboxes.new_tensor(scale_factor) + if flip: + new_bboxes = bbox_flip(new_bboxes, img_shape, flip_direction) + return new_bboxes + + +def bbox_mapping_back(bboxes, + img_shape, + scale_factor, + flip, + flip_direction='horizontal'): + """Map bboxes from testing scale to original image scale.""" + new_bboxes = bbox_flip(bboxes, img_shape, + flip_direction) if flip else bboxes + new_bboxes = new_bboxes.view(-1, 4) / new_bboxes.new_tensor(scale_factor) + return new_bboxes.view(bboxes.shape) + + +def bbox2roi(bbox_list): + """Convert a list of bboxes to roi format. + + Args: + bbox_list (list[Tensor]): a list of bboxes corresponding to a batch + of images. + + Returns: + Tensor: shape (n, 5), [batch_ind, x1, y1, x2, y2] + """ + rois_list = [] + for img_id, bboxes in enumerate(bbox_list): + if bboxes.size(0) > 0: + img_inds = bboxes.new_full((bboxes.size(0), 1), img_id) + rois = torch.cat([img_inds, bboxes[:, :4]], dim=-1) + else: + rois = bboxes.new_zeros((0, 5)) + rois_list.append(rois) + rois = torch.cat(rois_list, 0) + return rois + + +def roi2bbox(rois): + """Convert rois to bounding box format. + + Args: + rois (torch.Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + + Returns: + list[torch.Tensor]: Converted boxes of corresponding rois. + """ + bbox_list = [] + img_ids = torch.unique(rois[:, 0].cpu(), sorted=True) + for img_id in img_ids: + inds = (rois[:, 0] == img_id.item()) + bbox = rois[inds, 1:] + bbox_list.append(bbox) + return bbox_list + + +def bbox2result(bboxes, labels, num_classes): + """Convert detection results to a list of numpy arrays. + + Args: + bboxes (torch.Tensor | np.ndarray): shape (n, 5) + labels (torch.Tensor | np.ndarray): shape (n, ) + num_classes (int): class number, including background class + + Returns: + list(ndarray): bbox results of each class + """ + if bboxes.shape[0] == 0: + return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)] + else: + if isinstance(bboxes, torch.Tensor): + bboxes = bboxes.detach().cpu().numpy() + labels = labels.detach().cpu().numpy() + return [bboxes[labels == i, :] for i in range(num_classes)] + + +def distance2bbox(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (B, N, 2) or (N, 2). + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). Shape (B, N, 4) or (N, 4) + max_shape (Sequence[int] or torch.Tensor or Sequence[ + Sequence[int]],optional): Maximum bounds for boxes, specifies + (H, W, C) or (H, W). If priors shape is (B, N, 4), then + the max_shape should be a Sequence[Sequence[int]] + and the length of max_shape should also be B. + + Returns: + Tensor: Boxes with shape (N, 4) or (B, N, 4) + """ + + x1 = points[..., 0] - distance[..., 0] + y1 = points[..., 1] - distance[..., 1] + x2 = points[..., 0] + distance[..., 2] + y2 = points[..., 1] + distance[..., 3] + + bboxes = torch.stack([x1, y1, x2, y2], -1) + + if max_shape is not None: + if bboxes.dim() == 2 and not torch.onnx.is_in_onnx_export(): + # speed up + bboxes[:, 0::2].clamp_(min=0, max=max_shape[1]) + bboxes[:, 1::2].clamp_(min=0, max=max_shape[0]) + return bboxes + + # clip bboxes with dynamic `min` and `max` for onnx + if torch.onnx.is_in_onnx_export(): + from dbnet_det.core.export import dynamic_clip_for_onnx + x1, y1, x2, y2 = dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape) + bboxes = torch.stack([x1, y1, x2, y2], dim=-1) + return bboxes + if not isinstance(max_shape, torch.Tensor): + max_shape = x1.new_tensor(max_shape) + max_shape = max_shape[..., :2].type_as(x1) + if max_shape.ndim == 2: + assert bboxes.ndim == 3 + assert max_shape.size(0) == bboxes.size(0) + + min_xy = x1.new_tensor(0) + max_xy = torch.cat([max_shape, max_shape], + dim=-1).flip(-1).unsqueeze(-2) + bboxes = torch.where(bboxes < min_xy, min_xy, bboxes) + bboxes = torch.where(bboxes > max_xy, max_xy, bboxes) + + return bboxes + + +def bbox2distance(points, bbox, max_dis=None, eps=0.1): + """Decode bounding box based on distances. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + bbox (Tensor): Shape (n, 4), "xyxy" format + max_dis (float): Upper bound of the distance. + eps (float): a small value to ensure target < max_dis, instead <= + + Returns: + Tensor: Decoded distances. + """ + left = points[:, 0] - bbox[:, 0] + top = points[:, 1] - bbox[:, 1] + right = bbox[:, 2] - points[:, 0] + bottom = bbox[:, 3] - points[:, 1] + if max_dis is not None: + left = left.clamp(min=0, max=max_dis - eps) + top = top.clamp(min=0, max=max_dis - eps) + right = right.clamp(min=0, max=max_dis - eps) + bottom = bottom.clamp(min=0, max=max_dis - eps) + return torch.stack([left, top, right, bottom], -1) + + +def bbox_rescale(bboxes, scale_factor=1.0): + """Rescale bounding box w.r.t. scale_factor. + + Args: + bboxes (Tensor): Shape (n, 4) for bboxes or (n, 5) for rois + scale_factor (float): rescale factor + + Returns: + Tensor: Rescaled bboxes. + """ + if bboxes.size(1) == 5: + bboxes_ = bboxes[:, 1:] + inds_ = bboxes[:, 0] + else: + bboxes_ = bboxes + cx = (bboxes_[:, 0] + bboxes_[:, 2]) * 0.5 + cy = (bboxes_[:, 1] + bboxes_[:, 3]) * 0.5 + w = bboxes_[:, 2] - bboxes_[:, 0] + h = bboxes_[:, 3] - bboxes_[:, 1] + w = w * scale_factor + h = h * scale_factor + x1 = cx - 0.5 * w + x2 = cx + 0.5 * w + y1 = cy - 0.5 * h + y2 = cy + 0.5 * h + if bboxes.size(1) == 5: + rescaled_bboxes = torch.stack([inds_, x1, y1, x2, y2], dim=-1) + else: + rescaled_bboxes = torch.stack([x1, y1, x2, y2], dim=-1) + return rescaled_bboxes + + +def bbox_cxcywh_to_xyxy(bbox): + """Convert bbox coordinates from (cx, cy, w, h) to (x1, y1, x2, y2). + + Args: + bbox (Tensor): Shape (n, 4) for bboxes. + + Returns: + Tensor: Converted bboxes. + """ + cx, cy, w, h = bbox.split((1, 1, 1, 1), dim=-1) + bbox_new = [(cx - 0.5 * w), (cy - 0.5 * h), (cx + 0.5 * w), (cy + 0.5 * h)] + return torch.cat(bbox_new, dim=-1) + + +def bbox_xyxy_to_cxcywh(bbox): + """Convert bbox coordinates from (x1, y1, x2, y2) to (cx, cy, w, h). + + Args: + bbox (Tensor): Shape (n, 4) for bboxes. + + Returns: + Tensor: Converted bboxes. + """ + x1, y1, x2, y2 = bbox.split((1, 1, 1, 1), dim=-1) + bbox_new = [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)] + return torch.cat(bbox_new, dim=-1) diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/__init__.py new file mode 100755 index 00000000..0dbb2ee6 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .class_names import (cityscapes_classes, coco_classes, dataset_aliases, + get_classes, imagenet_det_classes, + imagenet_vid_classes, oid_challenge_classes, + oid_v6_classes, voc_classes) +from .eval_hooks import DistEvalHook, EvalHook +from .mean_ap import average_precision, eval_map, print_map_summary +from .panoptic_utils import INSTANCE_OFFSET +from .recall import (eval_recalls, plot_iou_recall, plot_num_recall, + print_recall_summary) + +# __all__ = [ +# 'voc_classes', 'imagenet_det_classes', 'imagenet_vid_classes', +# 'coco_classes', 'cityscapes_classes', 'dataset_aliases', 'get_classes', +# 'DistEvalHook', 'EvalHook', 'average_precision', 'eval_map', +# 'print_map_summary', 'eval_recalls', 'print_recall_summary', +# 'plot_num_recall', 'plot_iou_recall', 'oid_v6_classes', +# 'oid_challenge_classes', 'INSTANCE_OFFSET' +# ] diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/bbox_overlaps.py b/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/bbox_overlaps.py new file mode 100755 index 00000000..d83eece7 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/bbox_overlaps.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + + +def bbox_overlaps(bboxes1, + bboxes2, + mode='iou', + eps=1e-6, + use_legacy_coordinate=False): + """Calculate the ious between each bbox of bboxes1 and bboxes2. + + Args: + bboxes1 (ndarray): Shape (n, 4) + bboxes2 (ndarray): Shape (k, 4) + mode (str): IOU (intersection over union) or IOF (intersection + over foreground) + use_legacy_coordinate (bool): Whether to use coordinate system in + dbnet_det v1.x. which means width, height should be + calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively. + Note when function is used in `VOCDataset`, it should be + True to align with the official implementation + `http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCdevkit_18-May-2011.tar` + Default: False. + + Returns: + ious (ndarray): Shape (n, k) + """ + + assert mode in ['iou', 'iof'] + if not use_legacy_coordinate: + extra_length = 0. + else: + extra_length = 1. + bboxes1 = bboxes1.astype(np.float32) + bboxes2 = bboxes2.astype(np.float32) + rows = bboxes1.shape[0] + cols = bboxes2.shape[0] + ious = np.zeros((rows, cols), dtype=np.float32) + if rows * cols == 0: + return ious + exchange = False + if bboxes1.shape[0] > bboxes2.shape[0]: + bboxes1, bboxes2 = bboxes2, bboxes1 + ious = np.zeros((cols, rows), dtype=np.float32) + exchange = True + area1 = (bboxes1[:, 2] - bboxes1[:, 0] + extra_length) * ( + bboxes1[:, 3] - bboxes1[:, 1] + extra_length) + area2 = (bboxes2[:, 2] - bboxes2[:, 0] + extra_length) * ( + bboxes2[:, 3] - bboxes2[:, 1] + extra_length) + for i in range(bboxes1.shape[0]): + x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0]) + y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1]) + x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2]) + y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3]) + overlap = np.maximum(x_end - x_start + extra_length, 0) * np.maximum( + y_end - y_start + extra_length, 0) + if mode == 'iou': + union = area1[i] + area2 - overlap + else: + union = area1[i] if not exchange else area2 + union = np.maximum(union, eps) + ious[i, :] = overlap / union + if exchange: + ious = ious.T + return ious diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/class_names.py b/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/class_names.py new file mode 100755 index 00000000..444ce82f --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/class_names.py @@ -0,0 +1,332 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dbnet_cv + + +def wider_face_classes(): + return ['face'] + + +def voc_classes(): + return [ + 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', + 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', + 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' + ] + + +def imagenet_det_classes(): + return [ + 'accordion', 'airplane', 'ant', 'antelope', 'apple', 'armadillo', + 'artichoke', 'axe', 'baby_bed', 'backpack', 'bagel', 'balance_beam', + 'banana', 'band_aid', 'banjo', 'baseball', 'basketball', 'bathing_cap', + 'beaker', 'bear', 'bee', 'bell_pepper', 'bench', 'bicycle', 'binder', + 'bird', 'bookshelf', 'bow_tie', 'bow', 'bowl', 'brassiere', 'burrito', + 'bus', 'butterfly', 'camel', 'can_opener', 'car', 'cart', 'cattle', + 'cello', 'centipede', 'chain_saw', 'chair', 'chime', 'cocktail_shaker', + 'coffee_maker', 'computer_keyboard', 'computer_mouse', 'corkscrew', + 'cream', 'croquet_ball', 'crutch', 'cucumber', 'cup_or_mug', 'diaper', + 'digital_clock', 'dishwasher', 'dog', 'domestic_cat', 'dragonfly', + 'drum', 'dumbbell', 'electric_fan', 'elephant', 'face_powder', 'fig', + 'filing_cabinet', 'flower_pot', 'flute', 'fox', 'french_horn', 'frog', + 'frying_pan', 'giant_panda', 'goldfish', 'golf_ball', 'golfcart', + 'guacamole', 'guitar', 'hair_dryer', 'hair_spray', 'hamburger', + 'hammer', 'hamster', 'harmonica', 'harp', 'hat_with_a_wide_brim', + 'head_cabbage', 'helmet', 'hippopotamus', 'horizontal_bar', 'horse', + 'hotdog', 'iPod', 'isopod', 'jellyfish', 'koala_bear', 'ladle', + 'ladybug', 'lamp', 'laptop', 'lemon', 'lion', 'lipstick', 'lizard', + 'lobster', 'maillot', 'maraca', 'microphone', 'microwave', 'milk_can', + 'miniskirt', 'monkey', 'motorcycle', 'mushroom', 'nail', 'neck_brace', + 'oboe', 'orange', 'otter', 'pencil_box', 'pencil_sharpener', 'perfume', + 'person', 'piano', 'pineapple', 'ping-pong_ball', 'pitcher', 'pizza', + 'plastic_bag', 'plate_rack', 'pomegranate', 'popsicle', 'porcupine', + 'power_drill', 'pretzel', 'printer', 'puck', 'punching_bag', 'purse', + 'rabbit', 'racket', 'ray', 'red_panda', 'refrigerator', + 'remote_control', 'rubber_eraser', 'rugby_ball', 'ruler', + 'salt_or_pepper_shaker', 'saxophone', 'scorpion', 'screwdriver', + 'seal', 'sheep', 'ski', 'skunk', 'snail', 'snake', 'snowmobile', + 'snowplow', 'soap_dispenser', 'soccer_ball', 'sofa', 'spatula', + 'squirrel', 'starfish', 'stethoscope', 'stove', 'strainer', + 'strawberry', 'stretcher', 'sunglasses', 'swimming_trunks', 'swine', + 'syringe', 'table', 'tape_player', 'tennis_ball', 'tick', 'tie', + 'tiger', 'toaster', 'traffic_light', 'train', 'trombone', 'trumpet', + 'turtle', 'tv_or_monitor', 'unicycle', 'vacuum', 'violin', + 'volleyball', 'waffle_iron', 'washer', 'water_bottle', 'watercraft', + 'whale', 'wine_bottle', 'zebra' + ] + + +def imagenet_vid_classes(): + return [ + 'airplane', 'antelope', 'bear', 'bicycle', 'bird', 'bus', 'car', + 'cattle', 'dog', 'domestic_cat', 'elephant', 'fox', 'giant_panda', + 'hamster', 'horse', 'lion', 'lizard', 'monkey', 'motorcycle', 'rabbit', + 'red_panda', 'sheep', 'snake', 'squirrel', 'tiger', 'train', 'turtle', + 'watercraft', 'whale', 'zebra' + ] + + +def coco_classes(): + return [ + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic_light', 'fire_hydrant', 'stop_sign', + 'parking_meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', + 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', + 'sports_ball', 'kite', 'baseball_bat', 'baseball_glove', 'skateboard', + 'surfboard', 'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork', + 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', + 'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', + 'laptop', 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush' + ] + + +def cityscapes_classes(): + return [ + 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', + 'bicycle' + ] + + +def oid_challenge_classes(): + return [ + 'Footwear', 'Jeans', 'House', 'Tree', 'Woman', 'Man', 'Land vehicle', + 'Person', 'Wheel', 'Bus', 'Human face', 'Bird', 'Dress', 'Girl', + 'Vehicle', 'Building', 'Cat', 'Car', 'Belt', 'Elephant', 'Dessert', + 'Butterfly', 'Train', 'Guitar', 'Poster', 'Book', 'Boy', 'Bee', + 'Flower', 'Window', 'Hat', 'Human head', 'Dog', 'Human arm', 'Drink', + 'Human mouth', 'Human hair', 'Human nose', 'Human hand', 'Table', + 'Marine invertebrates', 'Fish', 'Sculpture', 'Rose', 'Street light', + 'Glasses', 'Fountain', 'Skyscraper', 'Swimwear', 'Brassiere', 'Drum', + 'Duck', 'Countertop', 'Furniture', 'Ball', 'Human leg', 'Boat', + 'Balloon', 'Bicycle helmet', 'Goggles', 'Door', 'Human eye', 'Shirt', + 'Toy', 'Teddy bear', 'Pasta', 'Tomato', 'Human ear', + 'Vehicle registration plate', 'Microphone', 'Musical keyboard', + 'Tower', 'Houseplant', 'Flowerpot', 'Fruit', 'Vegetable', + 'Musical instrument', 'Suit', 'Motorcycle', 'Bagel', 'French fries', + 'Hamburger', 'Chair', 'Salt and pepper shakers', 'Snail', 'Airplane', + 'Horse', 'Laptop', 'Computer keyboard', 'Football helmet', 'Cocktail', + 'Juice', 'Tie', 'Computer monitor', 'Human beard', 'Bottle', + 'Saxophone', 'Lemon', 'Mouse', 'Sock', 'Cowboy hat', 'Sun hat', + 'Football', 'Porch', 'Sunglasses', 'Lobster', 'Crab', 'Picture frame', + 'Van', 'Crocodile', 'Surfboard', 'Shorts', 'Helicopter', 'Helmet', + 'Sports uniform', 'Taxi', 'Swan', 'Goose', 'Coat', 'Jacket', 'Handbag', + 'Flag', 'Skateboard', 'Television', 'Tire', 'Spoon', 'Palm tree', + 'Stairs', 'Salad', 'Castle', 'Oven', 'Microwave oven', 'Wine', + 'Ceiling fan', 'Mechanical fan', 'Cattle', 'Truck', 'Box', 'Ambulance', + 'Desk', 'Wine glass', 'Reptile', 'Tank', 'Traffic light', 'Billboard', + 'Tent', 'Insect', 'Spider', 'Treadmill', 'Cupboard', 'Shelf', + 'Seat belt', 'Human foot', 'Bicycle', 'Bicycle wheel', 'Couch', + 'Bookcase', 'Fedora', 'Backpack', 'Bench', 'Oyster', + 'Moths and butterflies', 'Lavender', 'Waffle', 'Fork', 'Animal', + 'Accordion', 'Mobile phone', 'Plate', 'Coffee cup', 'Saucer', + 'Platter', 'Dagger', 'Knife', 'Bull', 'Tortoise', 'Sea turtle', 'Deer', + 'Weapon', 'Apple', 'Ski', 'Taco', 'Traffic sign', 'Beer', 'Necklace', + 'Sunflower', 'Piano', 'Organ', 'Harpsichord', 'Bed', 'Cabinetry', + 'Nightstand', 'Curtain', 'Chest of drawers', 'Drawer', 'Parrot', + 'Sandal', 'High heels', 'Tableware', 'Cart', 'Mushroom', 'Kite', + 'Missile', 'Seafood', 'Camera', 'Paper towel', 'Toilet paper', + 'Sombrero', 'Radish', 'Lighthouse', 'Segway', 'Pig', 'Watercraft', + 'Golf cart', 'studio couch', 'Dolphin', 'Whale', 'Earrings', 'Otter', + 'Sea lion', 'Whiteboard', 'Monkey', 'Gondola', 'Zebra', + 'Baseball glove', 'Scarf', 'Adhesive tape', 'Trousers', 'Scoreboard', + 'Lily', 'Carnivore', 'Power plugs and sockets', 'Office building', + 'Sandwich', 'Swimming pool', 'Headphones', 'Tin can', 'Crown', 'Doll', + 'Cake', 'Frog', 'Beetle', 'Ant', 'Gas stove', 'Canoe', 'Falcon', + 'Blue jay', 'Egg', 'Fire hydrant', 'Raccoon', 'Muffin', 'Wall clock', + 'Coffee', 'Mug', 'Tea', 'Bear', 'Waste container', 'Home appliance', + 'Candle', 'Lion', 'Mirror', 'Starfish', 'Marine mammal', 'Wheelchair', + 'Umbrella', 'Alpaca', 'Violin', 'Cello', 'Brown bear', 'Canary', 'Bat', + 'Ruler', 'Plastic bag', 'Penguin', 'Watermelon', 'Harbor seal', 'Pen', + 'Pumpkin', 'Harp', 'Kitchen appliance', 'Roller skates', 'Bust', + 'Coffee table', 'Tennis ball', 'Tennis racket', 'Ladder', 'Boot', + 'Bowl', 'Stop sign', 'Volleyball', 'Eagle', 'Paddle', 'Chicken', + 'Skull', 'Lamp', 'Beehive', 'Maple', 'Sink', 'Goldfish', 'Tripod', + 'Coconut', 'Bidet', 'Tap', 'Bathroom cabinet', 'Toilet', + 'Filing cabinet', 'Pretzel', 'Table tennis racket', 'Bronze sculpture', + 'Rocket', 'Mouse', 'Hamster', 'Lizard', 'Lifejacket', 'Goat', + 'Washing machine', 'Trumpet', 'Horn', 'Trombone', 'Sheep', + 'Tablet computer', 'Pillow', 'Kitchen & dining room table', + 'Parachute', 'Raven', 'Glove', 'Loveseat', 'Christmas tree', + 'Shellfish', 'Rifle', 'Shotgun', 'Sushi', 'Sparrow', 'Bread', + 'Toaster', 'Watch', 'Asparagus', 'Artichoke', 'Suitcase', 'Antelope', + 'Broccoli', 'Ice cream', 'Racket', 'Banana', 'Cookie', 'Cucumber', + 'Dragonfly', 'Lynx', 'Caterpillar', 'Light bulb', 'Office supplies', + 'Miniskirt', 'Skirt', 'Fireplace', 'Potato', 'Light switch', + 'Croissant', 'Cabbage', 'Ladybug', 'Handgun', 'Luggage and bags', + 'Window blind', 'Snowboard', 'Baseball bat', 'Digital clock', + 'Serving tray', 'Infant bed', 'Sofa bed', 'Guacamole', 'Fox', 'Pizza', + 'Snowplow', 'Jet ski', 'Refrigerator', 'Lantern', 'Convenience store', + 'Sword', 'Rugby ball', 'Owl', 'Ostrich', 'Pancake', 'Strawberry', + 'Carrot', 'Tart', 'Dice', 'Turkey', 'Rabbit', 'Invertebrate', 'Vase', + 'Stool', 'Swim cap', 'Shower', 'Clock', 'Jellyfish', 'Aircraft', + 'Chopsticks', 'Orange', 'Snake', 'Sewing machine', 'Kangaroo', 'Mixer', + 'Food processor', 'Shrimp', 'Towel', 'Porcupine', 'Jaguar', 'Cannon', + 'Limousine', 'Mule', 'Squirrel', 'Kitchen knife', 'Tiara', 'Tiger', + 'Bow and arrow', 'Candy', 'Rhinoceros', 'Shark', 'Cricket ball', + 'Doughnut', 'Plumbing fixture', 'Camel', 'Polar bear', 'Coin', + 'Printer', 'Blender', 'Giraffe', 'Billiard table', 'Kettle', + 'Dinosaur', 'Pineapple', 'Zucchini', 'Jug', 'Barge', 'Teapot', + 'Golf ball', 'Binoculars', 'Scissors', 'Hot dog', 'Door handle', + 'Seahorse', 'Bathtub', 'Leopard', 'Centipede', 'Grapefruit', 'Snowman', + 'Cheetah', 'Alarm clock', 'Grape', 'Wrench', 'Wok', 'Bell pepper', + 'Cake stand', 'Barrel', 'Woodpecker', 'Flute', 'Corded phone', + 'Willow', 'Punching bag', 'Pomegranate', 'Telephone', 'Pear', + 'Common fig', 'Bench', 'Wood-burning stove', 'Burrito', 'Nail', + 'Turtle', 'Submarine sandwich', 'Drinking straw', 'Peach', 'Popcorn', + 'Frying pan', 'Picnic basket', 'Honeycomb', 'Envelope', 'Mango', + 'Cutting board', 'Pitcher', 'Stationary bicycle', 'Dumbbell', + 'Personal care', 'Dog bed', 'Snowmobile', 'Oboe', 'Briefcase', + 'Squash', 'Tick', 'Slow cooker', 'Coffeemaker', 'Measuring cup', + 'Crutch', 'Stretcher', 'Screwdriver', 'Flashlight', 'Spatula', + 'Pressure cooker', 'Ring binder', 'Beaker', 'Torch', 'Winter melon' + ] + + +def oid_v6_classes(): + return [ + 'Tortoise', 'Container', 'Magpie', 'Sea turtle', 'Football', + 'Ambulance', 'Ladder', 'Toothbrush', 'Syringe', 'Sink', 'Toy', + 'Organ (Musical Instrument)', 'Cassette deck', 'Apple', 'Human eye', + 'Cosmetics', 'Paddle', 'Snowman', 'Beer', 'Chopsticks', 'Human beard', + 'Bird', 'Parking meter', 'Traffic light', 'Croissant', 'Cucumber', + 'Radish', 'Towel', 'Doll', 'Skull', 'Washing machine', 'Glove', 'Tick', + 'Belt', 'Sunglasses', 'Banjo', 'Cart', 'Ball', 'Backpack', 'Bicycle', + 'Home appliance', 'Centipede', 'Boat', 'Surfboard', 'Boot', + 'Headphones', 'Hot dog', 'Shorts', 'Fast food', 'Bus', 'Boy', + 'Screwdriver', 'Bicycle wheel', 'Barge', 'Laptop', 'Miniskirt', + 'Drill (Tool)', 'Dress', 'Bear', 'Waffle', 'Pancake', 'Brown bear', + 'Woodpecker', 'Blue jay', 'Pretzel', 'Bagel', 'Tower', 'Teapot', + 'Person', 'Bow and arrow', 'Swimwear', 'Beehive', 'Brassiere', 'Bee', + 'Bat (Animal)', 'Starfish', 'Popcorn', 'Burrito', 'Chainsaw', + 'Balloon', 'Wrench', 'Tent', 'Vehicle registration plate', 'Lantern', + 'Toaster', 'Flashlight', 'Billboard', 'Tiara', 'Limousine', 'Necklace', + 'Carnivore', 'Scissors', 'Stairs', 'Computer keyboard', 'Printer', + 'Traffic sign', 'Chair', 'Shirt', 'Poster', 'Cheese', 'Sock', + 'Fire hydrant', 'Land vehicle', 'Earrings', 'Tie', 'Watercraft', + 'Cabinetry', 'Suitcase', 'Muffin', 'Bidet', 'Snack', 'Snowmobile', + 'Clock', 'Medical equipment', 'Cattle', 'Cello', 'Jet ski', 'Camel', + 'Coat', 'Suit', 'Desk', 'Cat', 'Bronze sculpture', 'Juice', 'Gondola', + 'Beetle', 'Cannon', 'Computer mouse', 'Cookie', 'Office building', + 'Fountain', 'Coin', 'Calculator', 'Cocktail', 'Computer monitor', + 'Box', 'Stapler', 'Christmas tree', 'Cowboy hat', 'Hiking equipment', + 'Studio couch', 'Drum', 'Dessert', 'Wine rack', 'Drink', 'Zucchini', + 'Ladle', 'Human mouth', 'Dairy Product', 'Dice', 'Oven', 'Dinosaur', + 'Ratchet (Device)', 'Couch', 'Cricket ball', 'Winter melon', 'Spatula', + 'Whiteboard', 'Pencil sharpener', 'Door', 'Hat', 'Shower', 'Eraser', + 'Fedora', 'Guacamole', 'Dagger', 'Scarf', 'Dolphin', 'Sombrero', + 'Tin can', 'Mug', 'Tap', 'Harbor seal', 'Stretcher', 'Can opener', + 'Goggles', 'Human body', 'Roller skates', 'Coffee cup', + 'Cutting board', 'Blender', 'Plumbing fixture', 'Stop sign', + 'Office supplies', 'Volleyball (Ball)', 'Vase', 'Slow cooker', + 'Wardrobe', 'Coffee', 'Whisk', 'Paper towel', 'Personal care', 'Food', + 'Sun hat', 'Tree house', 'Flying disc', 'Skirt', 'Gas stove', + 'Salt and pepper shakers', 'Mechanical fan', 'Face powder', 'Fax', + 'Fruit', 'French fries', 'Nightstand', 'Barrel', 'Kite', 'Tart', + 'Treadmill', 'Fox', 'Flag', 'French horn', 'Window blind', + 'Human foot', 'Golf cart', 'Jacket', 'Egg (Food)', 'Street light', + 'Guitar', 'Pillow', 'Human leg', 'Isopod', 'Grape', 'Human ear', + 'Power plugs and sockets', 'Panda', 'Giraffe', 'Woman', 'Door handle', + 'Rhinoceros', 'Bathtub', 'Goldfish', 'Houseplant', 'Goat', + 'Baseball bat', 'Baseball glove', 'Mixing bowl', + 'Marine invertebrates', 'Kitchen utensil', 'Light switch', 'House', + 'Horse', 'Stationary bicycle', 'Hammer', 'Ceiling fan', 'Sofa bed', + 'Adhesive tape', 'Harp', 'Sandal', 'Bicycle helmet', 'Saucer', + 'Harpsichord', 'Human hair', 'Heater', 'Harmonica', 'Hamster', + 'Curtain', 'Bed', 'Kettle', 'Fireplace', 'Scale', 'Drinking straw', + 'Insect', 'Hair dryer', 'Kitchenware', 'Indoor rower', 'Invertebrate', + 'Food processor', 'Bookcase', 'Refrigerator', 'Wood-burning stove', + 'Punching bag', 'Common fig', 'Cocktail shaker', 'Jaguar (Animal)', + 'Golf ball', 'Fashion accessory', 'Alarm clock', 'Filing cabinet', + 'Artichoke', 'Table', 'Tableware', 'Kangaroo', 'Koala', 'Knife', + 'Bottle', 'Bottle opener', 'Lynx', 'Lavender (Plant)', 'Lighthouse', + 'Dumbbell', 'Human head', 'Bowl', 'Humidifier', 'Porch', 'Lizard', + 'Billiard table', 'Mammal', 'Mouse', 'Motorcycle', + 'Musical instrument', 'Swim cap', 'Frying pan', 'Snowplow', + 'Bathroom cabinet', 'Missile', 'Bust', 'Man', 'Waffle iron', 'Milk', + 'Ring binder', 'Plate', 'Mobile phone', 'Baked goods', 'Mushroom', + 'Crutch', 'Pitcher (Container)', 'Mirror', 'Personal flotation device', + 'Table tennis racket', 'Pencil case', 'Musical keyboard', 'Scoreboard', + 'Briefcase', 'Kitchen knife', 'Nail (Construction)', 'Tennis ball', + 'Plastic bag', 'Oboe', 'Chest of drawers', 'Ostrich', 'Piano', 'Girl', + 'Plant', 'Potato', 'Hair spray', 'Sports equipment', 'Pasta', + 'Penguin', 'Pumpkin', 'Pear', 'Infant bed', 'Polar bear', 'Mixer', + 'Cupboard', 'Jacuzzi', 'Pizza', 'Digital clock', 'Pig', 'Reptile', + 'Rifle', 'Lipstick', 'Skateboard', 'Raven', 'High heels', 'Red panda', + 'Rose', 'Rabbit', 'Sculpture', 'Saxophone', 'Shotgun', 'Seafood', + 'Submarine sandwich', 'Snowboard', 'Sword', 'Picture frame', 'Sushi', + 'Loveseat', 'Ski', 'Squirrel', 'Tripod', 'Stethoscope', 'Submarine', + 'Scorpion', 'Segway', 'Training bench', 'Snake', 'Coffee table', + 'Skyscraper', 'Sheep', 'Television', 'Trombone', 'Tea', 'Tank', 'Taco', + 'Telephone', 'Torch', 'Tiger', 'Strawberry', 'Trumpet', 'Tree', + 'Tomato', 'Train', 'Tool', 'Picnic basket', 'Cooking spray', + 'Trousers', 'Bowling equipment', 'Football helmet', 'Truck', + 'Measuring cup', 'Coffeemaker', 'Violin', 'Vehicle', 'Handbag', + 'Paper cutter', 'Wine', 'Weapon', 'Wheel', 'Worm', 'Wok', 'Whale', + 'Zebra', 'Auto part', 'Jug', 'Pizza cutter', 'Cream', 'Monkey', 'Lion', + 'Bread', 'Platter', 'Chicken', 'Eagle', 'Helicopter', 'Owl', 'Duck', + 'Turtle', 'Hippopotamus', 'Crocodile', 'Toilet', 'Toilet paper', + 'Squid', 'Clothing', 'Footwear', 'Lemon', 'Spider', 'Deer', 'Frog', + 'Banana', 'Rocket', 'Wine glass', 'Countertop', 'Tablet computer', + 'Waste container', 'Swimming pool', 'Dog', 'Book', 'Elephant', 'Shark', + 'Candle', 'Leopard', 'Axe', 'Hand dryer', 'Soap dispenser', + 'Porcupine', 'Flower', 'Canary', 'Cheetah', 'Palm tree', 'Hamburger', + 'Maple', 'Building', 'Fish', 'Lobster', 'Garden Asparagus', + 'Furniture', 'Hedgehog', 'Airplane', 'Spoon', 'Otter', 'Bull', + 'Oyster', 'Horizontal bar', 'Convenience store', 'Bomb', 'Bench', + 'Ice cream', 'Caterpillar', 'Butterfly', 'Parachute', 'Orange', + 'Antelope', 'Beaker', 'Moths and butterflies', 'Window', 'Closet', + 'Castle', 'Jellyfish', 'Goose', 'Mule', 'Swan', 'Peach', 'Coconut', + 'Seat belt', 'Raccoon', 'Chisel', 'Fork', 'Lamp', 'Camera', + 'Squash (Plant)', 'Racket', 'Human face', 'Human arm', 'Vegetable', + 'Diaper', 'Unicycle', 'Falcon', 'Chime', 'Snail', 'Shellfish', + 'Cabbage', 'Carrot', 'Mango', 'Jeans', 'Flowerpot', 'Pineapple', + 'Drawer', 'Stool', 'Envelope', 'Cake', 'Dragonfly', 'Common sunflower', + 'Microwave oven', 'Honeycomb', 'Marine mammal', 'Sea lion', 'Ladybug', + 'Shelf', 'Watch', 'Candy', 'Salad', 'Parrot', 'Handgun', 'Sparrow', + 'Van', 'Grinder', 'Spice rack', 'Light bulb', 'Corded phone', + 'Sports uniform', 'Tennis racket', 'Wall clock', 'Serving tray', + 'Kitchen & dining room table', 'Dog bed', 'Cake stand', + 'Cat furniture', 'Bathroom accessory', 'Facial tissue holder', + 'Pressure cooker', 'Kitchen appliance', 'Tire', 'Ruler', + 'Luggage and bags', 'Microphone', 'Broccoli', 'Umbrella', 'Pastry', + 'Grapefruit', 'Band-aid', 'Animal', 'Bell pepper', 'Turkey', 'Lily', + 'Pomegranate', 'Doughnut', 'Glasses', 'Human nose', 'Pen', 'Ant', + 'Car', 'Aircraft', 'Human hand', 'Skunk', 'Teddy bear', 'Watermelon', + 'Cantaloupe', 'Dishwasher', 'Flute', 'Balance beam', 'Sandwich', + 'Shrimp', 'Sewing machine', 'Binoculars', 'Rays and skates', 'Ipod', + 'Accordion', 'Willow', 'Crab', 'Crown', 'Seahorse', 'Perfume', + 'Alpaca', 'Taxi', 'Canoe', 'Remote control', 'Wheelchair', + 'Rugby ball', 'Armadillo', 'Maracas', 'Helmet' + ] + + +dataset_aliases = { + 'voc': ['voc', 'pascal_voc', 'voc07', 'voc12'], + 'imagenet_det': ['det', 'imagenet_det', 'ilsvrc_det'], + 'imagenet_vid': ['vid', 'imagenet_vid', 'ilsvrc_vid'], + 'coco': ['coco', 'mscoco', 'ms_coco'], + 'wider_face': ['WIDERFaceDataset', 'wider_face', 'WIDERFace'], + 'cityscapes': ['cityscapes'], + 'oid_challenge': ['oid_challenge', 'openimages_challenge'], + 'oid_v6': ['oid_v6', 'openimages_v6'] +} + + +def get_classes(dataset): + """Get class names of a dataset.""" + alias2name = {} + for name, aliases in dataset_aliases.items(): + for alias in aliases: + alias2name[alias] = name + + if dbnet_cv.is_str(dataset): + if dataset in alias2name: + labels = eval(alias2name[dataset] + '_classes()') + else: + raise ValueError(f'Unrecognized dataset: {dataset}') + else: + raise TypeError(f'dataset must a str, but got {type(dataset)}') + return labels diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/eval_hooks.py b/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/eval_hooks.py new file mode 100755 index 00000000..5ba682d7 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/eval_hooks.py @@ -0,0 +1,140 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import bisect +import os.path as osp + +import dbnet_cv +import torch.distributed as dist +from dbnet_cv.runner import DistEvalHook as BaseDistEvalHook +from dbnet_cv.runner import EvalHook as BaseEvalHook +from torch.nn.modules.batchnorm import _BatchNorm + + +def _calc_dynamic_intervals(start_interval, dynamic_interval_list): + assert dbnet_cv.is_list_of(dynamic_interval_list, tuple) + + dynamic_milestones = [0] + dynamic_milestones.extend( + [dynamic_interval[0] for dynamic_interval in dynamic_interval_list]) + dynamic_intervals = [start_interval] + dynamic_intervals.extend( + [dynamic_interval[1] for dynamic_interval in dynamic_interval_list]) + return dynamic_milestones, dynamic_intervals + + +class EvalHook(BaseEvalHook): + + def __init__(self, *args, dynamic_intervals=None, **kwargs): + super(EvalHook, self).__init__(*args, **kwargs) + self.latest_results = None + + self.use_dynamic_intervals = dynamic_intervals is not None + if self.use_dynamic_intervals: + self.dynamic_milestones, self.dynamic_intervals = \ + _calc_dynamic_intervals(self.interval, dynamic_intervals) + + def _decide_interval(self, runner): + if self.use_dynamic_intervals: + progress = runner.epoch if self.by_epoch else runner.iter + step = bisect.bisect(self.dynamic_milestones, (progress + 1)) + # Dynamically modify the evaluation interval + self.interval = self.dynamic_intervals[step - 1] + + def before_train_epoch(self, runner): + """Evaluate the model only at the start of training by epoch.""" + self._decide_interval(runner) + super().before_train_epoch(runner) + + def before_train_iter(self, runner): + self._decide_interval(runner) + super().before_train_iter(runner) + + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + if not self._should_evaluate(runner): + return + + from dbnet_det.apis import single_gpu_test + + # Changed results to self.results so that dbnet_detWandbHook can access + # the evaluation results and log them to wandb. + results = single_gpu_test(runner.model, self.dataloader, show=False) + self.latest_results = results + runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) + key_score = self.evaluate(runner, results) + # the key_score may be `None` so it needs to skip the action to save + # the best checkpoint + if self.save_best and key_score: + self._save_ckpt(runner, key_score) + + +# Note: Considering that DBNET_CV's EvalHook updated its interface in V1.3.16, +# in order to avoid strong version dependency, we did not directly +# inherit EvalHook but BaseDistEvalHook. +class DistEvalHook(BaseDistEvalHook): + + def __init__(self, *args, dynamic_intervals=None, **kwargs): + super(DistEvalHook, self).__init__(*args, **kwargs) + self.latest_results = None + + self.use_dynamic_intervals = dynamic_intervals is not None + if self.use_dynamic_intervals: + self.dynamic_milestones, self.dynamic_intervals = \ + _calc_dynamic_intervals(self.interval, dynamic_intervals) + + def _decide_interval(self, runner): + if self.use_dynamic_intervals: + progress = runner.epoch if self.by_epoch else runner.iter + step = bisect.bisect(self.dynamic_milestones, (progress + 1)) + # Dynamically modify the evaluation interval + self.interval = self.dynamic_intervals[step - 1] + + def before_train_epoch(self, runner): + """Evaluate the model only at the start of training by epoch.""" + self._decide_interval(runner) + super().before_train_epoch(runner) + + def before_train_iter(self, runner): + self._decide_interval(runner) + super().before_train_iter(runner) + + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + # Synchronization of BatchNorm's buffer (running_mean + # and running_var) is not supported in the DDP of pytorch, + # which may cause the inconsistent performance of models in + # different ranks, so we broadcast BatchNorm's buffers + # of rank 0 to other ranks to avoid this. + if self.broadcast_bn_buffer: + model = runner.model + for name, module in model.named_modules(): + if isinstance(module, + _BatchNorm) and module.track_running_stats: + dist.broadcast(module.running_var, 0) + dist.broadcast(module.running_mean, 0) + + if not self._should_evaluate(runner): + return + + tmpdir = self.tmpdir + if tmpdir is None: + tmpdir = osp.join(runner.work_dir, '.eval_hook') + + from dbnet_det.apis import multi_gpu_test + + # Changed results to self.results so that dbnet_detWandbHook can access + # the evaluation results and log them to wandb. + results = multi_gpu_test( + runner.model, + self.dataloader, + tmpdir=tmpdir, + gpu_collect=self.gpu_collect) + self.latest_results = results + if runner.rank == 0: + print('\n') + runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) + key_score = self.evaluate(runner, results) + + # the key_score may be `None` so it needs to skip + # the action to save the best checkpoint + if self.save_best and key_score: + self._save_ckpt(runner, key_score) diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/mean_ap.py b/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/mean_ap.py new file mode 100755 index 00000000..3a2216da --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/mean_ap.py @@ -0,0 +1,782 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from multiprocessing import Pool + +import dbnet_cv +import numpy as np +from dbnet_cv.utils import print_log +from terminaltables import AsciiTable + +from .bbox_overlaps import bbox_overlaps +from .class_names import get_classes + + +def average_precision(recalls, precisions, mode='area'): + """Calculate average precision (for single or multiple scales). + + Args: + recalls (ndarray): shape (num_scales, num_dets) or (num_dets, ) + precisions (ndarray): shape (num_scales, num_dets) or (num_dets, ) + mode (str): 'area' or '11points', 'area' means calculating the area + under precision-recall curve, '11points' means calculating + the average precision of recalls at [0, 0.1, ..., 1] + + Returns: + float or ndarray: calculated average precision + """ + no_scale = False + if recalls.ndim == 1: + no_scale = True + recalls = recalls[np.newaxis, :] + precisions = precisions[np.newaxis, :] + assert recalls.shape == precisions.shape and recalls.ndim == 2 + num_scales = recalls.shape[0] + ap = np.zeros(num_scales, dtype=np.float32) + if mode == 'area': + zeros = np.zeros((num_scales, 1), dtype=recalls.dtype) + ones = np.ones((num_scales, 1), dtype=recalls.dtype) + mrec = np.hstack((zeros, recalls, ones)) + mpre = np.hstack((zeros, precisions, zeros)) + for i in range(mpre.shape[1] - 1, 0, -1): + mpre[:, i - 1] = np.maximum(mpre[:, i - 1], mpre[:, i]) + for i in range(num_scales): + ind = np.where(mrec[i, 1:] != mrec[i, :-1])[0] + ap[i] = np.sum( + (mrec[i, ind + 1] - mrec[i, ind]) * mpre[i, ind + 1]) + elif mode == '11points': + for i in range(num_scales): + for thr in np.arange(0, 1 + 1e-3, 0.1): + precs = precisions[i, recalls[i, :] >= thr] + prec = precs.max() if precs.size > 0 else 0 + ap[i] += prec + ap /= 11 + else: + raise ValueError( + 'Unrecognized mode, only "area" and "11points" are supported') + if no_scale: + ap = ap[0] + return ap + + +def tpfp_imagenet(det_bboxes, + gt_bboxes, + gt_bboxes_ignore=None, + default_iou_thr=0.5, + area_ranges=None, + use_legacy_coordinate=False, + **kwargs): + """Check if detected bboxes are true positive or false positive. + + Args: + det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5). + gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4). + gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image, + of shape (k, 4). Default: None + default_iou_thr (float): IoU threshold to be considered as matched for + medium and large bboxes (small ones have special rules). + Default: 0.5. + area_ranges (list[tuple] | None): Range of bbox areas to be evaluated, + in the format [(min1, max1), (min2, max2), ...]. Default: None. + use_legacy_coordinate (bool): Whether to use coordinate system in + dbnet_det v1.x. which means width, height should be + calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively. + Default: False. + + Returns: + tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of + each array is (num_scales, m). + """ + + if not use_legacy_coordinate: + extra_length = 0. + else: + extra_length = 1. + + # an indicator of ignored gts + gt_ignore_inds = np.concatenate( + (np.zeros(gt_bboxes.shape[0], dtype=np.bool), + np.ones(gt_bboxes_ignore.shape[0], dtype=np.bool))) + # stack gt_bboxes and gt_bboxes_ignore for convenience + gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore)) + + num_dets = det_bboxes.shape[0] + num_gts = gt_bboxes.shape[0] + if area_ranges is None: + area_ranges = [(None, None)] + num_scales = len(area_ranges) + # tp and fp are of shape (num_scales, num_gts), each row is tp or fp + # of a certain scale. + tp = np.zeros((num_scales, num_dets), dtype=np.float32) + fp = np.zeros((num_scales, num_dets), dtype=np.float32) + if gt_bboxes.shape[0] == 0: + if area_ranges == [(None, None)]: + fp[...] = 1 + else: + det_areas = ( + det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * ( + det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length) + for i, (min_area, max_area) in enumerate(area_ranges): + fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1 + return tp, fp + ious = bbox_overlaps( + det_bboxes, gt_bboxes - 1, use_legacy_coordinate=use_legacy_coordinate) + gt_w = gt_bboxes[:, 2] - gt_bboxes[:, 0] + extra_length + gt_h = gt_bboxes[:, 3] - gt_bboxes[:, 1] + extra_length + iou_thrs = np.minimum((gt_w * gt_h) / ((gt_w + 10.0) * (gt_h + 10.0)), + default_iou_thr) + # sort all detections by scores in descending order + sort_inds = np.argsort(-det_bboxes[:, -1]) + for k, (min_area, max_area) in enumerate(area_ranges): + gt_covered = np.zeros(num_gts, dtype=bool) + # if no area range is specified, gt_area_ignore is all False + if min_area is None: + gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool) + else: + gt_areas = gt_w * gt_h + gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area) + for i in sort_inds: + max_iou = -1 + matched_gt = -1 + # find best overlapped available gt + for j in range(num_gts): + # different from PASCAL VOC: allow finding other gts if the + # best overlapped ones are already matched by other det bboxes + if gt_covered[j]: + continue + elif ious[i, j] >= iou_thrs[j] and ious[i, j] > max_iou: + max_iou = ious[i, j] + matched_gt = j + # there are 4 cases for a det bbox: + # 1. it matches a gt, tp = 1, fp = 0 + # 2. it matches an ignored gt, tp = 0, fp = 0 + # 3. it matches no gt and within area range, tp = 0, fp = 1 + # 4. it matches no gt but is beyond area range, tp = 0, fp = 0 + if matched_gt >= 0: + gt_covered[matched_gt] = 1 + if not (gt_ignore_inds[matched_gt] + or gt_area_ignore[matched_gt]): + tp[k, i] = 1 + elif min_area is None: + fp[k, i] = 1 + else: + bbox = det_bboxes[i, :4] + area = (bbox[2] - bbox[0] + extra_length) * ( + bbox[3] - bbox[1] + extra_length) + if area >= min_area and area < max_area: + fp[k, i] = 1 + return tp, fp + + +def tpfp_default(det_bboxes, + gt_bboxes, + gt_bboxes_ignore=None, + iou_thr=0.5, + area_ranges=None, + use_legacy_coordinate=False, + **kwargs): + """Check if detected bboxes are true positive or false positive. + + Args: + det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5). + gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4). + gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image, + of shape (k, 4). Default: None + iou_thr (float): IoU threshold to be considered as matched. + Default: 0.5. + area_ranges (list[tuple] | None): Range of bbox areas to be + evaluated, in the format [(min1, max1), (min2, max2), ...]. + Default: None. + use_legacy_coordinate (bool): Whether to use coordinate system in + dbnet_det v1.x. which means width, height should be + calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively. + Default: False. + + Returns: + tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of + each array is (num_scales, m). + """ + + if not use_legacy_coordinate: + extra_length = 0. + else: + extra_length = 1. + + # an indicator of ignored gts + gt_ignore_inds = np.concatenate( + (np.zeros(gt_bboxes.shape[0], dtype=np.bool), + np.ones(gt_bboxes_ignore.shape[0], dtype=np.bool))) + # stack gt_bboxes and gt_bboxes_ignore for convenience + gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore)) + + num_dets = det_bboxes.shape[0] + num_gts = gt_bboxes.shape[0] + if area_ranges is None: + area_ranges = [(None, None)] + num_scales = len(area_ranges) + # tp and fp are of shape (num_scales, num_gts), each row is tp or fp of + # a certain scale + tp = np.zeros((num_scales, num_dets), dtype=np.float32) + fp = np.zeros((num_scales, num_dets), dtype=np.float32) + + # if there is no gt bboxes in this image, then all det bboxes + # within area range are false positives + if gt_bboxes.shape[0] == 0: + if area_ranges == [(None, None)]: + fp[...] = 1 + else: + det_areas = ( + det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * ( + det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length) + for i, (min_area, max_area) in enumerate(area_ranges): + fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1 + return tp, fp + + ious = bbox_overlaps( + det_bboxes, gt_bboxes, use_legacy_coordinate=use_legacy_coordinate) + # for each det, the max iou with all gts + ious_max = ious.max(axis=1) + # for each det, which gt overlaps most with it + ious_argmax = ious.argmax(axis=1) + # sort all dets in descending order by scores + sort_inds = np.argsort(-det_bboxes[:, -1]) + for k, (min_area, max_area) in enumerate(area_ranges): + gt_covered = np.zeros(num_gts, dtype=bool) + # if no area range is specified, gt_area_ignore is all False + if min_area is None: + gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool) + else: + gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + extra_length) * ( + gt_bboxes[:, 3] - gt_bboxes[:, 1] + extra_length) + gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area) + for i in sort_inds: + if ious_max[i] >= iou_thr: + matched_gt = ious_argmax[i] + if not (gt_ignore_inds[matched_gt] + or gt_area_ignore[matched_gt]): + if not gt_covered[matched_gt]: + gt_covered[matched_gt] = True + tp[k, i] = 1 + else: + fp[k, i] = 1 + # otherwise ignore this detected bbox, tp = 0, fp = 0 + elif min_area is None: + fp[k, i] = 1 + else: + bbox = det_bboxes[i, :4] + area = (bbox[2] - bbox[0] + extra_length) * ( + bbox[3] - bbox[1] + extra_length) + if area >= min_area and area < max_area: + fp[k, i] = 1 + return tp, fp + + +def tpfp_openimages(det_bboxes, + gt_bboxes, + gt_bboxes_ignore=None, + iou_thr=0.5, + area_ranges=None, + use_legacy_coordinate=False, + gt_bboxes_group_of=None, + use_group_of=True, + ioa_thr=0.5, + **kwargs): + """Check if detected bboxes are true positive or false positive. + + Args: + det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5). + gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4). + gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image, + of shape (k, 4). Default: None + iou_thr (float): IoU threshold to be considered as matched. + Default: 0.5. + area_ranges (list[tuple] | None): Range of bbox areas to be + evaluated, in the format [(min1, max1), (min2, max2), ...]. + Default: None. + use_legacy_coordinate (bool): Whether to use coordinate system in + dbnet_det v1.x. which means width, height should be + calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively. + Default: False. + gt_bboxes_group_of (ndarray): GT group_of of this image, of shape + (k, 1). Default: None + use_group_of (bool): Whether to use group of when calculate TP and FP, + which only used in OpenImages evaluation. Default: True. + ioa_thr (float | None): IoA threshold to be considered as matched, + which only used in OpenImages evaluation. Default: 0.5. + + Returns: + tuple[np.ndarray]: Returns a tuple (tp, fp, det_bboxes), where + (tp, fp) whose elements are 0 and 1. The shape of each array is + (num_scales, m). (det_bboxes) whose will filter those are not + matched by group of gts when processing Open Images evaluation. + The shape is (num_scales, m). + """ + + if not use_legacy_coordinate: + extra_length = 0. + else: + extra_length = 1. + + # an indicator of ignored gts + gt_ignore_inds = np.concatenate( + (np.zeros(gt_bboxes.shape[0], dtype=np.bool), + np.ones(gt_bboxes_ignore.shape[0], dtype=np.bool))) + # stack gt_bboxes and gt_bboxes_ignore for convenience + gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore)) + + num_dets = det_bboxes.shape[0] + num_gts = gt_bboxes.shape[0] + if area_ranges is None: + area_ranges = [(None, None)] + num_scales = len(area_ranges) + # tp and fp are of shape (num_scales, num_gts), each row is tp or fp of + # a certain scale + tp = np.zeros((num_scales, num_dets), dtype=np.float32) + fp = np.zeros((num_scales, num_dets), dtype=np.float32) + + # if there is no gt bboxes in this image, then all det bboxes + # within area range are false positives + if gt_bboxes.shape[0] == 0: + if area_ranges == [(None, None)]: + fp[...] = 1 + else: + det_areas = ( + det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * ( + det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length) + for i, (min_area, max_area) in enumerate(area_ranges): + fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1 + return tp, fp, det_bboxes + + if gt_bboxes_group_of is not None and use_group_of: + # if handle group-of boxes, divided gt boxes into two parts: + # non-group-of and group-of.Then calculate ious and ioas through + # non-group-of group-of gts respectively. This only used in + # OpenImages evaluation. + assert gt_bboxes_group_of.shape[0] == gt_bboxes.shape[0] + non_group_gt_bboxes = gt_bboxes[~gt_bboxes_group_of] + group_gt_bboxes = gt_bboxes[gt_bboxes_group_of] + num_gts_group = group_gt_bboxes.shape[0] + ious = bbox_overlaps(det_bboxes, non_group_gt_bboxes) + ioas = bbox_overlaps(det_bboxes, group_gt_bboxes, mode='iof') + else: + # if not consider group-of boxes, only calculate ious through gt boxes + ious = bbox_overlaps( + det_bboxes, gt_bboxes, use_legacy_coordinate=use_legacy_coordinate) + ioas = None + + if ious.shape[1] > 0: + # for each det, the max iou with all gts + ious_max = ious.max(axis=1) + # for each det, which gt overlaps most with it + ious_argmax = ious.argmax(axis=1) + # sort all dets in descending order by scores + sort_inds = np.argsort(-det_bboxes[:, -1]) + for k, (min_area, max_area) in enumerate(area_ranges): + gt_covered = np.zeros(num_gts, dtype=bool) + # if no area range is specified, gt_area_ignore is all False + if min_area is None: + gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool) + else: + gt_areas = ( + gt_bboxes[:, 2] - gt_bboxes[:, 0] + extra_length) * ( + gt_bboxes[:, 3] - gt_bboxes[:, 1] + extra_length) + gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area) + for i in sort_inds: + if ious_max[i] >= iou_thr: + matched_gt = ious_argmax[i] + if not (gt_ignore_inds[matched_gt] + or gt_area_ignore[matched_gt]): + if not gt_covered[matched_gt]: + gt_covered[matched_gt] = True + tp[k, i] = 1 + else: + fp[k, i] = 1 + # otherwise ignore this detected bbox, tp = 0, fp = 0 + elif min_area is None: + fp[k, i] = 1 + else: + bbox = det_bboxes[i, :4] + area = (bbox[2] - bbox[0] + extra_length) * ( + bbox[3] - bbox[1] + extra_length) + if area >= min_area and area < max_area: + fp[k, i] = 1 + else: + # if there is no no-group-of gt bboxes in this image, + # then all det bboxes within area range are false positives. + # Only used in OpenImages evaluation. + if area_ranges == [(None, None)]: + fp[...] = 1 + else: + det_areas = ( + det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * ( + det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length) + for i, (min_area, max_area) in enumerate(area_ranges): + fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1 + + if ioas is None or ioas.shape[1] <= 0: + return tp, fp, det_bboxes + else: + # The evaluation of group-of TP and FP are done in two stages: + # 1. All detections are first matched to non group-of boxes; true + # positives are determined. + # 2. Detections that are determined as false positives are matched + # against group-of boxes and calculated group-of TP and FP. + # Only used in OpenImages evaluation. + det_bboxes_group = np.zeros( + (num_scales, ioas.shape[1], det_bboxes.shape[1]), dtype=float) + match_group_of = np.zeros((num_scales, num_dets), dtype=bool) + tp_group = np.zeros((num_scales, num_gts_group), dtype=np.float32) + ioas_max = ioas.max(axis=1) + # for each det, which gt overlaps most with it + ioas_argmax = ioas.argmax(axis=1) + # sort all dets in descending order by scores + sort_inds = np.argsort(-det_bboxes[:, -1]) + for k, (min_area, max_area) in enumerate(area_ranges): + box_is_covered = tp[k] + # if no area range is specified, gt_area_ignore is all False + if min_area is None: + gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool) + else: + gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * ( + gt_bboxes[:, 3] - gt_bboxes[:, 1]) + gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area) + for i in sort_inds: + matched_gt = ioas_argmax[i] + if not box_is_covered[i]: + if ioas_max[i] >= ioa_thr: + if not (gt_ignore_inds[matched_gt] + or gt_area_ignore[matched_gt]): + if not tp_group[k, matched_gt]: + tp_group[k, matched_gt] = 1 + match_group_of[k, i] = True + else: + match_group_of[k, i] = True + + if det_bboxes_group[k, matched_gt, -1] < \ + det_bboxes[i, -1]: + det_bboxes_group[k, matched_gt] = \ + det_bboxes[i] + + fp_group = (tp_group <= 0).astype(float) + tps = [] + fps = [] + # concatenate tp, fp, and det-boxes which not matched group of + # gt boxes and tp_group, fp_group, and det_bboxes_group which + # matched group of boxes respectively. + for i in range(num_scales): + tps.append( + np.concatenate((tp[i][~match_group_of[i]], tp_group[i]))) + fps.append( + np.concatenate((fp[i][~match_group_of[i]], fp_group[i]))) + det_bboxes = np.concatenate( + (det_bboxes[~match_group_of[i]], det_bboxes_group[i])) + + tp = np.vstack(tps) + fp = np.vstack(fps) + return tp, fp, det_bboxes + + +def get_cls_results(det_results, annotations, class_id): + """Get det results and gt information of a certain class. + + Args: + det_results (list[list]): Same as `eval_map()`. + annotations (list[dict]): Same as `eval_map()`. + class_id (int): ID of a specific class. + + Returns: + tuple[list[np.ndarray]]: detected bboxes, gt bboxes, ignored gt bboxes + """ + cls_dets = [img_res[class_id] for img_res in det_results] + cls_gts = [] + cls_gts_ignore = [] + for ann in annotations: + gt_inds = ann['labels'] == class_id + cls_gts.append(ann['bboxes'][gt_inds, :]) + + if ann.get('labels_ignore', None) is not None: + ignore_inds = ann['labels_ignore'] == class_id + cls_gts_ignore.append(ann['bboxes_ignore'][ignore_inds, :]) + else: + cls_gts_ignore.append(np.empty((0, 4), dtype=np.float32)) + + return cls_dets, cls_gts, cls_gts_ignore + + +def get_cls_group_ofs(annotations, class_id): + """Get `gt_group_of` of a certain class, which is used in Open Images. + + Args: + annotations (list[dict]): Same as `eval_map()`. + class_id (int): ID of a specific class. + + Returns: + list[np.ndarray]: `gt_group_of` of a certain class. + """ + gt_group_ofs = [] + for ann in annotations: + gt_inds = ann['labels'] == class_id + if ann.get('gt_is_group_ofs', None) is not None: + gt_group_ofs.append(ann['gt_is_group_ofs'][gt_inds]) + else: + gt_group_ofs.append(np.empty((0, 1), dtype=np.bool)) + + return gt_group_ofs + + +def eval_map(det_results, + annotations, + scale_ranges=None, + iou_thr=0.5, + ioa_thr=None, + dataset=None, + logger=None, + tpfp_fn=None, + nproc=4, + use_legacy_coordinate=False, + use_group_of=False): + """Evaluate mAP of a dataset. + + Args: + det_results (list[list]): [[cls1_det, cls2_det, ...], ...]. + The outer list indicates images, and the inner list indicates + per-class detected bboxes. + annotations (list[dict]): Ground truth annotations where each item of + the list indicates an image. Keys of annotations are: + + - `bboxes`: numpy array of shape (n, 4) + - `labels`: numpy array of shape (n, ) + - `bboxes_ignore` (optional): numpy array of shape (k, 4) + - `labels_ignore` (optional): numpy array of shape (k, ) + scale_ranges (list[tuple] | None): Range of scales to be evaluated, + in the format [(min1, max1), (min2, max2), ...]. A range of + (32, 64) means the area range between (32**2, 64**2). + Default: None. + iou_thr (float): IoU threshold to be considered as matched. + Default: 0.5. + ioa_thr (float | None): IoA threshold to be considered as matched, + which only used in OpenImages evaluation. Default: None. + dataset (list[str] | str | None): Dataset name or dataset classes, + there are minor differences in metrics for different datasets, e.g. + "voc07", "imagenet_det", etc. Default: None. + logger (logging.Logger | str | None): The way to print the mAP + summary. See `dbnet_cv.utils.print_log()` for details. Default: None. + tpfp_fn (callable | None): The function used to determine true/ + false positives. If None, :func:`tpfp_default` is used as default + unless dataset is 'det' or 'vid' (:func:`tpfp_imagenet` in this + case). If it is given as a function, then this function is used + to evaluate tp & fp. Default None. + nproc (int): Processes used for computing TP and FP. + Default: 4. + use_legacy_coordinate (bool): Whether to use coordinate system in + dbnet_det v1.x. which means width, height should be + calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively. + Default: False. + use_group_of (bool): Whether to use group of when calculate TP and FP, + which only used in OpenImages evaluation. Default: False. + + Returns: + tuple: (mAP, [dict, dict, ...]) + """ + assert len(det_results) == len(annotations) + if not use_legacy_coordinate: + extra_length = 0. + else: + extra_length = 1. + + num_imgs = len(det_results) + num_scales = len(scale_ranges) if scale_ranges is not None else 1 + num_classes = len(det_results[0]) # positive class num + area_ranges = ([(rg[0]**2, rg[1]**2) for rg in scale_ranges] + if scale_ranges is not None else None) + + # There is no need to use multi processes to process + # when num_imgs = 1 . + if num_imgs > 1: + assert nproc > 0, 'nproc must be at least one.' + nproc = min(nproc, num_imgs) + pool = Pool(nproc) + + eval_results = [] + for i in range(num_classes): + # get gt and det bboxes of this class + cls_dets, cls_gts, cls_gts_ignore = get_cls_results( + det_results, annotations, i) + # choose proper function according to datasets to compute tp and fp + if tpfp_fn is None: + if dataset in ['det', 'vid']: + tpfp_fn = tpfp_imagenet + elif dataset in ['oid_challenge', 'oid_v6'] \ + or use_group_of is True: + tpfp_fn = tpfp_openimages + else: + tpfp_fn = tpfp_default + if not callable(tpfp_fn): + raise ValueError( + f'tpfp_fn has to be a function or None, but got {tpfp_fn}') + + if num_imgs > 1: + # compute tp and fp for each image with multiple processes + args = [] + if use_group_of: + # used in Open Images Dataset evaluation + gt_group_ofs = get_cls_group_ofs(annotations, i) + args.append(gt_group_ofs) + args.append([use_group_of for _ in range(num_imgs)]) + if ioa_thr is not None: + args.append([ioa_thr for _ in range(num_imgs)]) + + tpfp = pool.starmap( + tpfp_fn, + zip(cls_dets, cls_gts, cls_gts_ignore, + [iou_thr for _ in range(num_imgs)], + [area_ranges for _ in range(num_imgs)], + [use_legacy_coordinate for _ in range(num_imgs)], *args)) + else: + tpfp = tpfp_fn( + cls_dets[0], + cls_gts[0], + cls_gts_ignore[0], + iou_thr, + area_ranges, + use_legacy_coordinate, + gt_bboxes_group_of=(get_cls_group_ofs(annotations, i)[0] + if use_group_of else None), + use_group_of=use_group_of, + ioa_thr=ioa_thr) + tpfp = [tpfp] + + if use_group_of: + tp, fp, cls_dets = tuple(zip(*tpfp)) + else: + tp, fp = tuple(zip(*tpfp)) + # calculate gt number of each scale + # ignored gts or gts beyond the specific scale are not counted + num_gts = np.zeros(num_scales, dtype=int) + for j, bbox in enumerate(cls_gts): + if area_ranges is None: + num_gts[0] += bbox.shape[0] + else: + gt_areas = (bbox[:, 2] - bbox[:, 0] + extra_length) * ( + bbox[:, 3] - bbox[:, 1] + extra_length) + for k, (min_area, max_area) in enumerate(area_ranges): + num_gts[k] += np.sum((gt_areas >= min_area) + & (gt_areas < max_area)) + # sort all det bboxes by score, also sort tp and fp + cls_dets = np.vstack(cls_dets) + num_dets = cls_dets.shape[0] + sort_inds = np.argsort(-cls_dets[:, -1]) + tp = np.hstack(tp)[:, sort_inds] + fp = np.hstack(fp)[:, sort_inds] + # calculate recall and precision with tp and fp + tp = np.cumsum(tp, axis=1) + fp = np.cumsum(fp, axis=1) + eps = np.finfo(np.float32).eps + recalls = tp / np.maximum(num_gts[:, np.newaxis], eps) + precisions = tp / np.maximum((tp + fp), eps) + # calculate AP + if scale_ranges is None: + recalls = recalls[0, :] + precisions = precisions[0, :] + num_gts = num_gts.item() + mode = 'area' if dataset != 'voc07' else '11points' + ap = average_precision(recalls, precisions, mode) + eval_results.append({ + 'num_gts': num_gts, + 'num_dets': num_dets, + 'recall': recalls, + 'precision': precisions, + 'ap': ap + }) + + if num_imgs > 1: + pool.close() + + if scale_ranges is not None: + # shape (num_classes, num_scales) + all_ap = np.vstack([cls_result['ap'] for cls_result in eval_results]) + all_num_gts = np.vstack( + [cls_result['num_gts'] for cls_result in eval_results]) + mean_ap = [] + for i in range(num_scales): + if np.any(all_num_gts[:, i] > 0): + mean_ap.append(all_ap[all_num_gts[:, i] > 0, i].mean()) + else: + mean_ap.append(0.0) + else: + aps = [] + for cls_result in eval_results: + if cls_result['num_gts'] > 0: + aps.append(cls_result['ap']) + mean_ap = np.array(aps).mean().item() if aps else 0.0 + + print_map_summary( + mean_ap, eval_results, dataset, area_ranges, logger=logger) + + return mean_ap, eval_results + + +def print_map_summary(mean_ap, + results, + dataset=None, + scale_ranges=None, + logger=None): + """Print mAP and results of each class. + + A table will be printed to show the gts/dets/recall/AP of each class and + the mAP. + + Args: + mean_ap (float): Calculated from `eval_map()`. + results (list[dict]): Calculated from `eval_map()`. + dataset (list[str] | str | None): Dataset name or dataset classes. + scale_ranges (list[tuple] | None): Range of scales to be evaluated. + logger (logging.Logger | str | None): The way to print the mAP + summary. See `dbnet_cv.utils.print_log()` for details. Default: None. + """ + + if logger == 'silent': + return + + if isinstance(results[0]['ap'], np.ndarray): + num_scales = len(results[0]['ap']) + else: + num_scales = 1 + + if scale_ranges is not None: + assert len(scale_ranges) == num_scales + + num_classes = len(results) + + recalls = np.zeros((num_scales, num_classes), dtype=np.float32) + aps = np.zeros((num_scales, num_classes), dtype=np.float32) + num_gts = np.zeros((num_scales, num_classes), dtype=int) + for i, cls_result in enumerate(results): + if cls_result['recall'].size > 0: + recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1] + aps[:, i] = cls_result['ap'] + num_gts[:, i] = cls_result['num_gts'] + + if dataset is None: + label_names = [str(i) for i in range(num_classes)] + elif dbnet_cv.is_str(dataset): + label_names = get_classes(dataset) + else: + label_names = dataset + + if not isinstance(mean_ap, list): + mean_ap = [mean_ap] + + header = ['class', 'gts', 'dets', 'recall', 'ap'] + for i in range(num_scales): + if scale_ranges is not None: + print_log(f'Scale range {scale_ranges[i]}', logger=logger) + table_data = [header] + for j in range(num_classes): + row_data = [ + label_names[j], num_gts[i, j], results[j]['num_dets'], + f'{recalls[i, j]:.3f}', f'{aps[i, j]:.3f}' + ] + table_data.append(row_data) + table_data.append(['mAP', '', '', '', f'{mean_ap[i]:.3f}']) + table = AsciiTable(table_data) + table.inner_footing_row_border = True + print_log('\n' + table.table, logger=logger) diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/panoptic_utils.py b/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/panoptic_utils.py new file mode 100755 index 00000000..45643d76 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/panoptic_utils.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# A custom value to distinguish instance ID and category ID; need to +# be greater than the number of categories. +# For a pixel in the panoptic result map: +# pan_id = ins_id * INSTANCE_OFFSET + cat_id +INSTANCE_OFFSET = 1000 \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/recall.py b/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/recall.py new file mode 100755 index 00000000..4b41c4ae --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/core/evaluation/recall.py @@ -0,0 +1,197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections.abc import Sequence + +import numpy as np +from dbnet_cv.utils import print_log +from terminaltables import AsciiTable + +from .bbox_overlaps import bbox_overlaps + + +def _recalls(all_ious, proposal_nums, thrs): + + img_num = all_ious.shape[0] + total_gt_num = sum([ious.shape[0] for ious in all_ious]) + + _ious = np.zeros((proposal_nums.size, total_gt_num), dtype=np.float32) + for k, proposal_num in enumerate(proposal_nums): + tmp_ious = np.zeros(0) + for i in range(img_num): + ious = all_ious[i][:, :proposal_num].copy() + gt_ious = np.zeros((ious.shape[0])) + if ious.size == 0: + tmp_ious = np.hstack((tmp_ious, gt_ious)) + continue + for j in range(ious.shape[0]): + gt_max_overlaps = ious.argmax(axis=1) + max_ious = ious[np.arange(0, ious.shape[0]), gt_max_overlaps] + gt_idx = max_ious.argmax() + gt_ious[j] = max_ious[gt_idx] + box_idx = gt_max_overlaps[gt_idx] + ious[gt_idx, :] = -1 + ious[:, box_idx] = -1 + tmp_ious = np.hstack((tmp_ious, gt_ious)) + _ious[k, :] = tmp_ious + + _ious = np.fliplr(np.sort(_ious, axis=1)) + recalls = np.zeros((proposal_nums.size, thrs.size)) + for i, thr in enumerate(thrs): + recalls[:, i] = (_ious >= thr).sum(axis=1) / float(total_gt_num) + + return recalls + + +def set_recall_param(proposal_nums, iou_thrs): + """Check proposal_nums and iou_thrs and set correct format.""" + if isinstance(proposal_nums, Sequence): + _proposal_nums = np.array(proposal_nums) + elif isinstance(proposal_nums, int): + _proposal_nums = np.array([proposal_nums]) + else: + _proposal_nums = proposal_nums + + if iou_thrs is None: + _iou_thrs = np.array([0.5]) + elif isinstance(iou_thrs, Sequence): + _iou_thrs = np.array(iou_thrs) + elif isinstance(iou_thrs, float): + _iou_thrs = np.array([iou_thrs]) + else: + _iou_thrs = iou_thrs + + return _proposal_nums, _iou_thrs + + +def eval_recalls(gts, + proposals, + proposal_nums=None, + iou_thrs=0.5, + logger=None, + use_legacy_coordinate=False): + """Calculate recalls. + + Args: + gts (list[ndarray]): a list of arrays of shape (n, 4) + proposals (list[ndarray]): a list of arrays of shape (k, 4) or (k, 5) + proposal_nums (int | Sequence[int]): Top N proposals to be evaluated. + iou_thrs (float | Sequence[float]): IoU thresholds. Default: 0.5. + logger (logging.Logger | str | None): The way to print the recall + summary. See `dbnet_cv.utils.print_log()` for details. Default: None. + use_legacy_coordinate (bool): Whether use coordinate system + in dbnet_det v1.x. "1" was added to both height and width + which means w, h should be + computed as 'x2 - x1 + 1` and 'y2 - y1 + 1'. Default: False. + + + Returns: + ndarray: recalls of different ious and proposal nums + """ + + img_num = len(gts) + assert img_num == len(proposals) + proposal_nums, iou_thrs = set_recall_param(proposal_nums, iou_thrs) + all_ious = [] + for i in range(img_num): + if proposals[i].ndim == 2 and proposals[i].shape[1] == 5: + scores = proposals[i][:, 4] + sort_idx = np.argsort(scores)[::-1] + img_proposal = proposals[i][sort_idx, :] + else: + img_proposal = proposals[i] + prop_num = min(img_proposal.shape[0], proposal_nums[-1]) + if gts[i] is None or gts[i].shape[0] == 0: + ious = np.zeros((0, img_proposal.shape[0]), dtype=np.float32) + else: + ious = bbox_overlaps( + gts[i], + img_proposal[:prop_num, :4], + use_legacy_coordinate=use_legacy_coordinate) + all_ious.append(ious) + all_ious = np.array(all_ious) + recalls = _recalls(all_ious, proposal_nums, iou_thrs) + + print_recall_summary(recalls, proposal_nums, iou_thrs, logger=logger) + return recalls + + +def print_recall_summary(recalls, + proposal_nums, + iou_thrs, + row_idxs=None, + col_idxs=None, + logger=None): + """Print recalls in a table. + + Args: + recalls (ndarray): calculated from `bbox_recalls` + proposal_nums (ndarray or list): top N proposals + iou_thrs (ndarray or list): iou thresholds + row_idxs (ndarray): which rows(proposal nums) to print + col_idxs (ndarray): which cols(iou thresholds) to print + logger (logging.Logger | str | None): The way to print the recall + summary. See `dbnet_cv.utils.print_log()` for details. Default: None. + """ + proposal_nums = np.array(proposal_nums, dtype=np.int32) + iou_thrs = np.array(iou_thrs) + if row_idxs is None: + row_idxs = np.arange(proposal_nums.size) + if col_idxs is None: + col_idxs = np.arange(iou_thrs.size) + row_header = [''] + iou_thrs[col_idxs].tolist() + table_data = [row_header] + for i, num in enumerate(proposal_nums[row_idxs]): + row = [f'{val:.3f}' for val in recalls[row_idxs[i], col_idxs].tolist()] + row.insert(0, num) + table_data.append(row) + table = AsciiTable(table_data) + print_log('\n' + table.table, logger=logger) + + +def plot_num_recall(recalls, proposal_nums): + """Plot Proposal_num-Recalls curve. + + Args: + recalls(ndarray or list): shape (k,) + proposal_nums(ndarray or list): same shape as `recalls` + """ + if isinstance(proposal_nums, np.ndarray): + _proposal_nums = proposal_nums.tolist() + else: + _proposal_nums = proposal_nums + if isinstance(recalls, np.ndarray): + _recalls = recalls.tolist() + else: + _recalls = recalls + + import matplotlib.pyplot as plt + f = plt.figure() + plt.plot([0] + _proposal_nums, [0] + _recalls) + plt.xlabel('Proposal num') + plt.ylabel('Recall') + plt.axis([0, proposal_nums.max(), 0, 1]) + f.show() + + +def plot_iou_recall(recalls, iou_thrs): + """Plot IoU-Recalls curve. + + Args: + recalls(ndarray or list): shape (k,) + iou_thrs(ndarray or list): same shape as `recalls` + """ + if isinstance(iou_thrs, np.ndarray): + _iou_thrs = iou_thrs.tolist() + else: + _iou_thrs = iou_thrs + if isinstance(recalls, np.ndarray): + _recalls = recalls.tolist() + else: + _recalls = recalls + + import matplotlib.pyplot as plt + f = plt.figure() + plt.plot(_iou_thrs + [1.0], _recalls + [0.]) + plt.xlabel('IoU') + plt.ylabel('Recall') + plt.axis([iou_thrs.min(), 1, 0, 1]) + f.show() diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/core/mask/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_det/core/mask/__init__.py new file mode 100755 index 00000000..506dd4de --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/core/mask/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# from .mask_target import mask_target +from .structures import BaseInstanceMasks, BitmapMasks, PolygonMasks +from .utils import encode_mask_results, mask2bbox, split_combined_polys + +# __all__ = [ +# 'split_combined_polys', 'mask_target', 'BaseInstanceMasks', 'BitmapMasks', +# 'PolygonMasks', 'encode_mask_results', 'mask2bbox' +# ] diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/core/mask/structures.py b/cv/ocr/dbnet/pytorch/dbnet_det/core/mask/structures.py new file mode 100755 index 00000000..bd665259 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/core/mask/structures.py @@ -0,0 +1,1102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + +import cv2 +import dbnet_cv +import numpy as np +import pycocotools.mask as maskUtils +import torch +from dbnet_cv.ops.roi_align import roi_align + + +class BaseInstanceMasks(metaclass=ABCMeta): + """Base class for instance masks.""" + + @abstractmethod + def rescale(self, scale, interpolation='nearest'): + """Rescale masks as large as possible while keeping the aspect ratio. + For details can refer to `dbnet_cv.imrescale`. + + Args: + scale (tuple[int]): The maximum size (h, w) of rescaled mask. + interpolation (str): Same as :func:`dbnet_cv.imrescale`. + + Returns: + BaseInstanceMasks: The rescaled masks. + """ + + @abstractmethod + def resize(self, out_shape, interpolation='nearest'): + """Resize masks to the given out_shape. + + Args: + out_shape: Target (h, w) of resized mask. + interpolation (str): See :func:`dbnet_cv.imresize`. + + Returns: + BaseInstanceMasks: The resized masks. + """ + + @abstractmethod + def flip(self, flip_direction='horizontal'): + """Flip masks alone the given direction. + + Args: + flip_direction (str): Either 'horizontal' or 'vertical'. + + Returns: + BaseInstanceMasks: The flipped masks. + """ + + @abstractmethod + def pad(self, out_shape, pad_val): + """Pad masks to the given size of (h, w). + + Args: + out_shape (tuple[int]): Target (h, w) of padded mask. + pad_val (int): The padded value. + + Returns: + BaseInstanceMasks: The padded masks. + """ + + @abstractmethod + def crop(self, bbox): + """Crop each mask by the given bbox. + + Args: + bbox (ndarray): Bbox in format [x1, y1, x2, y2], shape (4, ). + + Return: + BaseInstanceMasks: The cropped masks. + """ + + @abstractmethod + def crop_and_resize(self, + bboxes, + out_shape, + inds, + device, + interpolation='bilinear', + binarize=True): + """Crop and resize masks by the given bboxes. + + This function is mainly used in mask targets computation. + It firstly align mask to bboxes by assigned_inds, then crop mask by the + assigned bbox and resize to the size of (mask_h, mask_w) + + Args: + bboxes (Tensor): Bboxes in format [x1, y1, x2, y2], shape (N, 4) + out_shape (tuple[int]): Target (h, w) of resized mask + inds (ndarray): Indexes to assign masks to each bbox, + shape (N,) and values should be between [0, num_masks - 1]. + device (str): Device of bboxes + interpolation (str): See `dbnet_cv.imresize` + binarize (bool): if True fractional values are rounded to 0 or 1 + after the resize operation. if False and unsupported an error + will be raised. Defaults to True. + + Return: + BaseInstanceMasks: the cropped and resized masks. + """ + + @abstractmethod + def expand(self, expanded_h, expanded_w, top, left): + """see :class:`Expand`.""" + + @property + @abstractmethod + def areas(self): + """ndarray: areas of each instance.""" + + @abstractmethod + def to_ndarray(self): + """Convert masks to the format of ndarray. + + Return: + ndarray: Converted masks in the format of ndarray. + """ + + @abstractmethod + def to_tensor(self, dtype, device): + """Convert masks to the format of Tensor. + + Args: + dtype (str): Dtype of converted mask. + device (torch.device): Device of converted masks. + + Returns: + Tensor: Converted masks in the format of Tensor. + """ + + @abstractmethod + def translate(self, + out_shape, + offset, + direction='horizontal', + fill_val=0, + interpolation='bilinear'): + """Translate the masks. + + Args: + out_shape (tuple[int]): Shape for output mask, format (h, w). + offset (int | float): The offset for translate. + direction (str): The translate direction, either "horizontal" + or "vertical". + fill_val (int | float): Border value. Default 0. + interpolation (str): Same as :func:`dbnet_cv.imtranslate`. + + Returns: + Translated masks. + """ + + def shear(self, + out_shape, + magnitude, + direction='horizontal', + border_value=0, + interpolation='bilinear'): + """Shear the masks. + + Args: + out_shape (tuple[int]): Shape for output mask, format (h, w). + magnitude (int | float): The magnitude used for shear. + direction (str): The shear direction, either "horizontal" + or "vertical". + border_value (int | tuple[int]): Value used in case of a + constant border. Default 0. + interpolation (str): Same as in :func:`dbnet_cv.imshear`. + + Returns: + ndarray: Sheared masks. + """ + + @abstractmethod + def rotate(self, out_shape, angle, center=None, scale=1.0, fill_val=0): + """Rotate the masks. + + Args: + out_shape (tuple[int]): Shape for output mask, format (h, w). + angle (int | float): Rotation angle in degrees. Positive values + mean counter-clockwise rotation. + center (tuple[float], optional): Center point (w, h) of the + rotation in source image. If not specified, the center of + the image will be used. + scale (int | float): Isotropic scale factor. + fill_val (int | float): Border value. Default 0 for masks. + + Returns: + Rotated masks. + """ + + +class BitmapMasks(BaseInstanceMasks): + """This class represents masks in the form of bitmaps. + + Args: + masks (ndarray): ndarray of masks in shape (N, H, W), where N is + the number of objects. + height (int): height of masks + width (int): width of masks + + Example: + >>> from dbnet_det.core.mask.structures import * # NOQA + >>> num_masks, H, W = 3, 32, 32 + >>> rng = np.random.RandomState(0) + >>> masks = (rng.rand(num_masks, H, W) > 0.1).astype(np.int) + >>> self = BitmapMasks(masks, height=H, width=W) + + >>> # demo crop_and_resize + >>> num_boxes = 5 + >>> bboxes = np.array([[0, 0, 30, 10.0]] * num_boxes) + >>> out_shape = (14, 14) + >>> inds = torch.randint(0, len(self), size=(num_boxes,)) + >>> device = 'cpu' + >>> interpolation = 'bilinear' + >>> new = self.crop_and_resize( + ... bboxes, out_shape, inds, device, interpolation) + >>> assert len(new) == num_boxes + >>> assert new.height, new.width == out_shape + """ + + def __init__(self, masks, height, width): + self.height = height + self.width = width + if len(masks) == 0: + self.masks = np.empty((0, self.height, self.width), dtype=np.uint8) + else: + assert isinstance(masks, (list, np.ndarray)) + if isinstance(masks, list): + assert isinstance(masks[0], np.ndarray) + assert masks[0].ndim == 2 # (H, W) + else: + assert masks.ndim == 3 # (N, H, W) + + self.masks = np.stack(masks).reshape(-1, height, width) + assert self.masks.shape[1] == self.height + assert self.masks.shape[2] == self.width + + def __getitem__(self, index): + """Index the BitmapMask. + + Args: + index (int | ndarray): Indices in the format of integer or ndarray. + + Returns: + :obj:`BitmapMasks`: Indexed bitmap masks. + """ + masks = self.masks[index].reshape(-1, self.height, self.width) + return BitmapMasks(masks, self.height, self.width) + + def __iter__(self): + return iter(self.masks) + + def __repr__(self): + s = self.__class__.__name__ + '(' + s += f'num_masks={len(self.masks)}, ' + s += f'height={self.height}, ' + s += f'width={self.width})' + return s + + def __len__(self): + """Number of masks.""" + return len(self.masks) + + def rescale(self, scale, interpolation='nearest'): + """See :func:`BaseInstanceMasks.rescale`.""" + if len(self.masks) == 0: + new_w, new_h = dbnet_cv.rescale_size((self.width, self.height), scale) + rescaled_masks = np.empty((0, new_h, new_w), dtype=np.uint8) + else: + rescaled_masks = np.stack([ + dbnet_cv.imrescale(mask, scale, interpolation=interpolation) + for mask in self.masks + ]) + height, width = rescaled_masks.shape[1:] + return BitmapMasks(rescaled_masks, height, width) + + def resize(self, out_shape, interpolation='nearest'): + """See :func:`BaseInstanceMasks.resize`.""" + if len(self.masks) == 0: + resized_masks = np.empty((0, *out_shape), dtype=np.uint8) + else: + resized_masks = np.stack([ + dbnet_cv.imresize( + mask, out_shape[::-1], interpolation=interpolation) + for mask in self.masks + ]) + return BitmapMasks(resized_masks, *out_shape) + + def flip(self, flip_direction='horizontal'): + """See :func:`BaseInstanceMasks.flip`.""" + assert flip_direction in ('horizontal', 'vertical', 'diagonal') + + if len(self.masks) == 0: + flipped_masks = self.masks + else: + flipped_masks = np.stack([ + dbnet_cv.imflip(mask, direction=flip_direction) + for mask in self.masks + ]) + return BitmapMasks(flipped_masks, self.height, self.width) + + def pad(self, out_shape, pad_val=0): + """See :func:`BaseInstanceMasks.pad`.""" + if len(self.masks) == 0: + padded_masks = np.empty((0, *out_shape), dtype=np.uint8) + else: + padded_masks = np.stack([ + dbnet_cv.impad(mask, shape=out_shape, pad_val=pad_val) + for mask in self.masks + ]) + return BitmapMasks(padded_masks, *out_shape) + + def crop(self, bbox): + """See :func:`BaseInstanceMasks.crop`.""" + assert isinstance(bbox, np.ndarray) + assert bbox.ndim == 1 + + # clip the boundary + bbox = bbox.copy() + bbox[0::2] = np.clip(bbox[0::2], 0, self.width) + bbox[1::2] = np.clip(bbox[1::2], 0, self.height) + x1, y1, x2, y2 = bbox + w = np.maximum(x2 - x1, 1) + h = np.maximum(y2 - y1, 1) + + if len(self.masks) == 0: + cropped_masks = np.empty((0, h, w), dtype=np.uint8) + else: + cropped_masks = self.masks[:, y1:y1 + h, x1:x1 + w] + return BitmapMasks(cropped_masks, h, w) + + def crop_and_resize(self, + bboxes, + out_shape, + inds, + device='cpu', + interpolation='bilinear', + binarize=True): + """See :func:`BaseInstanceMasks.crop_and_resize`.""" + if len(self.masks) == 0: + empty_masks = np.empty((0, *out_shape), dtype=np.uint8) + return BitmapMasks(empty_masks, *out_shape) + + # convert bboxes to tensor + if isinstance(bboxes, np.ndarray): + bboxes = torch.from_numpy(bboxes).to(device=device) + if isinstance(inds, np.ndarray): + inds = torch.from_numpy(inds).to(device=device) + + num_bbox = bboxes.shape[0] + fake_inds = torch.arange( + num_bbox, device=device).to(dtype=bboxes.dtype)[:, None] + rois = torch.cat([fake_inds, bboxes], dim=1) # Nx5 + rois = rois.to(device=device) + if num_bbox > 0: + gt_masks_th = torch.from_numpy(self.masks).to(device).index_select( + 0, inds).to(dtype=rois.dtype) + targets = roi_align(gt_masks_th[:, None, :, :], rois, out_shape, + 1.0, 0, 'avg', True).squeeze(1) + if binarize: + resized_masks = (targets >= 0.5).cpu().numpy() + else: + resized_masks = targets.cpu().numpy() + else: + resized_masks = [] + return BitmapMasks(resized_masks, *out_shape) + + def expand(self, expanded_h, expanded_w, top, left): + """See :func:`BaseInstanceMasks.expand`.""" + if len(self.masks) == 0: + expanded_mask = np.empty((0, expanded_h, expanded_w), + dtype=np.uint8) + else: + expanded_mask = np.zeros((len(self), expanded_h, expanded_w), + dtype=np.uint8) + expanded_mask[:, top:top + self.height, + left:left + self.width] = self.masks + return BitmapMasks(expanded_mask, expanded_h, expanded_w) + + def translate(self, + out_shape, + offset, + direction='horizontal', + fill_val=0, + interpolation='bilinear'): + """Translate the BitmapMasks. + + Args: + out_shape (tuple[int]): Shape for output mask, format (h, w). + offset (int | float): The offset for translate. + direction (str): The translate direction, either "horizontal" + or "vertical". + fill_val (int | float): Border value. Default 0 for masks. + interpolation (str): Same as :func:`dbnet_cv.imtranslate`. + + Returns: + BitmapMasks: Translated BitmapMasks. + + Example: + >>> from dbnet_det.core.mask.structures import BitmapMasks + >>> self = BitmapMasks.random(dtype=np.uint8) + >>> out_shape = (32, 32) + >>> offset = 4 + >>> direction = 'horizontal' + >>> fill_val = 0 + >>> interpolation = 'bilinear' + >>> # Note, There seem to be issues when: + >>> # * out_shape is different than self's shape + >>> # * the mask dtype is not supported by cv2.AffineWarp + >>> new = self.translate(out_shape, offset, direction, fill_val, + >>> interpolation) + >>> assert len(new) == len(self) + >>> assert new.height, new.width == out_shape + """ + if len(self.masks) == 0: + translated_masks = np.empty((0, *out_shape), dtype=np.uint8) + else: + translated_masks = dbnet_cv.imtranslate( + self.masks.transpose((1, 2, 0)), + offset, + direction, + border_value=fill_val, + interpolation=interpolation) + if translated_masks.ndim == 2: + translated_masks = translated_masks[:, :, None] + translated_masks = translated_masks.transpose( + (2, 0, 1)).astype(self.masks.dtype) + return BitmapMasks(translated_masks, *out_shape) + + def shear(self, + out_shape, + magnitude, + direction='horizontal', + border_value=0, + interpolation='bilinear'): + """Shear the BitmapMasks. + + Args: + out_shape (tuple[int]): Shape for output mask, format (h, w). + magnitude (int | float): The magnitude used for shear. + direction (str): The shear direction, either "horizontal" + or "vertical". + border_value (int | tuple[int]): Value used in case of a + constant border. + interpolation (str): Same as in :func:`dbnet_cv.imshear`. + + Returns: + BitmapMasks: The sheared masks. + """ + if len(self.masks) == 0: + sheared_masks = np.empty((0, *out_shape), dtype=np.uint8) + else: + sheared_masks = dbnet_cv.imshear( + self.masks.transpose((1, 2, 0)), + magnitude, + direction, + border_value=border_value, + interpolation=interpolation) + if sheared_masks.ndim == 2: + sheared_masks = sheared_masks[:, :, None] + sheared_masks = sheared_masks.transpose( + (2, 0, 1)).astype(self.masks.dtype) + return BitmapMasks(sheared_masks, *out_shape) + + def rotate(self, out_shape, angle, center=None, scale=1.0, fill_val=0): + """Rotate the BitmapMasks. + + Args: + out_shape (tuple[int]): Shape for output mask, format (h, w). + angle (int | float): Rotation angle in degrees. Positive values + mean counter-clockwise rotation. + center (tuple[float], optional): Center point (w, h) of the + rotation in source image. If not specified, the center of + the image will be used. + scale (int | float): Isotropic scale factor. + fill_val (int | float): Border value. Default 0 for masks. + + Returns: + BitmapMasks: Rotated BitmapMasks. + """ + if len(self.masks) == 0: + rotated_masks = np.empty((0, *out_shape), dtype=self.masks.dtype) + else: + rotated_masks = dbnet_cv.imrotate( + self.masks.transpose((1, 2, 0)), + angle, + center=center, + scale=scale, + border_value=fill_val) + if rotated_masks.ndim == 2: + # case when only one mask, (h, w) + rotated_masks = rotated_masks[:, :, None] # (h, w, 1) + rotated_masks = rotated_masks.transpose( + (2, 0, 1)).astype(self.masks.dtype) + return BitmapMasks(rotated_masks, *out_shape) + + @property + def areas(self): + """See :py:attr:`BaseInstanceMasks.areas`.""" + return self.masks.sum((1, 2)) + + def to_ndarray(self): + """See :func:`BaseInstanceMasks.to_ndarray`.""" + return self.masks + + def to_tensor(self, dtype, device): + """See :func:`BaseInstanceMasks.to_tensor`.""" + return torch.tensor(self.masks, dtype=dtype, device=device) + + @classmethod + def random(cls, + num_masks=3, + height=32, + width=32, + dtype=np.uint8, + rng=None): + """Generate random bitmap masks for demo / testing purposes. + + Example: + >>> from dbnet_det.core.mask.structures import BitmapMasks + >>> self = BitmapMasks.random() + >>> print('self = {}'.format(self)) + self = BitmapMasks(num_masks=3, height=32, width=32) + """ + from dbnet_det.utils.util_random import ensure_rng + rng = ensure_rng(rng) + masks = (rng.rand(num_masks, height, width) > 0.1).astype(dtype) + self = cls(masks, height=height, width=width) + return self + + def get_bboxes(self): + num_masks = len(self) + boxes = np.zeros((num_masks, 4), dtype=np.float32) + x_any = self.masks.any(axis=1) + y_any = self.masks.any(axis=2) + for idx in range(num_masks): + x = np.where(x_any[idx, :])[0] + y = np.where(y_any[idx, :])[0] + if len(x) > 0 and len(y) > 0: + # use +1 for x_max and y_max so that the right and bottom + # boundary of instance masks are fully included by the box + boxes[idx, :] = np.array([x[0], y[0], x[-1] + 1, y[-1] + 1], + dtype=np.float32) + return boxes + + +class PolygonMasks(BaseInstanceMasks): + """This class represents masks in the form of polygons. + + Polygons is a list of three levels. The first level of the list + corresponds to objects, the second level to the polys that compose the + object, the third level to the poly coordinates + + Args: + masks (list[list[ndarray]]): The first level of the list + corresponds to objects, the second level to the polys that + compose the object, the third level to the poly coordinates + height (int): height of masks + width (int): width of masks + + Example: + >>> from dbnet_det.core.mask.structures import * # NOQA + >>> masks = [ + >>> [ np.array([0, 0, 10, 0, 10, 10., 0, 10, 0, 0]) ] + >>> ] + >>> height, width = 16, 16 + >>> self = PolygonMasks(masks, height, width) + + >>> # demo translate + >>> new = self.translate((16, 16), 4., direction='horizontal') + >>> assert np.all(new.masks[0][0][1::2] == masks[0][0][1::2]) + >>> assert np.all(new.masks[0][0][0::2] == masks[0][0][0::2] + 4) + + >>> # demo crop_and_resize + >>> num_boxes = 3 + >>> bboxes = np.array([[0, 0, 30, 10.0]] * num_boxes) + >>> out_shape = (16, 16) + >>> inds = torch.randint(0, len(self), size=(num_boxes,)) + >>> device = 'cpu' + >>> interpolation = 'bilinear' + >>> new = self.crop_and_resize( + ... bboxes, out_shape, inds, device, interpolation) + >>> assert len(new) == num_boxes + >>> assert new.height, new.width == out_shape + """ + + def __init__(self, masks, height, width): + assert isinstance(masks, list) + if len(masks) > 0: + assert isinstance(masks[0], list) + assert isinstance(masks[0][0], np.ndarray) + + self.height = height + self.width = width + self.masks = masks + + def __getitem__(self, index): + """Index the polygon masks. + + Args: + index (ndarray | List): The indices. + + Returns: + :obj:`PolygonMasks`: The indexed polygon masks. + """ + if isinstance(index, np.ndarray): + index = index.tolist() + if isinstance(index, list): + masks = [self.masks[i] for i in index] + else: + try: + masks = self.masks[index] + except Exception: + raise ValueError( + f'Unsupported input of type {type(index)} for indexing!') + if len(masks) and isinstance(masks[0], np.ndarray): + masks = [masks] # ensure a list of three levels + return PolygonMasks(masks, self.height, self.width) + + def __iter__(self): + return iter(self.masks) + + def __repr__(self): + s = self.__class__.__name__ + '(' + s += f'num_masks={len(self.masks)}, ' + s += f'height={self.height}, ' + s += f'width={self.width})' + return s + + def __len__(self): + """Number of masks.""" + return len(self.masks) + + def rescale(self, scale, interpolation=None): + """see :func:`BaseInstanceMasks.rescale`""" + new_w, new_h = dbnet_cv.rescale_size((self.width, self.height), scale) + if len(self.masks) == 0: + rescaled_masks = PolygonMasks([], new_h, new_w) + else: + rescaled_masks = self.resize((new_h, new_w)) + return rescaled_masks + + def resize(self, out_shape, interpolation=None): + """see :func:`BaseInstanceMasks.resize`""" + if len(self.masks) == 0: + resized_masks = PolygonMasks([], *out_shape) + else: + h_scale = out_shape[0] / self.height + w_scale = out_shape[1] / self.width + resized_masks = [] + for poly_per_obj in self.masks: + resized_poly = [] + for p in poly_per_obj: + p = p.copy() + p[0::2] = p[0::2] * w_scale + p[1::2] = p[1::2] * h_scale + resized_poly.append(p) + resized_masks.append(resized_poly) + resized_masks = PolygonMasks(resized_masks, *out_shape) + return resized_masks + + def flip(self, flip_direction='horizontal'): + """see :func:`BaseInstanceMasks.flip`""" + assert flip_direction in ('horizontal', 'vertical', 'diagonal') + if len(self.masks) == 0: + flipped_masks = PolygonMasks([], self.height, self.width) + else: + flipped_masks = [] + for poly_per_obj in self.masks: + flipped_poly_per_obj = [] + for p in poly_per_obj: + p = p.copy() + if flip_direction == 'horizontal': + p[0::2] = self.width - p[0::2] + elif flip_direction == 'vertical': + p[1::2] = self.height - p[1::2] + else: + p[0::2] = self.width - p[0::2] + p[1::2] = self.height - p[1::2] + flipped_poly_per_obj.append(p) + flipped_masks.append(flipped_poly_per_obj) + flipped_masks = PolygonMasks(flipped_masks, self.height, + self.width) + return flipped_masks + + def crop(self, bbox): + """see :func:`BaseInstanceMasks.crop`""" + assert isinstance(bbox, np.ndarray) + assert bbox.ndim == 1 + + # clip the boundary + bbox = bbox.copy() + bbox[0::2] = np.clip(bbox[0::2], 0, self.width) + bbox[1::2] = np.clip(bbox[1::2], 0, self.height) + x1, y1, x2, y2 = bbox + w = np.maximum(x2 - x1, 1) + h = np.maximum(y2 - y1, 1) + + if len(self.masks) == 0: + cropped_masks = PolygonMasks([], h, w) + else: + cropped_masks = [] + for poly_per_obj in self.masks: + cropped_poly_per_obj = [] + for p in poly_per_obj: + # pycocotools will clip the boundary + p = p.copy() + p[0::2] = p[0::2] - bbox[0] + p[1::2] = p[1::2] - bbox[1] + cropped_poly_per_obj.append(p) + cropped_masks.append(cropped_poly_per_obj) + cropped_masks = PolygonMasks(cropped_masks, h, w) + return cropped_masks + + def pad(self, out_shape, pad_val=0): + """padding has no effect on polygons`""" + return PolygonMasks(self.masks, *out_shape) + + def expand(self, *args, **kwargs): + """TODO: Add expand for polygon""" + raise NotImplementedError + + def crop_and_resize(self, + bboxes, + out_shape, + inds, + device='cpu', + interpolation='bilinear', + binarize=True): + """see :func:`BaseInstanceMasks.crop_and_resize`""" + out_h, out_w = out_shape + if len(self.masks) == 0: + return PolygonMasks([], out_h, out_w) + + if not binarize: + raise ValueError('Polygons are always binary, ' + 'setting binarize=False is unsupported') + + resized_masks = [] + for i in range(len(bboxes)): + mask = self.masks[inds[i]] + bbox = bboxes[i, :] + x1, y1, x2, y2 = bbox + w = np.maximum(x2 - x1, 1) + h = np.maximum(y2 - y1, 1) + h_scale = out_h / max(h, 0.1) # avoid too large scale + w_scale = out_w / max(w, 0.1) + + resized_mask = [] + for p in mask: + p = p.copy() + # crop + # pycocotools will clip the boundary + p[0::2] = p[0::2] - bbox[0] + p[1::2] = p[1::2] - bbox[1] + + # resize + p[0::2] = p[0::2] * w_scale + p[1::2] = p[1::2] * h_scale + resized_mask.append(p) + resized_masks.append(resized_mask) + return PolygonMasks(resized_masks, *out_shape) + + def translate(self, + out_shape, + offset, + direction='horizontal', + fill_val=None, + interpolation=None): + """Translate the PolygonMasks. + + Example: + >>> self = PolygonMasks.random(dtype=np.int) + >>> out_shape = (self.height, self.width) + >>> new = self.translate(out_shape, 4., direction='horizontal') + >>> assert np.all(new.masks[0][0][1::2] == self.masks[0][0][1::2]) + >>> assert np.all(new.masks[0][0][0::2] == self.masks[0][0][0::2] + 4) # noqa: E501 + """ + assert fill_val is None or fill_val == 0, 'Here fill_val is not '\ + f'used, and defaultly should be None or 0. got {fill_val}.' + if len(self.masks) == 0: + translated_masks = PolygonMasks([], *out_shape) + else: + translated_masks = [] + for poly_per_obj in self.masks: + translated_poly_per_obj = [] + for p in poly_per_obj: + p = p.copy() + if direction == 'horizontal': + p[0::2] = np.clip(p[0::2] + offset, 0, out_shape[1]) + elif direction == 'vertical': + p[1::2] = np.clip(p[1::2] + offset, 0, out_shape[0]) + translated_poly_per_obj.append(p) + translated_masks.append(translated_poly_per_obj) + translated_masks = PolygonMasks(translated_masks, *out_shape) + return translated_masks + + def shear(self, + out_shape, + magnitude, + direction='horizontal', + border_value=0, + interpolation='bilinear'): + """See :func:`BaseInstanceMasks.shear`.""" + if len(self.masks) == 0: + sheared_masks = PolygonMasks([], *out_shape) + else: + sheared_masks = [] + if direction == 'horizontal': + shear_matrix = np.stack([[1, magnitude], + [0, 1]]).astype(np.float32) + elif direction == 'vertical': + shear_matrix = np.stack([[1, 0], [magnitude, + 1]]).astype(np.float32) + for poly_per_obj in self.masks: + sheared_poly = [] + for p in poly_per_obj: + p = np.stack([p[0::2], p[1::2]], axis=0) # [2, n] + new_coords = np.matmul(shear_matrix, p) # [2, n] + new_coords[0, :] = np.clip(new_coords[0, :], 0, + out_shape[1]) + new_coords[1, :] = np.clip(new_coords[1, :], 0, + out_shape[0]) + sheared_poly.append( + new_coords.transpose((1, 0)).reshape(-1)) + sheared_masks.append(sheared_poly) + sheared_masks = PolygonMasks(sheared_masks, *out_shape) + return sheared_masks + + def rotate(self, out_shape, angle, center=None, scale=1.0, fill_val=0): + """See :func:`BaseInstanceMasks.rotate`.""" + if len(self.masks) == 0: + rotated_masks = PolygonMasks([], *out_shape) + else: + rotated_masks = [] + rotate_matrix = cv2.getRotationMatrix2D(center, -angle, scale) + for poly_per_obj in self.masks: + rotated_poly = [] + for p in poly_per_obj: + p = p.copy() + coords = np.stack([p[0::2], p[1::2]], axis=1) # [n, 2] + # pad 1 to convert from format [x, y] to homogeneous + # coordinates format [x, y, 1] + coords = np.concatenate( + (coords, np.ones((coords.shape[0], 1), coords.dtype)), + axis=1) # [n, 3] + rotated_coords = np.matmul( + rotate_matrix[None, :, :], + coords[:, :, None])[..., 0] # [n, 2, 1] -> [n, 2] + rotated_coords[:, 0] = np.clip(rotated_coords[:, 0], 0, + out_shape[1]) + rotated_coords[:, 1] = np.clip(rotated_coords[:, 1], 0, + out_shape[0]) + rotated_poly.append(rotated_coords.reshape(-1)) + rotated_masks.append(rotated_poly) + rotated_masks = PolygonMasks(rotated_masks, *out_shape) + return rotated_masks + + def to_bitmap(self): + """convert polygon masks to bitmap masks.""" + bitmap_masks = self.to_ndarray() + return BitmapMasks(bitmap_masks, self.height, self.width) + + @property + def areas(self): + """Compute areas of masks. + + This func is modified from `detectron2 + `_. + The function only works with Polygons using the shoelace formula. + + Return: + ndarray: areas of each instance + """ # noqa: W501 + area = [] + for polygons_per_obj in self.masks: + area_per_obj = 0 + for p in polygons_per_obj: + area_per_obj += self._polygon_area(p[0::2], p[1::2]) + area.append(area_per_obj) + return np.asarray(area) + + def _polygon_area(self, x, y): + """Compute the area of a component of a polygon. + + Using the shoelace formula: + https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates + + Args: + x (ndarray): x coordinates of the component + y (ndarray): y coordinates of the component + + Return: + float: the are of the component + """ # noqa: 501 + return 0.5 * np.abs( + np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1))) + + def to_ndarray(self): + """Convert masks to the format of ndarray.""" + if len(self.masks) == 0: + return np.empty((0, self.height, self.width), dtype=np.uint8) + bitmap_masks = [] + for poly_per_obj in self.masks: + bitmap_masks.append( + polygon_to_bitmap(poly_per_obj, self.height, self.width)) + return np.stack(bitmap_masks) + + def to_tensor(self, dtype, device): + """See :func:`BaseInstanceMasks.to_tensor`.""" + if len(self.masks) == 0: + return torch.empty((0, self.height, self.width), + dtype=dtype, + device=device) + ndarray_masks = self.to_ndarray() + return torch.tensor(ndarray_masks, dtype=dtype, device=device) + + @classmethod + def random(cls, + num_masks=3, + height=32, + width=32, + n_verts=5, + dtype=np.float32, + rng=None): + """Generate random polygon masks for demo / testing purposes. + + Adapted from [1]_ + + References: + .. [1] https://gitlab.kitware.com/computer-vision/kwimage/-/blob/928cae35ca8/kwimage/structs/polygon.py#L379 # noqa: E501 + + Example: + >>> from dbnet_det.core.mask.structures import PolygonMasks + >>> self = PolygonMasks.random() + >>> print('self = {}'.format(self)) + """ + from dbnet_det.utils.util_random import ensure_rng + rng = ensure_rng(rng) + + def _gen_polygon(n, irregularity, spikeyness): + """Creates the polygon by sampling points on a circle around the + centre. Random noise is added by varying the angular spacing + between sequential points, and by varying the radial distance of + each point from the centre. + + Based on original code by Mike Ounsworth + + Args: + n (int): number of vertices + irregularity (float): [0,1] indicating how much variance there + is in the angular spacing of vertices. [0,1] will map to + [0, 2pi/numberOfVerts] + spikeyness (float): [0,1] indicating how much variance there is + in each vertex from the circle of radius aveRadius. [0,1] + will map to [0, aveRadius] + + Returns: + a list of vertices, in CCW order. + """ + from scipy.stats import truncnorm + + # Generate around the unit circle + cx, cy = (0.0, 0.0) + radius = 1 + + tau = np.pi * 2 + + irregularity = np.clip(irregularity, 0, 1) * 2 * np.pi / n + spikeyness = np.clip(spikeyness, 1e-9, 1) + + # generate n angle steps + lower = (tau / n) - irregularity + upper = (tau / n) + irregularity + angle_steps = rng.uniform(lower, upper, n) + + # normalize the steps so that point 0 and point n+1 are the same + k = angle_steps.sum() / (2 * np.pi) + angles = (angle_steps / k).cumsum() + rng.uniform(0, tau) + + # Convert high and low values to be wrt the standard normal range + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.truncnorm.html + low = 0 + high = 2 * radius + mean = radius + std = spikeyness + a = (low - mean) / std + b = (high - mean) / std + tnorm = truncnorm(a=a, b=b, loc=mean, scale=std) + + # now generate the points + radii = tnorm.rvs(n, random_state=rng) + x_pts = cx + radii * np.cos(angles) + y_pts = cy + radii * np.sin(angles) + + points = np.hstack([x_pts[:, None], y_pts[:, None]]) + + # Scale to 0-1 space + points = points - points.min(axis=0) + points = points / points.max(axis=0) + + # Randomly place within 0-1 space + points = points * (rng.rand() * .8 + .2) + min_pt = points.min(axis=0) + max_pt = points.max(axis=0) + + high = (1 - max_pt) + low = (0 - min_pt) + offset = (rng.rand(2) * (high - low)) + low + points = points + offset + return points + + def _order_vertices(verts): + """ + References: + https://stackoverflow.com/questions/1709283/how-can-i-sort-a-coordinate-list-for-a-rectangle-counterclockwise + """ + mlat = verts.T[0].sum() / len(verts) + mlng = verts.T[1].sum() / len(verts) + + tau = np.pi * 2 + angle = (np.arctan2(mlat - verts.T[0], verts.T[1] - mlng) + + tau) % tau + sortx = angle.argsort() + verts = verts.take(sortx, axis=0) + return verts + + # Generate a random exterior for each requested mask + masks = [] + for _ in range(num_masks): + exterior = _order_vertices(_gen_polygon(n_verts, 0.9, 0.9)) + exterior = (exterior * [(width, height)]).astype(dtype) + masks.append([exterior.ravel()]) + + self = cls(masks, height, width) + return self + + def get_bboxes(self): + num_masks = len(self) + boxes = np.zeros((num_masks, 4), dtype=np.float32) + for idx, poly_per_obj in enumerate(self.masks): + # simply use a number that is big enough for comparison with + # coordinates + xy_min = np.array([self.width * 2, self.height * 2], + dtype=np.float32) + xy_max = np.zeros(2, dtype=np.float32) + for p in poly_per_obj: + xy = np.array(p).reshape(-1, 2).astype(np.float32) + xy_min = np.minimum(xy_min, np.min(xy, axis=0)) + xy_max = np.maximum(xy_max, np.max(xy, axis=0)) + boxes[idx, :2] = xy_min + boxes[idx, 2:] = xy_max + + return boxes + + +def polygon_to_bitmap(polygons, height, width): + """Convert masks from the form of polygons to bitmaps. + + Args: + polygons (list[ndarray]): masks in polygon representation + height (int): mask height + width (int): mask width + + Return: + ndarray: the converted masks in bitmap representation + """ + rles = maskUtils.frPyObjects(polygons, height, width) + rle = maskUtils.merge(rles) + bitmap_mask = maskUtils.decode(rle).astype(np.bool) + return bitmap_mask + + +def bitmap_to_polygon(bitmap): + """Convert masks from the form of bitmaps to polygons. + + Args: + bitmap (ndarray): masks in bitmap representation. + + Return: + list[ndarray]: the converted mask in polygon representation. + bool: whether the mask has holes. + """ + bitmap = np.ascontiguousarray(bitmap).astype(np.uint8) + # cv2.RETR_CCOMP: retrieves all of the contours and organizes them + # into a two-level hierarchy. At the top level, there are external + # boundaries of the components. At the second level, there are + # boundaries of the holes. If there is another contour inside a hole + # of a connected component, it is still put at the top level. + # cv2.CHAIN_APPROX_NONE: stores absolutely all the contour points. + outs = cv2.findContours(bitmap, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE) + contours = outs[-2] + hierarchy = outs[-1] + if hierarchy is None: + return [], False + # hierarchy[i]: 4 elements, for the indexes of next, previous, + # parent, or nested contours. If there is no corresponding contour, + # it will be -1. + with_hole = (hierarchy.reshape(-1, 4)[:, 3] >= 0).any() + contours = [c.reshape(-1, 2) for c in contours] + return contours, with_hole diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/core/mask/utils.py b/cv/ocr/dbnet/pytorch/dbnet_det/core/mask/utils.py new file mode 100755 index 00000000..d38a95b4 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/core/mask/utils.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dbnet_cv +import numpy as np +import pycocotools.mask as mask_util +import torch + + +def split_combined_polys(polys, poly_lens, polys_per_mask): + """Split the combined 1-D polys into masks. + + A mask is represented as a list of polys, and a poly is represented as + a 1-D array. In dataset, all masks are concatenated into a single 1-D + tensor. Here we need to split the tensor into original representations. + + Args: + polys (list): a list (length = image num) of 1-D tensors + poly_lens (list): a list (length = image num) of poly length + polys_per_mask (list): a list (length = image num) of poly number + of each mask + + Returns: + list: a list (length = image num) of list (length = mask num) of \ + list (length = poly num) of numpy array. + """ + mask_polys_list = [] + for img_id in range(len(polys)): + polys_single = polys[img_id] + polys_lens_single = poly_lens[img_id].tolist() + polys_per_mask_single = polys_per_mask[img_id].tolist() + + split_polys = dbnet_cv.slice_list(polys_single, polys_lens_single) + mask_polys = dbnet_cv.slice_list(split_polys, polys_per_mask_single) + mask_polys_list.append(mask_polys) + return mask_polys_list + + +# TODO: move this function to more proper place +def encode_mask_results(mask_results): + """Encode bitmap mask to RLE code. + + Args: + mask_results (list | tuple[list]): bitmap mask results. + In mask scoring rcnn, mask_results is a tuple of (segm_results, + segm_cls_score). + + Returns: + list | tuple: RLE encoded mask. + """ + if isinstance(mask_results, tuple): # mask scoring + cls_segms, cls_mask_scores = mask_results + else: + cls_segms = mask_results + num_classes = len(cls_segms) + encoded_mask_results = [[] for _ in range(num_classes)] + for i in range(len(cls_segms)): + for cls_segm in cls_segms[i]: + encoded_mask_results[i].append( + mask_util.encode( + np.array( + cls_segm[:, :, np.newaxis], order='F', + dtype='uint8'))[0]) # encoded with RLE + if isinstance(mask_results, tuple): + return encoded_mask_results, cls_mask_scores + else: + return encoded_mask_results + + +def mask2bbox(masks): + """Obtain tight bounding boxes of binary masks. + + Args: + masks (Tensor): Binary mask of shape (n, h, w). + + Returns: + Tensor: Bboxe with shape (n, 4) of \ + positive region in binary mask. + """ + N = masks.shape[0] + bboxes = masks.new_zeros((N, 4), dtype=torch.float32) + x_any = torch.any(masks, dim=1) + y_any = torch.any(masks, dim=2) + for i in range(N): + x = torch.where(x_any[i, :])[0] + y = torch.where(y_any[i, :])[0] + if len(x) > 0 and len(y) > 0: + bboxes[i, :] = bboxes.new_tensor( + [x[0], y[0], x[-1] + 1, y[-1] + 1]) + + return bboxes diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/core/utils/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_det/core/utils/__init__.py new file mode 100755 index 00000000..3f0d0708 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/core/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dist_utils import (DistOptimizerHook, all_reduce_dict, allreduce_grads, + reduce_mean, sync_random_seed) +from .misc import (center_of_mass, filter_scores_and_topk, flip_tensor, + generate_coordinate, mask2ndarray, multi_apply, + select_single_mlvl, unmap) + +__all__ = [ + 'allreduce_grads', 'DistOptimizerHook', 'reduce_mean', 'multi_apply', + 'unmap', 'mask2ndarray', 'flip_tensor', 'all_reduce_dict', + 'center_of_mass', 'generate_coordinate', 'select_single_mlvl', + 'filter_scores_and_topk', 'sync_random_seed' +] diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/core/utils/dist_utils.py b/cv/ocr/dbnet/pytorch/dbnet_det/core/utils/dist_utils.py new file mode 100755 index 00000000..3681d000 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/core/utils/dist_utils.py @@ -0,0 +1,193 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +import pickle +import warnings +from collections import OrderedDict + +import numpy as np +import torch +import torch.distributed as dist +from dbnet_cv.runner import OptimizerHook, get_dist_info +from torch._utils import (_flatten_dense_tensors, _take_tensors, + _unflatten_dense_tensors) + + +def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): + if bucket_size_mb > 0: + bucket_size_bytes = bucket_size_mb * 1024 * 1024 + buckets = _take_tensors(tensors, bucket_size_bytes) + else: + buckets = OrderedDict() + for tensor in tensors: + tp = tensor.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(tensor) + buckets = buckets.values() + + for bucket in buckets: + flat_tensors = _flatten_dense_tensors(bucket) + dist.all_reduce(flat_tensors) + flat_tensors.div_(world_size) + for tensor, synced in zip( + bucket, _unflatten_dense_tensors(flat_tensors, bucket)): + tensor.copy_(synced) + + +def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): + """Allreduce gradients. + + Args: + params (list[torch.Parameters]): List of parameters of a model + coalesce (bool, optional): Whether allreduce parameters as a whole. + Defaults to True. + bucket_size_mb (int, optional): Size of bucket, the unit is MB. + Defaults to -1. + """ + grads = [ + param.grad.data for param in params + if param.requires_grad and param.grad is not None + ] + world_size = dist.get_world_size() + if coalesce: + _allreduce_coalesced(grads, world_size, bucket_size_mb) + else: + for tensor in grads: + dist.all_reduce(tensor.div_(world_size)) + + +class DistOptimizerHook(OptimizerHook): + """Deprecated optimizer hook for distributed training.""" + + def __init__(self, *args, **kwargs): + warnings.warn('"DistOptimizerHook" is deprecated, please switch to' + '"dbnet_cv.runner.OptimizerHook".') + super().__init__(*args, **kwargs) + + +def reduce_mean(tensor): + """"Obtain the mean of tensor on different GPUs.""" + if not (dist.is_available() and dist.is_initialized()): + return tensor + tensor = tensor.clone() + dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) + return tensor + + +def obj2tensor(pyobj, device='cuda'): + """Serialize picklable python object to tensor.""" + storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj)) + return torch.ByteTensor(storage).to(device=device) + + +def tensor2obj(tensor): + """Deserialize tensor to picklable python object.""" + return pickle.loads(tensor.cpu().numpy().tobytes()) + + +@functools.lru_cache() +def _get_global_gloo_group(): + """Return a process group based on gloo backend, containing all the ranks + The result is cached.""" + if dist.get_backend() == 'nccl': + return dist.new_group(backend='gloo') + else: + return dist.group.WORLD + + +def all_reduce_dict(py_dict, op='sum', group=None, to_float=True): + """Apply all reduce function for python dict object. + + The code is modified from https://github.com/Megvii- + BaseDetection/YOLOX/blob/main/yolox/utils/allreduce_norm.py. + + NOTE: make sure that py_dict in different ranks has the same keys and + the values should be in the same shape. Currently only supports + nccl backend. + + Args: + py_dict (dict): Dict to be applied all reduce op. + op (str): Operator, could be 'sum' or 'mean'. Default: 'sum' + group (:obj:`torch.distributed.group`, optional): Distributed group, + Default: None. + to_float (bool): Whether to convert all values of dict to float. + Default: True. + + Returns: + OrderedDict: reduced python dict object. + """ + warnings.warn( + 'group` is deprecated. Currently only supports NCCL backend.') + _, world_size = get_dist_info() + if world_size == 1: + return py_dict + + # all reduce logic across different devices. + py_key = list(py_dict.keys()) + if not isinstance(py_dict, OrderedDict): + py_key_tensor = obj2tensor(py_key) + dist.broadcast(py_key_tensor, src=0) + py_key = tensor2obj(py_key_tensor) + + tensor_shapes = [py_dict[k].shape for k in py_key] + tensor_numels = [py_dict[k].numel() for k in py_key] + + if to_float: + warnings.warn('Note: the "to_float" is True, you need to ' + 'ensure that the behavior is reasonable.') + flatten_tensor = torch.cat( + [py_dict[k].flatten().float() for k in py_key]) + else: + flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key]) + + dist.all_reduce(flatten_tensor, op=dist.ReduceOp.SUM) + if op == 'mean': + flatten_tensor /= world_size + + split_tensors = [ + x.reshape(shape) for x, shape in zip( + torch.split(flatten_tensor, tensor_numels), tensor_shapes) + ] + out_dict = {k: v for k, v in zip(py_key, split_tensors)} + if isinstance(py_dict, OrderedDict): + out_dict = OrderedDict(out_dict) + return out_dict + + +def sync_random_seed(seed=None, device='cuda'): + """Make sure different ranks share the same seed. + + All workers must call this function, otherwise it will deadlock. + This method is generally used in `DistributedSampler`, + because the seed should be identical across all processes + in the distributed group. + + In distributed sampling, different ranks should sample non-overlapped + data in the dataset. Therefore, this function is used to make sure that + each rank shuffles the data indices in the same order based + on the same seed. Then different ranks could use different indices + to select non-overlapped data from the same data list. + + Args: + seed (int, Optional): The seed. Default to None. + device (str): The device where the seed will be put on. + Default to 'cuda'. + + Returns: + int: Seed to be used. + """ + if seed is None: + seed = np.random.randint(2**31) + assert isinstance(seed, int) + + rank, world_size = get_dist_info() + + if world_size == 1: + return seed + + if rank == 0: + random_num = torch.tensor(seed, dtype=torch.int32, device=device) + else: + random_num = torch.tensor(0, dtype=torch.int32, device=device) + dist.broadcast(random_num, src=0) + return random_num.item() diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/core/utils/misc.py b/cv/ocr/dbnet/pytorch/dbnet_det/core/utils/misc.py new file mode 100755 index 00000000..14cb745e --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/core/utils/misc.py @@ -0,0 +1,208 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import partial + +import numpy as np +import torch +from six.moves import map, zip + +from ..mask.structures import BitmapMasks, PolygonMasks + + +def multi_apply(func, *args, **kwargs): + """Apply function to a list of arguments. + + Note: + This function applies the ``func`` to multiple inputs and + map the multiple outputs of the ``func`` into different + list. Each list contains the same type of outputs corresponding + to different inputs. + + Args: + func (Function): A function that will be applied to a list of + arguments + + Returns: + tuple(list): A tuple containing multiple list, each list contains \ + a kind of returned results by the function + """ + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +def unmap(data, count, inds, fill=0): + """Unmap a subset of item (data) back to the original set of items (of size + count)""" + if data.dim() == 1: + ret = data.new_full((count, ), fill) + ret[inds.type(torch.bool)] = data + else: + new_size = (count, ) + data.size()[1:] + ret = data.new_full(new_size, fill) + ret[inds.type(torch.bool), :] = data + return ret + + +def mask2ndarray(mask): + """Convert Mask to ndarray.. + + Args: + mask (:obj:`BitmapMasks` or :obj:`PolygonMasks` or + torch.Tensor or np.ndarray): The mask to be converted. + + Returns: + np.ndarray: Ndarray mask of shape (n, h, w) that has been converted + """ + if isinstance(mask, (BitmapMasks, PolygonMasks)): + mask = mask.to_ndarray() + elif isinstance(mask, torch.Tensor): + mask = mask.detach().cpu().numpy() + elif not isinstance(mask, np.ndarray): + raise TypeError(f'Unsupported {type(mask)} data type') + return mask + + +def flip_tensor(src_tensor, flip_direction): + """flip tensor base on flip_direction. + + Args: + src_tensor (Tensor): input feature map, shape (B, C, H, W). + flip_direction (str): The flipping direction. Options are + 'horizontal', 'vertical', 'diagonal'. + + Returns: + out_tensor (Tensor): Flipped tensor. + """ + assert src_tensor.ndim == 4 + valid_directions = ['horizontal', 'vertical', 'diagonal'] + assert flip_direction in valid_directions + if flip_direction == 'horizontal': + out_tensor = torch.flip(src_tensor, [3]) + elif flip_direction == 'vertical': + out_tensor = torch.flip(src_tensor, [2]) + else: + out_tensor = torch.flip(src_tensor, [2, 3]) + return out_tensor + + +def select_single_mlvl(mlvl_tensors, batch_id, detach=True): + """Extract a multi-scale single image tensor from a multi-scale batch + tensor based on batch index. + + Note: The default value of detach is True, because the proposal gradient + needs to be detached during the training of the two-stage model. E.g + Cascade Mask R-CNN. + + Args: + mlvl_tensors (list[Tensor]): Batch tensor for all scale levels, + each is a 4D-tensor. + batch_id (int): Batch index. + detach (bool): Whether detach gradient. Default True. + + Returns: + list[Tensor]: Multi-scale single image tensor. + """ + assert isinstance(mlvl_tensors, (list, tuple)) + num_levels = len(mlvl_tensors) + + if detach: + mlvl_tensor_list = [ + mlvl_tensors[i][batch_id].detach() for i in range(num_levels) + ] + else: + mlvl_tensor_list = [ + mlvl_tensors[i][batch_id] for i in range(num_levels) + ] + return mlvl_tensor_list + + +def filter_scores_and_topk(scores, score_thr, topk, results=None): + """Filter results using score threshold and topk candidates. + + Args: + scores (Tensor): The scores, shape (num_bboxes, K). + score_thr (float): The score filter threshold. + topk (int): The number of topk candidates. + results (dict or list or Tensor, Optional): The results to + which the filtering rule is to be applied. The shape + of each item is (num_bboxes, N). + + Returns: + tuple: Filtered results + + - scores (Tensor): The scores after being filtered, \ + shape (num_bboxes_filtered, ). + - labels (Tensor): The class labels, shape \ + (num_bboxes_filtered, ). + - anchor_idxs (Tensor): The anchor indexes, shape \ + (num_bboxes_filtered, ). + - filtered_results (dict or list or Tensor, Optional): \ + The filtered results. The shape of each item is \ + (num_bboxes_filtered, N). + """ + valid_mask = scores > score_thr + scores = scores[valid_mask] + valid_idxs = torch.nonzero(valid_mask) + + num_topk = min(topk, valid_idxs.size(0)) + # torch.sort is actually faster than .topk (at least on GPUs) + scores, idxs = scores.sort(descending=True) + scores = scores[:num_topk] + topk_idxs = valid_idxs[idxs[:num_topk]] + keep_idxs, labels = topk_idxs.unbind(dim=1) + + filtered_results = None + if results is not None: + if isinstance(results, dict): + filtered_results = {k: v[keep_idxs] for k, v in results.items()} + elif isinstance(results, list): + filtered_results = [result[keep_idxs] for result in results] + elif isinstance(results, torch.Tensor): + filtered_results = results[keep_idxs] + else: + raise NotImplementedError(f'Only supports dict or list or Tensor, ' + f'but get {type(results)}.') + return scores, labels, keep_idxs, filtered_results + + +def center_of_mass(mask, esp=1e-6): + """Calculate the centroid coordinates of the mask. + + Args: + mask (Tensor): The mask to be calculated, shape (h, w). + esp (float): Avoid dividing by zero. Default: 1e-6. + + Returns: + tuple[Tensor]: the coordinates of the center point of the mask. + + - center_h (Tensor): the center point of the height. + - center_w (Tensor): the center point of the width. + """ + h, w = mask.shape + grid_h = torch.arange(h, device=mask.device)[:, None] + grid_w = torch.arange(w, device=mask.device) + normalizer = mask.sum().float().clamp(min=esp) + center_h = (mask * grid_h).sum() / normalizer + center_w = (mask * grid_w).sum() / normalizer + return center_h, center_w + + +def generate_coordinate(featmap_sizes, device='cuda'): + """Generate the coordinate. + + Args: + featmap_sizes (tuple): The feature to be calculated, + of shape (N, C, W, H). + device (str): The device where the feature will be put on. + Returns: + coord_feat (Tensor): The coordinate feature, of shape (N, 2, W, H). + """ + + x_range = torch.linspace(-1, 1, featmap_sizes[-1], device=device) + y_range = torch.linspace(-1, 1, featmap_sizes[-2], device=device) + y, x = torch.meshgrid(y_range, x_range) + y = y.expand([featmap_sizes[0], 1, -1, -1]) + x = x.expand([featmap_sizes[0], 1, -1, -1]) + coord_feat = torch.cat([x, y], 1) + + return coord_feat diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/core/visualization/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_det/core/visualization/__init__.py new file mode 100755 index 00000000..cc1e786a --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/core/visualization/__init__.py @@ -0,0 +1,9 @@ +# # Copyright (c) OpenMMLab. All rights reserved. +from .image import (color_val_matplotlib, imshow_det_bboxes, + imshow_gt_det_bboxes) +from .palette import get_palette, palette_val + +__all__ = [ + 'imshow_det_bboxes', 'imshow_gt_det_bboxes', 'color_val_matplotlib', + 'palette_val', 'get_palette' +] diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/core/visualization/image.py b/cv/ocr/dbnet/pytorch/dbnet_det/core/visualization/image.py new file mode 100755 index 00000000..b732d718 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/core/visualization/image.py @@ -0,0 +1,559 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import cv2 +import matplotlib.pyplot as plt +import dbnet_cv +import numpy as np +import pycocotools.mask as mask_util +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon + +from dbnet_det.core.evaluation.panoptic_utils import INSTANCE_OFFSET +from ..mask.structures import bitmap_to_polygon +from ..utils import mask2ndarray +from .palette import get_palette, palette_val + +__all__ = [ + 'color_val_matplotlib', 'draw_masks', 'draw_bboxes', 'draw_labels', + 'imshow_det_bboxes', 'imshow_gt_det_bboxes' +] + +EPS = 1e-2 + + +def color_val_matplotlib(color): + """Convert various input in BGR order to normalized RGB matplotlib color + tuples. + + Args: + color (:obj`Color` | str | tuple | int | ndarray): Color inputs. + + Returns: + tuple[float]: A tuple of 3 normalized floats indicating RGB channels. + """ + color = dbnet_cv.color_val(color) + color = [color / 255 for color in color[::-1]] + return tuple(color) + + +def _get_adaptive_scales(areas, min_area=800, max_area=30000): + """Get adaptive scales according to areas. + + The scale range is [0.5, 1.0]. When the area is less than + ``'min_area'``, the scale is 0.5 while the area is larger than + ``'max_area'``, the scale is 1.0. + + Args: + areas (ndarray): The areas of bboxes or masks with the + shape of (n, ). + min_area (int): Lower bound areas for adaptive scales. + Default: 800. + max_area (int): Upper bound areas for adaptive scales. + Default: 30000. + + Returns: + ndarray: The adaotive scales with the shape of (n, ). + """ + scales = 0.5 + (areas - min_area) / (max_area - min_area) + scales = np.clip(scales, 0.5, 1.0) + return scales + + +def _get_bias_color(base, max_dist=30): + """Get different colors for each masks. + + Get different colors for each masks by adding a bias + color to the base category color. + Args: + base (ndarray): The base category color with the shape + of (3, ). + max_dist (int): The max distance of bias. Default: 30. + + Returns: + ndarray: The new color for a mask with the shape of (3, ). + """ + new_color = base + np.random.randint( + low=-max_dist, high=max_dist + 1, size=3) + return np.clip(new_color, 0, 255, new_color) + + +def draw_bboxes(ax, bboxes, color='g', alpha=0.8, thickness=2): + """Draw bounding boxes on the axes. + + Args: + ax (matplotlib.Axes): The input axes. + bboxes (ndarray): The input bounding boxes with the shape + of (n, 4). + color (list[tuple] | matplotlib.color): the colors for each + bounding boxes. + alpha (float): Transparency of bounding boxes. Default: 0.8. + thickness (int): Thickness of lines. Default: 2. + + Returns: + matplotlib.Axes: The result axes. + """ + polygons = [] + for i, bbox in enumerate(bboxes): + bbox_int = bbox.astype(np.int32) + poly = [[bbox_int[0], bbox_int[1]], [bbox_int[0], bbox_int[3]], + [bbox_int[2], bbox_int[3]], [bbox_int[2], bbox_int[1]]] + np_poly = np.array(poly).reshape((4, 2)) + polygons.append(Polygon(np_poly)) + p = PatchCollection( + polygons, + facecolor='none', + edgecolors=color, + linewidths=thickness, + alpha=alpha) + ax.add_collection(p) + + return ax + + +def draw_labels(ax, + labels, + positions, + scores=None, + class_names=None, + color='w', + font_size=8, + scales=None, + horizontal_alignment='left'): + """Draw labels on the axes. + + Args: + ax (matplotlib.Axes): The input axes. + labels (ndarray): The labels with the shape of (n, ). + positions (ndarray): The positions to draw each labels. + scores (ndarray): The scores for each labels. + class_names (list[str]): The class names. + color (list[tuple] | matplotlib.color): The colors for labels. + font_size (int): Font size of texts. Default: 8. + scales (list[float]): Scales of texts. Default: None. + horizontal_alignment (str): The horizontal alignment method of + texts. Default: 'left'. + + Returns: + matplotlib.Axes: The result axes. + """ + for i, (pos, label) in enumerate(zip(positions, labels)): + label_text = class_names[ + label] if class_names is not None else f'class {label}' + if scores is not None: + label_text += f'|{scores[i]:.02f}' + text_color = color[i] if isinstance(color, list) else color + + font_size_mask = font_size if scales is None else font_size * scales[i] + ax.text( + pos[0], + pos[1], + f'{label_text}', + bbox={ + 'facecolor': 'black', + 'alpha': 0.8, + 'pad': 0.7, + 'edgecolor': 'none' + }, + color=text_color, + fontsize=font_size_mask, + verticalalignment='top', + horizontalalignment=horizontal_alignment) + + return ax + + +def draw_masks(ax, img, masks, color=None, with_edge=True, alpha=0.8): + """Draw masks on the image and their edges on the axes. + + Args: + ax (matplotlib.Axes): The input axes. + img (ndarray): The image with the shape of (3, h, w). + masks (ndarray): The masks with the shape of (n, h, w). + color (ndarray): The colors for each masks with the shape + of (n, 3). + with_edge (bool): Whether to draw edges. Default: True. + alpha (float): Transparency of bounding boxes. Default: 0.8. + + Returns: + matplotlib.Axes: The result axes. + ndarray: The result image. + """ + taken_colors = set([0, 0, 0]) + if color is None: + random_colors = np.random.randint(0, 255, (masks.size(0), 3)) + color = [tuple(c) for c in random_colors] + color = np.array(color, dtype=np.uint8) + polygons = [] + for i, mask in enumerate(masks): + if with_edge: + contours, _ = bitmap_to_polygon(mask) + polygons += [Polygon(c) for c in contours] + + color_mask = color[i] + while tuple(color_mask) in taken_colors: + color_mask = _get_bias_color(color_mask) + taken_colors.add(tuple(color_mask)) + + mask = mask.astype(bool) + img[mask] = img[mask] * (1 - alpha) + color_mask * alpha + + p = PatchCollection( + polygons, facecolor='none', edgecolors='w', linewidths=1, alpha=0.8) + ax.add_collection(p) + + return ax, img + + +def imshow_det_bboxes(img, + bboxes=None, + labels=None, + segms=None, + class_names=None, + score_thr=0, + bbox_color='green', + text_color='green', + mask_color=None, + thickness=2, + font_size=8, + win_name='', + show=True, + wait_time=0, + out_file=None): + """Draw bboxes and class labels (with scores) on an image. + + Args: + img (str | ndarray): The image to be displayed. + bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or + (n, 5). + labels (ndarray): Labels of bboxes. + segms (ndarray | None): Masks, shaped (n,h,w) or None. + class_names (list[str]): Names of each classes. + score_thr (float): Minimum score of bboxes to be shown. Default: 0. + bbox_color (list[tuple] | tuple | str | None): Colors of bbox lines. + If a single color is given, it will be applied to all classes. + The tuple of color should be in RGB order. Default: 'green'. + text_color (list[tuple] | tuple | str | None): Colors of texts. + If a single color is given, it will be applied to all classes. + The tuple of color should be in RGB order. Default: 'green'. + mask_color (list[tuple] | tuple | str | None, optional): Colors of + masks. If a single color is given, it will be applied to all + classes. The tuple of color should be in RGB order. + Default: None. + thickness (int): Thickness of lines. Default: 2. + font_size (int): Font size of texts. Default: 13. + show (bool): Whether to show the image. Default: True. + win_name (str): The window name. Default: ''. + wait_time (float): Value of waitKey param. Default: 0. + out_file (str, optional): The filename to write the image. + Default: None. + + Returns: + ndarray: The image with bboxes drawn on it. + """ + assert bboxes is None or bboxes.ndim == 2, \ + f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.' + assert labels.ndim == 1, \ + f' labels ndim should be 1, but its ndim is {labels.ndim}.' + assert bboxes is None or bboxes.shape[1] == 4 or bboxes.shape[1] == 5, \ + f' bboxes.shape[1] should be 4 or 5, but its {bboxes.shape[1]}.' + assert bboxes is None or bboxes.shape[0] <= labels.shape[0], \ + 'labels.shape[0] should not be less than bboxes.shape[0].' + assert segms is None or segms.shape[0] == labels.shape[0], \ + 'segms.shape[0] and labels.shape[0] should have the same length.' + assert segms is not None or bboxes is not None, \ + 'segms and bboxes should not be None at the same time.' + + img = dbnet_cv.imread(img).astype(np.uint8) + + if score_thr > 0: + assert bboxes is not None and bboxes.shape[1] == 5 + scores = bboxes[:, -1] + inds = scores > score_thr + bboxes = bboxes[inds, :] + labels = labels[inds] + if segms is not None: + segms = segms[inds, ...] + + img = dbnet_cv.bgr2rgb(img) + width, height = img.shape[1], img.shape[0] + img = np.ascontiguousarray(img) + + fig = plt.figure(win_name, frameon=False) + plt.title(win_name) + canvas = fig.canvas + dpi = fig.get_dpi() + # add a small EPS to avoid precision lost due to matplotlib's truncation + # (https://github.com/matplotlib/matplotlib/issues/15363) + fig.set_size_inches((width + EPS) / dpi, (height + EPS) / dpi) + + # remove white edges by set subplot margin + plt.subplots_adjust(left=0, right=1, bottom=0, top=1) + ax = plt.gca() + ax.axis('off') + + max_label = int(max(labels) if len(labels) > 0 else 0) + text_palette = palette_val(get_palette(text_color, max_label + 1)) + text_colors = [text_palette[label] for label in labels] + + num_bboxes = 0 + if bboxes is not None: + num_bboxes = bboxes.shape[0] + bbox_palette = palette_val(get_palette(bbox_color, max_label + 1)) + colors = [bbox_palette[label] for label in labels[:num_bboxes]] + draw_bboxes(ax, bboxes, colors, alpha=0.8, thickness=thickness) + + horizontal_alignment = 'left' + positions = bboxes[:, :2].astype(np.int32) + thickness + areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0]) + scales = _get_adaptive_scales(areas) + scores = bboxes[:, 4] if bboxes.shape[1] == 5 else None + draw_labels( + ax, + labels[:num_bboxes], + positions, + scores=scores, + class_names=class_names, + color=text_colors, + font_size=font_size, + scales=scales, + horizontal_alignment=horizontal_alignment) + + if segms is not None: + mask_palette = get_palette(mask_color, max_label + 1) + colors = [mask_palette[label] for label in labels] + colors = np.array(colors, dtype=np.uint8) + draw_masks(ax, img, segms, colors, with_edge=True) + + if num_bboxes < segms.shape[0]: + segms = segms[num_bboxes:] + horizontal_alignment = 'center' + areas = [] + positions = [] + for mask in segms: + _, _, stats, centroids = cv2.connectedComponentsWithStats( + mask.astype(np.uint8), connectivity=8) + largest_id = np.argmax(stats[1:, -1]) + 1 + positions.append(centroids[largest_id]) + areas.append(stats[largest_id, -1]) + areas = np.stack(areas, axis=0) + scales = _get_adaptive_scales(areas) + draw_labels( + ax, + labels[num_bboxes:], + positions, + class_names=class_names, + color=text_colors, + font_size=font_size, + scales=scales, + horizontal_alignment=horizontal_alignment) + + plt.imshow(img) + + stream, _ = canvas.print_to_buffer() + buffer = np.frombuffer(stream, dtype='uint8') + img_rgba = buffer.reshape(height, width, 4) + rgb, alpha = np.split(img_rgba, [3], axis=2) + img = rgb.astype('uint8') + img = dbnet_cv.rgb2bgr(img) + + if show: + # We do not use cv2 for display because in some cases, opencv will + # conflict with Qt, it will output a warning: Current thread + # is not the object's thread. You can refer to + # https://github.com/opencv/opencv-python/issues/46 for details + if wait_time == 0: + plt.show() + else: + plt.show(block=False) + plt.pause(wait_time) + if out_file is not None: + dbnet_cv.imwrite(img, out_file) + + plt.close() + + return img + + +def imshow_gt_det_bboxes(img, + annotation, + result, + class_names=None, + score_thr=0, + gt_bbox_color=(61, 102, 255), + gt_text_color=(200, 200, 200), + gt_mask_color=(61, 102, 255), + det_bbox_color=(241, 101, 72), + det_text_color=(200, 200, 200), + det_mask_color=(241, 101, 72), + thickness=2, + font_size=13, + win_name='', + show=True, + wait_time=0, + out_file=None, + overlay_gt_pred=True): + """General visualization GT and result function. + + Args: + img (str | ndarray): The image to be displayed. + annotation (dict): Ground truth annotations where contain keys of + 'gt_bboxes' and 'gt_labels' or 'gt_masks'. + result (tuple[list] | list): The detection result, can be either + (bbox, segm) or just bbox. + class_names (list[str]): Names of each classes. + score_thr (float): Minimum score of bboxes to be shown. Default: 0. + gt_bbox_color (list[tuple] | tuple | str | None): Colors of bbox lines. + If a single color is given, it will be applied to all classes. + The tuple of color should be in RGB order. Default: (61, 102, 255). + gt_text_color (list[tuple] | tuple | str | None): Colors of texts. + If a single color is given, it will be applied to all classes. + The tuple of color should be in RGB order. Default: (200, 200, 200). + gt_mask_color (list[tuple] | tuple | str | None, optional): Colors of + masks. If a single color is given, it will be applied to all classes. + The tuple of color should be in RGB order. Default: (61, 102, 255). + det_bbox_color (list[tuple] | tuple | str | None):Colors of bbox lines. + If a single color is given, it will be applied to all classes. + The tuple of color should be in RGB order. Default: (241, 101, 72). + det_text_color (list[tuple] | tuple | str | None):Colors of texts. + If a single color is given, it will be applied to all classes. + The tuple of color should be in RGB order. Default: (200, 200, 200). + det_mask_color (list[tuple] | tuple | str | None, optional): Color of + masks. If a single color is given, it will be applied to all classes. + The tuple of color should be in RGB order. Default: (241, 101, 72). + thickness (int): Thickness of lines. Default: 2. + font_size (int): Font size of texts. Default: 13. + win_name (str): The window name. Default: ''. + show (bool): Whether to show the image. Default: True. + wait_time (float): Value of waitKey param. Default: 0. + out_file (str, optional): The filename to write the image. + Default: None. + overlay_gt_pred (bool): Whether to plot gts and predictions on the + same image. If False, predictions and gts will be plotted on two same + image which will be concatenated in vertical direction. The image + above is drawn with gt, and the image below is drawn with the + prediction result. Default: True. + + Returns: + ndarray: The image with bboxes or masks drawn on it. + """ + assert 'gt_bboxes' in annotation + assert 'gt_labels' in annotation + assert isinstance(result, (tuple, list, dict)), 'Expected ' \ + f'tuple or list or dict, but get {type(result)}' + + gt_bboxes = annotation['gt_bboxes'] + gt_labels = annotation['gt_labels'] + gt_masks = annotation.get('gt_masks', None) + if gt_masks is not None: + gt_masks = mask2ndarray(gt_masks) + + gt_seg = annotation.get('gt_semantic_seg', None) + if gt_seg is not None: + pad_value = 255 # the padding value of gt_seg + sem_labels = np.unique(gt_seg) + all_labels = np.concatenate((gt_labels, sem_labels), axis=0) + all_labels, counts = np.unique(all_labels, return_counts=True) + stuff_labels = all_labels[np.logical_and(counts < 2, + all_labels != pad_value)] + stuff_masks = gt_seg[None] == stuff_labels[:, None, None] + gt_labels = np.concatenate((gt_labels, stuff_labels), axis=0) + gt_masks = np.concatenate((gt_masks, stuff_masks.astype(np.uint8)), + axis=0) + # If you need to show the bounding boxes, + # please comment the following line + # gt_bboxes = None + + img = dbnet_cv.imread(img) + + img_with_gt = imshow_det_bboxes( + img, + gt_bboxes, + gt_labels, + gt_masks, + class_names=class_names, + bbox_color=gt_bbox_color, + text_color=gt_text_color, + mask_color=gt_mask_color, + thickness=thickness, + font_size=font_size, + win_name=win_name, + show=False) + + if not isinstance(result, dict): + if isinstance(result, tuple): + bbox_result, segm_result = result + if isinstance(segm_result, tuple): + segm_result = segm_result[0] # ms rcnn + else: + bbox_result, segm_result = result, None + + bboxes = np.vstack(bbox_result) + labels = [ + np.full(bbox.shape[0], i, dtype=np.int32) + for i, bbox in enumerate(bbox_result) + ] + labels = np.concatenate(labels) + + segms = None + if segm_result is not None and len(labels) > 0: # non empty + segms = dbnet_cv.concat_list(segm_result) + segms = mask_util.decode(segms) + segms = segms.transpose(2, 0, 1) + else: + assert class_names is not None, 'We need to know the number ' \ + 'of classes.' + VOID = len(class_names) + bboxes = None + pan_results = result['pan_results'] + # keep objects ahead + ids = np.unique(pan_results)[::-1] + legal_indices = ids != VOID + ids = ids[legal_indices] + labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64) + segms = (pan_results[None] == ids[:, None, None]) + + if overlay_gt_pred: + img = imshow_det_bboxes( + img_with_gt, + bboxes, + labels, + segms=segms, + class_names=class_names, + score_thr=score_thr, + bbox_color=det_bbox_color, + text_color=det_text_color, + mask_color=det_mask_color, + thickness=thickness, + font_size=font_size, + win_name=win_name, + show=show, + wait_time=wait_time, + out_file=out_file) + else: + img_with_det = imshow_det_bboxes( + img, + bboxes, + labels, + segms=segms, + class_names=class_names, + score_thr=score_thr, + bbox_color=det_bbox_color, + text_color=det_text_color, + mask_color=det_mask_color, + thickness=thickness, + font_size=font_size, + win_name=win_name, + show=False) + img = np.concatenate([img_with_gt, img_with_det], axis=0) + + plt.imshow(img) + if show: + if wait_time == 0: + plt.show() + else: + plt.show(block=False) + plt.pause(wait_time) + if out_file is not None: + dbnet_cv.imwrite(img, out_file) + plt.close() + + return img diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/core/visualization/palette.py b/cv/ocr/dbnet/pytorch/dbnet_det/core/visualization/palette.py new file mode 100755 index 00000000..b14a78dd --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/core/visualization/palette.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dbnet_cv +import numpy as np + + +def palette_val(palette): + """Convert palette to matplotlib palette. + Args: + palette List[tuple]: A list of color tuples. + Returns: + List[tuple[float]]: A list of RGB matplotlib color tuples. + """ + new_palette = [] + for color in palette: + color = [c / 255 for c in color] + new_palette.append(tuple(color)) + return new_palette + + +def get_palette(palette, num_classes): + """Get palette from various inputs. + Args: + palette (list[tuple] | str | tuple | :obj:`Color`): palette inputs. + num_classes (int): the number of classes. + Returns: + list[tuple[int]]: A list of color tuples. + """ + assert isinstance(num_classes, int) + + if isinstance(palette, list): + dataset_palette = palette + elif isinstance(palette, tuple): + dataset_palette = [palette] * num_classes + elif palette == 'random' or palette is None: + state = np.random.get_state() + # random color + np.random.seed(42) + palette = np.random.randint(0, 256, size=(num_classes, 3)) + np.random.set_state(state) + dataset_palette = [tuple(c) for c in palette] + elif palette == 'coco': + from dbnet_det.datasets import CocoDataset, CocoPanopticDataset + dataset_palette = CocoDataset.PALETTE + if len(dataset_palette) < num_classes: + dataset_palette = CocoPanopticDataset.PALETTE + elif palette == 'citys': + from dbnet_det.datasets import CityscapesDataset + dataset_palette = CityscapesDataset.PALETTE + elif palette == 'voc': + from dbnet_det.datasets import VOCDataset + dataset_palette = VOCDataset.PALETTE + elif dbnet_cv.is_str(palette): + dataset_palette = [dbnet_cv.color_val(palette)[::-1]] * num_classes + else: + raise TypeError(f'Invalid type for palette: {type(palette)}') + + assert len(dataset_palette) >= num_classes, \ + 'The length of palette should not be less than `num_classes`.' + return dataset_palette \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/__init__.py new file mode 100755 index 00000000..2fe2b278 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset +# from .cityscapes import CityscapesDataset +from .coco import CocoDataset +# from .coco_panoptic import CocoPanopticDataset +from .custom import CustomDataset +from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset, + MultiImageMixDataset, RepeatDataset) +# from .deepfashion import DeepFashionDataset +# from .lvis import LVISDataset, LVISV1Dataset, LVISV05Dataset +# from .openimages import OpenImagesChallengeDataset, OpenImagesDataset +from .samplers import DistributedGroupSampler, DistributedSampler, GroupSampler +from .utils import (NumClassCheckHook, get_loading_pipeline, + replace_ImageToTensor) +# from .voc import VOCDataset +# from .wider_face import WIDERFaceDataset +# from .xml_style import XMLDataset + +# __all__ = [ +# 'CustomDataset', 'XMLDataset', 'CocoDataset', 'DeepFashionDataset', +# 'VOCDataset', 'CityscapesDataset', 'LVISDataset', 'LVISV05Dataset', +# 'LVISV1Dataset', 'GroupSampler', 'DistributedGroupSampler', +# 'DistributedSampler', 'build_dataloader', 'ConcatDataset', 'RepeatDataset', +# 'ClassBalancedDataset', 'WIDERFaceDataset', 'DATASETS', 'PIPELINES', +# 'build_dataset', 'replace_ImageToTensor', 'get_loading_pipeline', +# 'NumClassCheckHook', 'CocoPanopticDataset', 'MultiImageMixDataset', +# 'OpenImagesDataset', 'OpenImagesChallengeDataset' +# ] diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/api_wrappers/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/api_wrappers/__init__.py new file mode 100755 index 00000000..af855759 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/api_wrappers/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .coco_api import COCO, COCOeval +from .panoptic_evaluation import pq_compute_multi_core, pq_compute_single_core + +__all__ = [ + 'COCO', 'COCOeval', 'pq_compute_multi_core', 'pq_compute_single_core' +] diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/api_wrappers/coco_api.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/api_wrappers/coco_api.py new file mode 100755 index 00000000..eef6341e --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/api_wrappers/coco_api.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This file add snake case alias for coco api + +import warnings + +import pycocotools +from pycocotools.coco import COCO as _COCO +from pycocotools.cocoeval import COCOeval as _COCOeval + + +class COCO(_COCO): + """This class is almost the same as official pycocotools package. + + It implements some snake case function aliases. So that the COCO class has + the same interface as LVIS class. + """ + + def __init__(self, annotation_file=None): + if getattr(pycocotools, '__version__', '0') >= '12.0.2': + warnings.warn( + 'mmpycocotools is deprecated. Please install official pycocotools by "pip install pycocotools"', # noqa: E501 + UserWarning) + super().__init__(annotation_file=annotation_file) + self.img_ann_map = self.imgToAnns + self.cat_img_map = self.catToImgs + + def get_ann_ids(self, img_ids=[], cat_ids=[], area_rng=[], iscrowd=None): + return self.getAnnIds(img_ids, cat_ids, area_rng, iscrowd) + + def get_cat_ids(self, cat_names=[], sup_names=[], cat_ids=[]): + return self.getCatIds(cat_names, sup_names, cat_ids) + + def get_img_ids(self, img_ids=[], cat_ids=[]): + return self.getImgIds(img_ids, cat_ids) + + def load_anns(self, ids): + return self.loadAnns(ids) + + def load_cats(self, ids): + return self.loadCats(ids) + + def load_imgs(self, ids): + return self.loadImgs(ids) + + +# just for the ease of import +COCOeval = _COCOeval diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/api_wrappers/panoptic_evaluation.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/api_wrappers/panoptic_evaluation.py new file mode 100755 index 00000000..fe4fd322 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/api_wrappers/panoptic_evaluation.py @@ -0,0 +1,228 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Copyright (c) 2018, Alexander Kirillov +# This file supports `file_client` for `panopticapi`, +# the source code is copied from `panopticapi`, +# only the way to load the gt images is modified. +import multiprocessing +import os + +import dbnet_cv +import numpy as np + +try: + from panopticapi.evaluation import OFFSET, VOID, PQStat + from panopticapi.utils import rgb2id +except ImportError: + PQStat = None + rgb2id = None + VOID = 0 + OFFSET = 256 * 256 * 256 + + +def pq_compute_single_core(proc_id, + annotation_set, + gt_folder, + pred_folder, + categories, + file_client=None, + print_log=False): + """The single core function to evaluate the metric of Panoptic + Segmentation. + + Same as the function with the same name in `panopticapi`. Only the function + to load the images is changed to use the file client. + + Args: + proc_id (int): The id of the mini process. + gt_folder (str): The path of the ground truth images. + pred_folder (str): The path of the prediction images. + categories (str): The categories of the dataset. + file_client (object): The file client of the dataset. If None, + the backend will be set to `disk`. + print_log (bool): Whether to print the log. Defaults to False. + """ + if PQStat is None: + raise RuntimeError( + 'panopticapi is not installed, please install it by: ' + 'pip install git+https://github.com/cocodataset/' + 'panopticapi.git.') + + if file_client is None: + file_client_args = dict(backend='disk') + file_client = dbnet_cv.FileClient(**file_client_args) + + pq_stat = PQStat() + + idx = 0 + for gt_ann, pred_ann in annotation_set: + if print_log and idx % 100 == 0: + print('Core: {}, {} from {} images processed'.format( + proc_id, idx, len(annotation_set))) + idx += 1 + # The gt images can be on the local disk or `ceph`, so we use + # file_client here. + img_bytes = file_client.get( + os.path.join(gt_folder, gt_ann['file_name'])) + pan_gt = dbnet_cv.imfrombytes(img_bytes, flag='color', channel_order='rgb') + pan_gt = rgb2id(pan_gt) + + # The predictions can only be on the local dist now. + pan_pred = dbnet_cv.imread( + os.path.join(pred_folder, pred_ann['file_name']), + flag='color', + channel_order='rgb') + pan_pred = rgb2id(pan_pred) + + gt_segms = {el['id']: el for el in gt_ann['segments_info']} + pred_segms = {el['id']: el for el in pred_ann['segments_info']} + + # predicted segments area calculation + prediction sanity checks + pred_labels_set = set(el['id'] for el in pred_ann['segments_info']) + labels, labels_cnt = np.unique(pan_pred, return_counts=True) + for label, label_cnt in zip(labels, labels_cnt): + if label not in pred_segms: + if label == VOID: + continue + raise KeyError( + 'In the image with ID {} segment with ID {} is ' + 'presented in PNG and not presented in JSON.'.format( + gt_ann['image_id'], label)) + pred_segms[label]['area'] = label_cnt + pred_labels_set.remove(label) + if pred_segms[label]['category_id'] not in categories: + raise KeyError( + 'In the image with ID {} segment with ID {} has ' + 'unknown category_id {}.'.format( + gt_ann['image_id'], label, + pred_segms[label]['category_id'])) + if len(pred_labels_set) != 0: + raise KeyError( + 'In the image with ID {} the following segment IDs {} ' + 'are presented in JSON and not presented in PNG.'.format( + gt_ann['image_id'], list(pred_labels_set))) + + # confusion matrix calculation + pan_gt_pred = pan_gt.astype(np.uint64) * OFFSET + pan_pred.astype( + np.uint64) + gt_pred_map = {} + labels, labels_cnt = np.unique(pan_gt_pred, return_counts=True) + for label, intersection in zip(labels, labels_cnt): + gt_id = label // OFFSET + pred_id = label % OFFSET + gt_pred_map[(gt_id, pred_id)] = intersection + + # count all matched pairs + gt_matched = set() + pred_matched = set() + for label_tuple, intersection in gt_pred_map.items(): + gt_label, pred_label = label_tuple + if gt_label not in gt_segms: + continue + if pred_label not in pred_segms: + continue + if gt_segms[gt_label]['iscrowd'] == 1: + continue + if gt_segms[gt_label]['category_id'] != pred_segms[pred_label][ + 'category_id']: + continue + + union = pred_segms[pred_label]['area'] + gt_segms[gt_label][ + 'area'] - intersection - gt_pred_map.get((VOID, pred_label), 0) + iou = intersection / union + if iou > 0.5: + pq_stat[gt_segms[gt_label]['category_id']].tp += 1 + pq_stat[gt_segms[gt_label]['category_id']].iou += iou + gt_matched.add(gt_label) + pred_matched.add(pred_label) + + # count false positives + crowd_labels_dict = {} + for gt_label, gt_info in gt_segms.items(): + if gt_label in gt_matched: + continue + # crowd segments are ignored + if gt_info['iscrowd'] == 1: + crowd_labels_dict[gt_info['category_id']] = gt_label + continue + pq_stat[gt_info['category_id']].fn += 1 + + # count false positives + for pred_label, pred_info in pred_segms.items(): + if pred_label in pred_matched: + continue + # intersection of the segment with VOID + intersection = gt_pred_map.get((VOID, pred_label), 0) + # plus intersection with corresponding CROWD region if it exists + if pred_info['category_id'] in crowd_labels_dict: + intersection += gt_pred_map.get( + (crowd_labels_dict[pred_info['category_id']], pred_label), + 0) + # predicted segment is ignored if more than half of + # the segment correspond to VOID and CROWD regions + if intersection / pred_info['area'] > 0.5: + continue + pq_stat[pred_info['category_id']].fp += 1 + + if print_log: + print('Core: {}, all {} images processed'.format( + proc_id, len(annotation_set))) + return pq_stat + + +def pq_compute_multi_core(matched_annotations_list, + gt_folder, + pred_folder, + categories, + file_client=None, + nproc=32): + """Evaluate the metrics of Panoptic Segmentation with multithreading. + + Same as the function with the same name in `panopticapi`. + + Args: + matched_annotations_list (list): The matched annotation list. Each + element is a tuple of annotations of the same image with the + format (gt_anns, pred_anns). + gt_folder (str): The path of the ground truth images. + pred_folder (str): The path of the prediction images. + categories (str): The categories of the dataset. + file_client (object): The file client of the dataset. If None, + the backend will be set to `disk`. + nproc (int): Number of processes for panoptic quality computing. + Defaults to 32. When `nproc` exceeds the number of cpu cores, + the number of cpu cores is used. + """ + if PQStat is None: + raise RuntimeError( + 'panopticapi is not installed, please install it by: ' + 'pip install git+https://github.com/cocodataset/' + 'panopticapi.git.') + + if file_client is None: + file_client_args = dict(backend='disk') + file_client = dbnet_cv.FileClient(**file_client_args) + + cpu_num = min(nproc, multiprocessing.cpu_count()) + + annotations_split = np.array_split(matched_annotations_list, cpu_num) + print('Number of cores: {}, images per core: {}'.format( + cpu_num, len(annotations_split[0]))) + workers = multiprocessing.Pool(processes=cpu_num) + processes = [] + for proc_id, annotation_set in enumerate(annotations_split): + p = workers.apply_async(pq_compute_single_core, + (proc_id, annotation_set, gt_folder, + pred_folder, categories, file_client)) + processes.append(p) + + # Close the process pool, otherwise it will lead to memory + # leaking problems. + workers.close() + workers.join() + + pq_stat = PQStat() + for p in processes: + pq_stat += p.get() + + return pq_stat diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/builder.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/builder.py new file mode 100755 index 00000000..6515adbc --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/builder.py @@ -0,0 +1,215 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import platform +import random +import warnings +from functools import partial + +import numpy as np +import torch +from dbnet_cv.parallel import collate +from dbnet_cv.runner import get_dist_info +from dbnet_cv.utils import TORCH_VERSION, Registry, build_from_cfg, digit_version +from torch.utils.data import DataLoader + +from .samplers import (ClassAwareSampler, DistributedGroupSampler, + DistributedSampler, GroupSampler, InfiniteBatchSampler, + InfiniteGroupBatchSampler) + +if platform.system() != 'Windows': + # https://github.com/pytorch/pytorch/issues/973 + import resource + rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) + base_soft_limit = rlimit[0] + hard_limit = rlimit[1] + soft_limit = min(max(4096, base_soft_limit), hard_limit) + resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) + +DATASETS = Registry('dataset') +PIPELINES = Registry('pipeline') + + +def _concat_dataset(cfg, default_args=None): + from .dataset_wrappers import ConcatDataset + ann_files = cfg['ann_file'] + img_prefixes = cfg.get('img_prefix', None) + seg_prefixes = cfg.get('seg_prefix', None) + proposal_files = cfg.get('proposal_file', None) + separate_eval = cfg.get('separate_eval', True) + + datasets = [] + num_dset = len(ann_files) + for i in range(num_dset): + data_cfg = copy.deepcopy(cfg) + # pop 'separate_eval' since it is not a valid key for common datasets. + if 'separate_eval' in data_cfg: + data_cfg.pop('separate_eval') + data_cfg['ann_file'] = ann_files[i] + if isinstance(img_prefixes, (list, tuple)): + data_cfg['img_prefix'] = img_prefixes[i] + if isinstance(seg_prefixes, (list, tuple)): + data_cfg['seg_prefix'] = seg_prefixes[i] + if isinstance(proposal_files, (list, tuple)): + data_cfg['proposal_file'] = proposal_files[i] + datasets.append(build_dataset(data_cfg, default_args)) + + return ConcatDataset(datasets, separate_eval) + + +def build_dataset(cfg, default_args=None): + from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset, + MultiImageMixDataset, RepeatDataset) + if isinstance(cfg, (list, tuple)): + dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) + elif cfg['type'] == 'ConcatDataset': + dataset = ConcatDataset( + [build_dataset(c, default_args) for c in cfg['datasets']], + cfg.get('separate_eval', True)) + elif cfg['type'] == 'RepeatDataset': + dataset = RepeatDataset( + build_dataset(cfg['dataset'], default_args), cfg['times']) + elif cfg['type'] == 'ClassBalancedDataset': + dataset = ClassBalancedDataset( + build_dataset(cfg['dataset'], default_args), cfg['oversample_thr']) + elif cfg['type'] == 'MultiImageMixDataset': + cp_cfg = copy.deepcopy(cfg) + cp_cfg['dataset'] = build_dataset(cp_cfg['dataset']) + cp_cfg.pop('type') + dataset = MultiImageMixDataset(**cp_cfg) + elif isinstance(cfg.get('ann_file'), (list, tuple)): + dataset = _concat_dataset(cfg, default_args) + else: + dataset = build_from_cfg(cfg, DATASETS, default_args) + + return dataset + + +def build_dataloader(dataset, + samples_per_gpu, + workers_per_gpu, + num_gpus=1, + dist=True, + shuffle=True, + seed=None, + runner_type='EpochBasedRunner', + persistent_workers=False, + class_aware_sampler=None, + **kwargs): + """Build PyTorch DataLoader. + + In distributed training, each GPU/process has a dataloader. + In non-distributed training, there is only one dataloader for all GPUs. + + Args: + dataset (Dataset): A PyTorch dataset. + samples_per_gpu (int): Number of training samples on each GPU, i.e., + batch size of each GPU. + workers_per_gpu (int): How many subprocesses to use for data loading + for each GPU. + num_gpus (int): Number of GPUs. Only used in non-distributed training. + dist (bool): Distributed training/test or not. Default: True. + shuffle (bool): Whether to shuffle the data at every epoch. + Default: True. + seed (int, Optional): Seed to be used. Default: None. + runner_type (str): Type of runner. Default: `EpochBasedRunner` + persistent_workers (bool): If True, the data loader will not shutdown + the worker processes after a dataset has been consumed once. + This allows to maintain the workers `Dataset` instances alive. + This argument is only valid when PyTorch>=1.7.0. Default: False. + class_aware_sampler (dict): Whether to use `ClassAwareSampler` + during training. Default: None. + kwargs: any keyword argument to be used to initialize DataLoader + + Returns: + DataLoader: A PyTorch dataloader. + """ + rank, world_size = get_dist_info() + + if dist: + # When model is :obj:`DistributedDataParallel`, + # `batch_size` of :obj:`dataloader` is the + # number of training samples on each GPU. + batch_size = samples_per_gpu + num_workers = workers_per_gpu + else: + # When model is obj:`DataParallel` + # the batch size is samples on all the GPUS + batch_size = num_gpus * samples_per_gpu + num_workers = num_gpus * workers_per_gpu + + if runner_type == 'IterBasedRunner': + # this is a batch sampler, which can yield + # a mini-batch indices each time. + # it can be used in both `DataParallel` and + # `DistributedDataParallel` + if shuffle: + batch_sampler = InfiniteGroupBatchSampler( + dataset, batch_size, world_size, rank, seed=seed) + else: + batch_sampler = InfiniteBatchSampler( + dataset, + batch_size, + world_size, + rank, + seed=seed, + shuffle=False) + batch_size = 1 + sampler = None + else: + if class_aware_sampler is not None: + # ClassAwareSampler can be used in both distributed and + # non-distributed training. + num_sample_class = class_aware_sampler.get('num_sample_class', 1) + sampler = ClassAwareSampler( + dataset, + samples_per_gpu, + world_size, + rank, + seed=seed, + num_sample_class=num_sample_class) + elif dist: + # DistributedGroupSampler will definitely shuffle the data to + # satisfy that images on each GPU are in the same group + if shuffle: + sampler = DistributedGroupSampler( + dataset, samples_per_gpu, world_size, rank, seed=seed) + else: + sampler = DistributedSampler( + dataset, world_size, rank, shuffle=False, seed=seed) + else: + sampler = GroupSampler(dataset, + samples_per_gpu) if shuffle else None + batch_sampler = None + + init_fn = partial( + worker_init_fn, num_workers=num_workers, rank=rank, + seed=seed) if seed is not None else None + + if (TORCH_VERSION != 'parrots' + and digit_version(TORCH_VERSION) >= digit_version('1.7.0')): + kwargs['persistent_workers'] = persistent_workers + elif persistent_workers is True: + warnings.warn('persistent_workers is invalid because your pytorch ' + 'version is lower than 1.7.0') + + data_loader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + batch_sampler=batch_sampler, + collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), + pin_memory=kwargs.pop('pin_memory', False), + worker_init_fn=init_fn, + **kwargs) + + return data_loader + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # The seed of each worker equals to + # num_worker * rank + worker_id + user_seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) + torch.manual_seed(worker_seed) diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/coco.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/coco.py new file mode 100755 index 00000000..ee69fc1a --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/coco.py @@ -0,0 +1,649 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import contextlib +import io +import itertools +import logging +import os.path as osp +import tempfile +import warnings +from collections import OrderedDict + +import dbnet_cv +import numpy as np +from dbnet_cv.utils import print_log +from terminaltables import AsciiTable + +from dbnet_det.core import eval_recalls +from .api_wrappers import COCO, COCOeval +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class CocoDataset(CustomDataset): + + CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', + 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', + 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush') + + PALETTE = [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), + (106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70), + (0, 0, 192), (250, 170, 30), (100, 170, 30), (220, 220, 0), + (175, 116, 175), (250, 0, 30), (165, 42, 42), (255, 77, 255), + (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157), + (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118), + (255, 179, 240), (0, 125, 92), (209, 0, 151), (188, 208, 182), + (0, 220, 176), (255, 99, 164), (92, 0, 73), (133, 129, 255), + (78, 180, 255), (0, 228, 0), (174, 255, 243), (45, 89, 255), + (134, 134, 103), (145, 148, 174), (255, 208, 186), + (197, 226, 255), (171, 134, 1), (109, 63, 54), (207, 138, 255), + (151, 0, 95), (9, 80, 61), (84, 105, 51), (74, 65, 105), + (166, 196, 102), (208, 195, 210), (255, 109, 65), (0, 143, 149), + (179, 0, 194), (209, 99, 106), (5, 121, 0), (227, 255, 205), + (147, 186, 208), (153, 69, 1), (3, 95, 161), (163, 255, 0), + (119, 0, 170), (0, 182, 199), (0, 165, 120), (183, 130, 88), + (95, 32, 0), (130, 114, 135), (110, 129, 133), (166, 74, 118), + (219, 142, 185), (79, 210, 114), (178, 90, 62), (65, 70, 15), + (127, 167, 115), (59, 105, 106), (142, 108, 45), (196, 172, 0), + (95, 54, 80), (128, 76, 255), (201, 57, 1), (246, 0, 122), + (191, 162, 208)] + + def load_annotations(self, ann_file): + """Load annotation from COCO style annotation file. + + Args: + ann_file (str): Path of annotation file. + + Returns: + list[dict]: Annotation info from COCO api. + """ + + self.coco = COCO(ann_file) + # The order of returned `cat_ids` will not + # change with the order of the CLASSES + self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES) + + self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} + self.img_ids = self.coco.get_img_ids() + data_infos = [] + total_ann_ids = [] + for i in self.img_ids: + info = self.coco.load_imgs([i])[0] + info['filename'] = info['file_name'] + data_infos.append(info) + ann_ids = self.coco.get_ann_ids(img_ids=[i]) + total_ann_ids.extend(ann_ids) + assert len(set(total_ann_ids)) == len( + total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!" + return data_infos + + def get_ann_info(self, idx): + """Get COCO annotation by index. + + Args: + idx (int): Index of data. + + Returns: + dict: Annotation info of specified index. + """ + + img_id = self.data_infos[idx]['id'] + ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) + ann_info = self.coco.load_anns(ann_ids) + return self._parse_ann_info(self.data_infos[idx], ann_info) + + def get_cat_ids(self, idx): + """Get COCO category ids by index. + + Args: + idx (int): Index of data. + + Returns: + list[int]: All categories in the image of specified index. + """ + + img_id = self.data_infos[idx]['id'] + ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) + ann_info = self.coco.load_anns(ann_ids) + return [ann['category_id'] for ann in ann_info] + + def _filter_imgs(self, min_size=32): + """Filter images too small or without ground truths.""" + valid_inds = [] + # obtain images that contain annotation + ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values()) + # obtain images that contain annotations of the required categories + ids_in_cat = set() + for i, class_id in enumerate(self.cat_ids): + ids_in_cat |= set(self.coco.cat_img_map[class_id]) + # merge the image id sets of the two conditions and use the merged set + # to filter out images if self.filter_empty_gt=True + ids_in_cat &= ids_with_ann + + valid_img_ids = [] + for i, img_info in enumerate(self.data_infos): + img_id = self.img_ids[i] + if self.filter_empty_gt and img_id not in ids_in_cat: + continue + if min(img_info['width'], img_info['height']) >= min_size: + valid_inds.append(i) + valid_img_ids.append(img_id) + self.img_ids = valid_img_ids + return valid_inds + + def _parse_ann_info(self, img_info, ann_info): + """Parse bbox and mask annotation. + + Args: + ann_info (list[dict]): Annotation info of an image. + with_mask (bool): Whether to parse mask annotations. + + Returns: + dict: A dict containing the following keys: bboxes, bboxes_ignore,\ + labels, masks, seg_map. "masks" are raw annotations and not \ + decoded into binary masks. + """ + gt_bboxes = [] + gt_labels = [] + gt_bboxes_ignore = [] + gt_masks_ann = [] + for i, ann in enumerate(ann_info): + if ann.get('ignore', False): + continue + x1, y1, w, h = ann['bbox'] + inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) + inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) + if inter_w * inter_h == 0: + continue + if ann['area'] <= 0 or w < 1 or h < 1: + continue + if ann['category_id'] not in self.cat_ids: + continue + bbox = [x1, y1, x1 + w, y1 + h] + if ann.get('iscrowd', False): + gt_bboxes_ignore.append(bbox) + else: + gt_bboxes.append(bbox) + gt_labels.append(self.cat2label[ann['category_id']]) + gt_masks_ann.append(ann.get('segmentation', None)) + + if gt_bboxes: + gt_bboxes = np.array(gt_bboxes, dtype=np.float32) + gt_labels = np.array(gt_labels, dtype=np.int64) + else: + gt_bboxes = np.zeros((0, 4), dtype=np.float32) + gt_labels = np.array([], dtype=np.int64) + + if gt_bboxes_ignore: + gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32) + else: + gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) + + seg_map = img_info['filename'].replace('jpg', 'png') + + ann = dict( + bboxes=gt_bboxes, + labels=gt_labels, + bboxes_ignore=gt_bboxes_ignore, + masks=gt_masks_ann, + seg_map=seg_map) + + return ann + + def xyxy2xywh(self, bbox): + """Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO + evaluation. + + Args: + bbox (numpy.ndarray): The bounding boxes, shape (4, ), in + ``xyxy`` order. + + Returns: + list[float]: The converted bounding boxes, in ``xywh`` order. + """ + + _bbox = bbox.tolist() + return [ + _bbox[0], + _bbox[1], + _bbox[2] - _bbox[0], + _bbox[3] - _bbox[1], + ] + + def _proposal2json(self, results): + """Convert proposal results to COCO json style.""" + json_results = [] + for idx in range(len(self)): + img_id = self.img_ids[idx] + bboxes = results[idx] + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['bbox'] = self.xyxy2xywh(bboxes[i]) + data['score'] = float(bboxes[i][4]) + data['category_id'] = 1 + json_results.append(data) + return json_results + + def _det2json(self, results): + """Convert detection results to COCO json style.""" + json_results = [] + for idx in range(len(self)): + img_id = self.img_ids[idx] + result = results[idx] + for label in range(len(result)): + bboxes = result[label] + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['bbox'] = self.xyxy2xywh(bboxes[i]) + data['score'] = float(bboxes[i][4]) + data['category_id'] = self.cat_ids[label] + json_results.append(data) + return json_results + + def _segm2json(self, results): + """Convert instance segmentation results to COCO json style.""" + bbox_json_results = [] + segm_json_results = [] + for idx in range(len(self)): + img_id = self.img_ids[idx] + det, seg = results[idx] + for label in range(len(det)): + # bbox results + bboxes = det[label] + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['bbox'] = self.xyxy2xywh(bboxes[i]) + data['score'] = float(bboxes[i][4]) + data['category_id'] = self.cat_ids[label] + bbox_json_results.append(data) + + # segm results + # some detectors use different scores for bbox and mask + if isinstance(seg, tuple): + segms = seg[0][label] + mask_score = seg[1][label] + else: + segms = seg[label] + mask_score = [bbox[4] for bbox in bboxes] + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['bbox'] = self.xyxy2xywh(bboxes[i]) + data['score'] = float(mask_score[i]) + data['category_id'] = self.cat_ids[label] + if isinstance(segms[i]['counts'], bytes): + segms[i]['counts'] = segms[i]['counts'].decode() + data['segmentation'] = segms[i] + segm_json_results.append(data) + return bbox_json_results, segm_json_results + + def results2json(self, results, outfile_prefix): + """Dump the detection results to a COCO style json file. + + There are 3 types of results: proposals, bbox predictions, mask + predictions, and they have different data types. This method will + automatically recognize the type, and dump them to json files. + + Args: + results (list[list | tuple | ndarray]): Testing results of the + dataset. + outfile_prefix (str): The filename prefix of the json files. If the + prefix is "somepath/xxx", the json files will be named + "somepath/xxx.bbox.json", "somepath/xxx.segm.json", + "somepath/xxx.proposal.json". + + Returns: + dict[str: str]: Possible keys are "bbox", "segm", "proposal", and \ + values are corresponding filenames. + """ + result_files = dict() + if isinstance(results[0], list): + json_results = self._det2json(results) + result_files['bbox'] = f'{outfile_prefix}.bbox.json' + result_files['proposal'] = f'{outfile_prefix}.bbox.json' + dbnet_cv.dump(json_results, result_files['bbox']) + elif isinstance(results[0], tuple): + json_results = self._segm2json(results) + result_files['bbox'] = f'{outfile_prefix}.bbox.json' + result_files['proposal'] = f'{outfile_prefix}.bbox.json' + result_files['segm'] = f'{outfile_prefix}.segm.json' + dbnet_cv.dump(json_results[0], result_files['bbox']) + dbnet_cv.dump(json_results[1], result_files['segm']) + elif isinstance(results[0], np.ndarray): + json_results = self._proposal2json(results) + result_files['proposal'] = f'{outfile_prefix}.proposal.json' + dbnet_cv.dump(json_results, result_files['proposal']) + else: + raise TypeError('invalid type of results') + return result_files + + def fast_eval_recall(self, results, proposal_nums, iou_thrs, logger=None): + gt_bboxes = [] + for i in range(len(self.img_ids)): + ann_ids = self.coco.get_ann_ids(img_ids=self.img_ids[i]) + ann_info = self.coco.load_anns(ann_ids) + if len(ann_info) == 0: + gt_bboxes.append(np.zeros((0, 4))) + continue + bboxes = [] + for ann in ann_info: + if ann.get('ignore', False) or ann['iscrowd']: + continue + x1, y1, w, h = ann['bbox'] + bboxes.append([x1, y1, x1 + w, y1 + h]) + bboxes = np.array(bboxes, dtype=np.float32) + if bboxes.shape[0] == 0: + bboxes = np.zeros((0, 4)) + gt_bboxes.append(bboxes) + + recalls = eval_recalls( + gt_bboxes, results, proposal_nums, iou_thrs, logger=logger) + ar = recalls.mean(axis=1) + return ar + + def format_results(self, results, jsonfile_prefix=None, **kwargs): + """Format the results to json (standard format for COCO evaluation). + + Args: + results (list[tuple | numpy.ndarray]): Testing results of the + dataset. + jsonfile_prefix (str | None): The prefix of json files. It includes + the file path and the prefix of filename, e.g., "a/b/prefix". + If not specified, a temp file will be created. Default: None. + + Returns: + tuple: (result_files, tmp_dir), result_files is a dict containing \ + the json filepaths, tmp_dir is the temporal directory created \ + for saving json files when jsonfile_prefix is not specified. + """ + assert isinstance(results, list), 'results must be a list' + assert len(results) == len(self), ( + 'The length of results is not equal to the dataset len: {} != {}'. + format(len(results), len(self))) + + if jsonfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + jsonfile_prefix = osp.join(tmp_dir.name, 'results') + else: + tmp_dir = None + result_files = self.results2json(results, jsonfile_prefix) + return result_files, tmp_dir + + def evaluate_det_segm(self, + results, + result_files, + coco_gt, + metrics, + logger=None, + classwise=False, + proposal_nums=(100, 300, 1000), + iou_thrs=None, + metric_items=None): + """Instance segmentation and object detection evaluation in COCO + protocol. + + Args: + results (list[list | tuple | dict]): Testing results of the + dataset. + result_files (dict[str, str]): a dict contains json file path. + coco_gt (COCO): COCO API object with ground truth annotation. + metric (str | list[str]): Metrics to be evaluated. Options are + 'bbox', 'segm', 'proposal', 'proposal_fast'. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + classwise (bool): Whether to evaluating the AP for each class. + proposal_nums (Sequence[int]): Proposal number used for evaluating + recalls, such as recall@100, recall@1000. + Default: (100, 300, 1000). + iou_thrs (Sequence[float], optional): IoU threshold used for + evaluating recalls/mAPs. If set to a list, the average of all + IoUs will also be computed. If not specified, [0.50, 0.55, + 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used. + Default: None. + metric_items (list[str] | str, optional): Metric items that will + be returned. If not specified, ``['AR@100', 'AR@300', + 'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be + used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75', + 'mAP_s', 'mAP_m', 'mAP_l']`` will be used when + ``metric=='bbox' or metric=='segm'``. + + Returns: + dict[str, float]: COCO style evaluation metric. + """ + if iou_thrs is None: + iou_thrs = np.linspace( + .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) + if metric_items is not None: + if not isinstance(metric_items, list): + metric_items = [metric_items] + + eval_results = OrderedDict() + for metric in metrics: + msg = f'Evaluating {metric}...' + if logger is None: + msg = '\n' + msg + print_log(msg, logger=logger) + + if metric == 'proposal_fast': + if isinstance(results[0], tuple): + raise KeyError('proposal_fast is not supported for ' + 'instance segmentation result.') + ar = self.fast_eval_recall( + results, proposal_nums, iou_thrs, logger='silent') + log_msg = [] + for i, num in enumerate(proposal_nums): + eval_results[f'AR@{num}'] = ar[i] + log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}') + log_msg = ''.join(log_msg) + print_log(log_msg, logger=logger) + continue + + iou_type = 'bbox' if metric == 'proposal' else metric + if metric not in result_files: + raise KeyError(f'{metric} is not in results') + try: + predictions = dbnet_cv.load(result_files[metric]) + if iou_type == 'segm': + # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa + # When evaluating mask AP, if the results contain bbox, + # cocoapi will use the box area instead of the mask area + # for calculating the instance area. Though the overall AP + # is not affected, this leads to different + # small/medium/large mask AP results. + for x in predictions: + x.pop('bbox') + warnings.simplefilter('once') + warnings.warn( + 'The key "bbox" is deleted for more accurate mask AP ' + 'of small/medium/large instances since v2.12.0. This ' + 'does not change the overall mAP calculation.', + UserWarning) + coco_det = coco_gt.loadRes(predictions) + except IndexError: + print_log( + 'The testing results of the whole dataset is empty.', + logger=logger, + level=logging.ERROR) + break + + cocoEval = COCOeval(coco_gt, coco_det, iou_type) + cocoEval.params.catIds = self.cat_ids + cocoEval.params.imgIds = self.img_ids + cocoEval.params.maxDets = list(proposal_nums) + cocoEval.params.iouThrs = iou_thrs + # mapping of cocoEval.stats + coco_metric_names = { + 'mAP': 0, + 'mAP_50': 1, + 'mAP_75': 2, + 'mAP_s': 3, + 'mAP_m': 4, + 'mAP_l': 5, + 'AR@100': 6, + 'AR@300': 7, + 'AR@1000': 8, + 'AR_s@1000': 9, + 'AR_m@1000': 10, + 'AR_l@1000': 11 + } + if metric_items is not None: + for metric_item in metric_items: + if metric_item not in coco_metric_names: + raise KeyError( + f'metric item {metric_item} is not supported') + + if metric == 'proposal': + cocoEval.params.useCats = 0 + cocoEval.evaluate() + cocoEval.accumulate() + + # Save coco summarize print information to logger + redirect_string = io.StringIO() + with contextlib.redirect_stdout(redirect_string): + cocoEval.summarize() + print_log('\n' + redirect_string.getvalue(), logger=logger) + + if metric_items is None: + metric_items = [ + 'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000', + 'AR_m@1000', 'AR_l@1000' + ] + + for item in metric_items: + val = float( + f'{cocoEval.stats[coco_metric_names[item]]:.3f}') + eval_results[item] = val + else: + cocoEval.evaluate() + cocoEval.accumulate() + + # Save coco summarize print information to logger + redirect_string = io.StringIO() + with contextlib.redirect_stdout(redirect_string): + cocoEval.summarize() + print_log('\n' + redirect_string.getvalue(), logger=logger) + + if classwise: # Compute per-category AP + # Compute per-category AP + # from https://github.com/facebookresearch/detectron2/ + precisions = cocoEval.eval['precision'] + # precision: (iou, recall, cls, area range, max dets) + assert len(self.cat_ids) == precisions.shape[2] + + results_per_category = [] + for idx, catId in enumerate(self.cat_ids): + # area range index 0: all area ranges + # max dets index -1: typically 100 per image + nm = self.coco.loadCats(catId)[0] + precision = precisions[:, :, idx, 0, -1] + precision = precision[precision > -1] + if precision.size: + ap = np.mean(precision) + else: + ap = float('nan') + results_per_category.append( + (f'{nm["name"]}', f'{float(ap):0.3f}')) + + num_columns = min(6, len(results_per_category) * 2) + results_flatten = list( + itertools.chain(*results_per_category)) + headers = ['category', 'AP'] * (num_columns // 2) + results_2d = itertools.zip_longest(*[ + results_flatten[i::num_columns] + for i in range(num_columns) + ]) + table_data = [headers] + table_data += [result for result in results_2d] + table = AsciiTable(table_data) + print_log('\n' + table.table, logger=logger) + + if metric_items is None: + metric_items = [ + 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l' + ] + + for metric_item in metric_items: + key = f'{metric}_{metric_item}' + val = float( + f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}' + ) + eval_results[key] = val + ap = cocoEval.stats[:6] + eval_results[f'{metric}_mAP_copypaste'] = ( + f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} ' + f'{ap[4]:.3f} {ap[5]:.3f}') + + return eval_results + + def evaluate(self, + results, + metric='bbox', + logger=None, + jsonfile_prefix=None, + classwise=False, + proposal_nums=(100, 300, 1000), + iou_thrs=None, + metric_items=None): + """Evaluation in COCO protocol. + + Args: + results (list[list | tuple]): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. Options are + 'bbox', 'segm', 'proposal', 'proposal_fast'. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + jsonfile_prefix (str | None): The prefix of json files. It includes + the file path and the prefix of filename, e.g., "a/b/prefix". + If not specified, a temp file will be created. Default: None. + classwise (bool): Whether to evaluating the AP for each class. + proposal_nums (Sequence[int]): Proposal number used for evaluating + recalls, such as recall@100, recall@1000. + Default: (100, 300, 1000). + iou_thrs (Sequence[float], optional): IoU threshold used for + evaluating recalls/mAPs. If set to a list, the average of all + IoUs will also be computed. If not specified, [0.50, 0.55, + 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used. + Default: None. + metric_items (list[str] | str, optional): Metric items that will + be returned. If not specified, ``['AR@100', 'AR@300', + 'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be + used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75', + 'mAP_s', 'mAP_m', 'mAP_l']`` will be used when + ``metric=='bbox' or metric=='segm'``. + + Returns: + dict[str, float]: COCO style evaluation metric. + """ + + metrics = metric if isinstance(metric, list) else [metric] + allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast'] + for metric in metrics: + if metric not in allowed_metrics: + raise KeyError(f'metric {metric} is not supported') + + coco_gt = self.coco + self.cat_ids = coco_gt.get_cat_ids(cat_names=self.CLASSES) + + result_files, tmp_dir = self.format_results(results, jsonfile_prefix) + eval_results = self.evaluate_det_segm(results, result_files, coco_gt, + metrics, logger, classwise, + proposal_nums, iou_thrs, + metric_items) + + if tmp_dir is not None: + tmp_dir.cleanup() + return eval_results diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/custom.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/custom.py new file mode 100755 index 00000000..ce37d099 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/custom.py @@ -0,0 +1,410 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings +from collections import OrderedDict + +import dbnet_cv +import numpy as np +from dbnet_cv.utils import print_log +from terminaltables import AsciiTable +from torch.utils.data import Dataset + +from dbnet_det.core import eval_map, eval_recalls +from .builder import DATASETS +from .pipelines import Compose + + +@DATASETS.register_module() +class CustomDataset(Dataset): + """Custom dataset for detection. + + The annotation format is shown as follows. The `ann` field is optional for + testing. + + .. code-block:: none + + [ + { + 'filename': 'a.jpg', + 'width': 1280, + 'height': 720, + 'ann': { + 'bboxes': (n, 4) in (x1, y1, x2, y2) order. + 'labels': (n, ), + 'bboxes_ignore': (k, 4), (optional field) + 'labels_ignore': (k, 4) (optional field) + } + }, + ... + ] + + Args: + ann_file (str): Annotation file path. + pipeline (list[dict]): Processing pipeline. + classes (str | Sequence[str], optional): Specify classes to load. + If is None, ``cls.CLASSES`` will be used. Default: None. + data_root (str, optional): Data root for ``ann_file``, + ``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified. + test_mode (bool, optional): If set True, annotation will not be loaded. + filter_empty_gt (bool, optional): If set true, images without bounding + boxes of the dataset's classes will be filtered out. This option + only works when `test_mode=False`, i.e., we never filter images + during tests. + """ + + CLASSES = None + + PALETTE = None + + def __init__(self, + ann_file, + pipeline, + classes=None, + data_root=None, + img_prefix='', + seg_prefix=None, + proposal_file=None, + test_mode=False, + filter_empty_gt=True, + file_client_args=dict(backend='disk')): + self.ann_file = ann_file + self.data_root = data_root + self.img_prefix = img_prefix + self.seg_prefix = seg_prefix + self.proposal_file = proposal_file + self.test_mode = test_mode + self.filter_empty_gt = filter_empty_gt + self.file_client = dbnet_cv.FileClient(**file_client_args) + self.CLASSES = self.get_classes(classes) + + # join paths if data_root is specified + if self.data_root is not None: + if not osp.isabs(self.ann_file): + self.ann_file = osp.join(self.data_root, self.ann_file) + if not (self.img_prefix is None or osp.isabs(self.img_prefix)): + self.img_prefix = osp.join(self.data_root, self.img_prefix) + if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)): + self.seg_prefix = osp.join(self.data_root, self.seg_prefix) + if not (self.proposal_file is None + or osp.isabs(self.proposal_file)): + self.proposal_file = osp.join(self.data_root, + self.proposal_file) + # load annotations (and proposals) + if hasattr(self.file_client, 'get_local_path'): + with self.file_client.get_local_path(self.ann_file) as local_path: + self.data_infos = self.load_annotations(local_path) + else: + warnings.warn( + 'The used DBNET_CV version does not have get_local_path. ' + f'We treat the {self.ann_file} as local paths and it ' + 'might cause errors if the path is not a local path. ' + 'Please use DBNET_CV>= 1.3.16 if you meet errors.') + self.data_infos = self.load_annotations(self.ann_file) + + if self.proposal_file is not None: + if hasattr(self.file_client, 'get_local_path'): + with self.file_client.get_local_path( + self.proposal_file) as local_path: + self.proposals = self.load_proposals(local_path) + else: + warnings.warn( + 'The used DBNET_CV version does not have get_local_path. ' + f'We treat the {self.ann_file} as local paths and it ' + 'might cause errors if the path is not a local path. ' + 'Please use DBNET_CV>= 1.3.16 if you meet errors.') + self.proposals = self.load_proposals(self.proposal_file) + else: + self.proposals = None + + # filter images too small and containing no annotations + if not test_mode: + valid_inds = self._filter_imgs() + self.data_infos = [self.data_infos[i] for i in valid_inds] + if self.proposals is not None: + self.proposals = [self.proposals[i] for i in valid_inds] + # set group flag for the sampler + self._set_group_flag() + + # processing pipeline + self.pipeline = Compose(pipeline) + + def __len__(self): + """Total number of samples of data.""" + return len(self.data_infos) + + def load_annotations(self, ann_file): + """Load annotation from annotation file.""" + return dbnet_cv.load(ann_file) + + def load_proposals(self, proposal_file): + """Load proposal from proposal file.""" + return dbnet_cv.load(proposal_file) + + def get_ann_info(self, idx): + """Get annotation by index. + + Args: + idx (int): Index of data. + + Returns: + dict: Annotation info of specified index. + """ + + return self.data_infos[idx]['ann'] + + def get_cat_ids(self, idx): + """Get category ids by index. + + Args: + idx (int): Index of data. + + Returns: + list[int]: All categories in the image of specified index. + """ + + return self.data_infos[idx]['ann']['labels'].astype(np.int).tolist() + + def pre_pipeline(self, results): + """Prepare results dict for pipeline.""" + results['img_prefix'] = self.img_prefix + results['seg_prefix'] = self.seg_prefix + results['proposal_file'] = self.proposal_file + results['bbox_fields'] = [] + results['mask_fields'] = [] + results['seg_fields'] = [] + + def _filter_imgs(self, min_size=32): + """Filter images too small.""" + if self.filter_empty_gt: + warnings.warn( + 'CustomDataset does not support filtering empty gt images.') + valid_inds = [] + for i, img_info in enumerate(self.data_infos): + if min(img_info['width'], img_info['height']) >= min_size: + valid_inds.append(i) + return valid_inds + + def _set_group_flag(self): + """Set flag according to image aspect ratio. + + Images with aspect ratio greater than 1 will be set as group 1, + otherwise group 0. + """ + self.flag = np.zeros(len(self), dtype=np.uint8) + for i in range(len(self)): + img_info = self.data_infos[i] + if img_info['width'] / img_info['height'] > 1: + self.flag[i] = 1 + + def _rand_another(self, idx): + """Get another random index from the same group as the given index.""" + pool = np.where(self.flag == self.flag[idx])[0] + return np.random.choice(pool) + + def __getitem__(self, idx): + """Get training/test data after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Training/test data (with annotation if `test_mode` is set \ + True). + """ + + if self.test_mode: + return self.prepare_test_img(idx) + while True: + data = self.prepare_train_img(idx) + if data is None: + idx = self._rand_another(idx) + continue + return data + + def prepare_train_img(self, idx): + """Get training data and annotations after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Training data and annotation after pipeline with new keys \ + introduced by pipeline. + """ + + img_info = self.data_infos[idx] + ann_info = self.get_ann_info(idx) + results = dict(img_info=img_info, ann_info=ann_info) + if self.proposals is not None: + results['proposals'] = self.proposals[idx] + self.pre_pipeline(results) + return self.pipeline(results) + + def prepare_test_img(self, idx): + """Get testing data after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Testing data after pipeline with new keys introduced by \ + pipeline. + """ + + img_info = self.data_infos[idx] + results = dict(img_info=img_info) + if self.proposals is not None: + results['proposals'] = self.proposals[idx] + self.pre_pipeline(results) + return self.pipeline(results) + + @classmethod + def get_classes(cls, classes=None): + """Get class names of current dataset. + + Args: + classes (Sequence[str] | str | None): If classes is None, use + default CLASSES defined by builtin dataset. If classes is a + string, take it as a file name. The file contains the name of + classes where each line contains one class name. If classes is + a tuple or list, override the CLASSES defined by the dataset. + + Returns: + tuple[str] or list[str]: Names of categories of the dataset. + """ + if classes is None: + return cls.CLASSES + + if isinstance(classes, str): + # take it as a file path + class_names = dbnet_cv.list_from_file(classes) + elif isinstance(classes, (tuple, list)): + class_names = classes + else: + raise ValueError(f'Unsupported type {type(classes)} of classes.') + + return class_names + + def get_cat2imgs(self): + """Get a dict with class as key and img_ids as values, which will be + used in :class:`ClassAwareSampler`. + + Returns: + dict[list]: A dict of per-label image list, + the item of the dict indicates a label index, + corresponds to the image index that contains the label. + """ + if self.CLASSES is None: + raise ValueError('self.CLASSES can not be None') + # sort the label index + cat2imgs = {i: [] for i in range(len(self.CLASSES))} + for i in range(len(self)): + cat_ids = set(self.get_cat_ids(i)) + for cat in cat_ids: + cat2imgs[cat].append(i) + return cat2imgs + + def format_results(self, results, **kwargs): + """Place holder to format result to dataset specific output.""" + + def evaluate(self, + results, + metric='mAP', + logger=None, + proposal_nums=(100, 300, 1000), + iou_thr=0.5, + scale_ranges=None): + """Evaluate the dataset. + + Args: + results (list): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. + logger (logging.Logger | None | str): Logger used for printing + related information during evaluation. Default: None. + proposal_nums (Sequence[int]): Proposal number used for evaluating + recalls, such as recall@100, recall@1000. + Default: (100, 300, 1000). + iou_thr (float | list[float]): IoU threshold. Default: 0.5. + scale_ranges (list[tuple] | None): Scale ranges for evaluating mAP. + Default: None. + """ + + if not isinstance(metric, str): + assert len(metric) == 1 + metric = metric[0] + allowed_metrics = ['mAP', 'recall'] + if metric not in allowed_metrics: + raise KeyError(f'metric {metric} is not supported') + annotations = [self.get_ann_info(i) for i in range(len(self))] + eval_results = OrderedDict() + iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr + if metric == 'mAP': + assert isinstance(iou_thrs, list) + mean_aps = [] + for iou_thr in iou_thrs: + print_log(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}') + mean_ap, _ = eval_map( + results, + annotations, + scale_ranges=scale_ranges, + iou_thr=iou_thr, + dataset=self.CLASSES, + logger=logger) + mean_aps.append(mean_ap) + eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3) + eval_results['mAP'] = sum(mean_aps) / len(mean_aps) + elif metric == 'recall': + gt_bboxes = [ann['bboxes'] for ann in annotations] + recalls = eval_recalls( + gt_bboxes, results, proposal_nums, iou_thr, logger=logger) + for i, num in enumerate(proposal_nums): + for j, iou in enumerate(iou_thrs): + eval_results[f'recall@{num}@{iou}'] = recalls[i, j] + if recalls.shape[1] > 1: + ar = recalls.mean(axis=1) + for i, num in enumerate(proposal_nums): + eval_results[f'AR@{num}'] = ar[i] + return eval_results + + def __repr__(self): + """Print the number of instance number.""" + dataset_type = 'Test' if self.test_mode else 'Train' + result = (f'\n{self.__class__.__name__} {dataset_type} dataset ' + f'with number of images {len(self)}, ' + f'and instance counts: \n') + if self.CLASSES is None: + result += 'Category names are not provided. \n' + return result + instance_count = np.zeros(len(self.CLASSES) + 1).astype(int) + # count the instance number in each image + for idx in range(len(self)): + label = self.get_ann_info(idx)['labels'] + unique, counts = np.unique(label, return_counts=True) + if len(unique) > 0: + # add the occurrence number to each class + instance_count[unique] += counts + else: + # background is the last index + instance_count[-1] += 1 + # create a table with category count + table_data = [['category', 'count'] * 5] + row_data = [] + for cls, count in enumerate(instance_count): + if cls < len(self.CLASSES): + row_data += [f'{cls} [{self.CLASSES[cls]}]', f'{count}'] + else: + # add the background number + row_data += ['-1 background', f'{count}'] + if len(row_data) == 10: + table_data.append(row_data) + row_data = [] + if len(row_data) >= 2: + if row_data[-1] == '0': + row_data = row_data[:-2] + if len(row_data) >= 2: + table_data.append([]) + table_data.append(row_data) + + table = AsciiTable(table_data) + result += table.table + return result diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/dataset_wrappers.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/dataset_wrappers.py new file mode 100755 index 00000000..4938ef8a --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/dataset_wrappers.py @@ -0,0 +1,456 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import bisect +import collections +import copy +import math +from collections import defaultdict + +import numpy as np +from dbnet_cv.utils import build_from_cfg, print_log +from torch.utils.data.dataset import ConcatDataset as _ConcatDataset + +from .builder import DATASETS, PIPELINES +from .coco import CocoDataset + + +@DATASETS.register_module() +class ConcatDataset(_ConcatDataset): + """A wrapper of concatenated dataset. + + Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but + concat the group flag for image aspect ratio. + + Args: + datasets (list[:obj:`Dataset`]): A list of datasets. + separate_eval (bool): Whether to evaluate the results + separately if it is used as validation dataset. + Defaults to True. + """ + + def __init__(self, datasets, separate_eval=True): + super(ConcatDataset, self).__init__(datasets) + self.CLASSES = datasets[0].CLASSES + self.PALETTE = getattr(datasets[0], 'PALETTE', None) + self.separate_eval = separate_eval + if not separate_eval: + if any([isinstance(ds, CocoDataset) for ds in datasets]): + raise NotImplementedError( + 'Evaluating concatenated CocoDataset as a whole is not' + ' supported! Please set "separate_eval=True"') + elif len(set([type(ds) for ds in datasets])) != 1: + raise NotImplementedError( + 'All the datasets should have same types') + + if hasattr(datasets[0], 'flag'): + flags = [] + for i in range(0, len(datasets)): + flags.append(datasets[i].flag) + self.flag = np.concatenate(flags) + + def get_cat_ids(self, idx): + """Get category ids of concatenated dataset by index. + + Args: + idx (int): Index of data. + + Returns: + list[int]: All categories in the image of specified index. + """ + + if idx < 0: + if -idx > len(self): + raise ValueError( + 'absolute value of index should not exceed dataset length') + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx].get_cat_ids(sample_idx) + + def get_ann_info(self, idx): + """Get annotation of concatenated dataset by index. + + Args: + idx (int): Index of data. + + Returns: + dict: Annotation info of specified index. + """ + + if idx < 0: + if -idx > len(self): + raise ValueError( + 'absolute value of index should not exceed dataset length') + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx].get_ann_info(sample_idx) + + def evaluate(self, results, logger=None, **kwargs): + """Evaluate the results. + + Args: + results (list[list | tuple]): Testing results of the dataset. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + + Returns: + dict[str: float]: AP results of the total dataset or each separate + dataset if `self.separate_eval=True`. + """ + assert len(results) == self.cumulative_sizes[-1], \ + ('Dataset and results have different sizes: ' + f'{self.cumulative_sizes[-1]} v.s. {len(results)}') + + # Check whether all the datasets support evaluation + for dataset in self.datasets: + assert hasattr(dataset, 'evaluate'), \ + f'{type(dataset)} does not implement evaluate function' + + if self.separate_eval: + dataset_idx = -1 + total_eval_results = dict() + for size, dataset in zip(self.cumulative_sizes, self.datasets): + start_idx = 0 if dataset_idx == -1 else \ + self.cumulative_sizes[dataset_idx] + end_idx = self.cumulative_sizes[dataset_idx + 1] + + results_per_dataset = results[start_idx:end_idx] + print_log( + f'\nEvaluateing {dataset.ann_file} with ' + f'{len(results_per_dataset)} images now', + logger=logger) + + eval_results_per_dataset = dataset.evaluate( + results_per_dataset, logger=logger, **kwargs) + dataset_idx += 1 + for k, v in eval_results_per_dataset.items(): + total_eval_results.update({f'{dataset_idx}_{k}': v}) + + return total_eval_results + elif any([isinstance(ds, CocoDataset) for ds in self.datasets]): + raise NotImplementedError( + 'Evaluating concatenated CocoDataset as a whole is not' + ' supported! Please set "separate_eval=True"') + elif len(set([type(ds) for ds in self.datasets])) != 1: + raise NotImplementedError( + 'All the datasets should have same types') + else: + original_data_infos = self.datasets[0].data_infos + self.datasets[0].data_infos = sum( + [dataset.data_infos for dataset in self.datasets], []) + eval_results = self.datasets[0].evaluate( + results, logger=logger, **kwargs) + self.datasets[0].data_infos = original_data_infos + return eval_results + + +@DATASETS.register_module() +class RepeatDataset: + """A wrapper of repeated dataset. + + The length of repeated dataset will be `times` larger than the original + dataset. This is useful when the data loading time is long but the dataset + is small. Using RepeatDataset can reduce the data loading time between + epochs. + + Args: + dataset (:obj:`Dataset`): The dataset to be repeated. + times (int): Repeat times. + """ + + def __init__(self, dataset, times): + self.dataset = dataset + self.times = times + self.CLASSES = dataset.CLASSES + self.PALETTE = getattr(dataset, 'PALETTE', None) + if hasattr(self.dataset, 'flag'): + self.flag = np.tile(self.dataset.flag, times) + + self._ori_len = len(self.dataset) + + def __getitem__(self, idx): + return self.dataset[idx % self._ori_len] + + def get_cat_ids(self, idx): + """Get category ids of repeat dataset by index. + + Args: + idx (int): Index of data. + + Returns: + list[int]: All categories in the image of specified index. + """ + + return self.dataset.get_cat_ids(idx % self._ori_len) + + def get_ann_info(self, idx): + """Get annotation of repeat dataset by index. + + Args: + idx (int): Index of data. + + Returns: + dict: Annotation info of specified index. + """ + + return self.dataset.get_ann_info(idx % self._ori_len) + + def __len__(self): + """Length after repetition.""" + return self.times * self._ori_len + + +# Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa +@DATASETS.register_module() +class ClassBalancedDataset: + """A wrapper of repeated dataset with repeat factor. + + Suitable for training on class imbalanced datasets like LVIS. Following + the sampling strategy in the `paper `_, + in each epoch, an image may appear multiple times based on its + "repeat factor". + The repeat factor for an image is a function of the frequency the rarest + category labeled in that image. The "frequency of category c" in [0, 1] + is defined by the fraction of images in the training set (without repeats) + in which category c appears. + The dataset needs to instantiate :func:`self.get_cat_ids` to support + ClassBalancedDataset. + + The repeat factor is computed as followed. + + 1. For each category c, compute the fraction # of images + that contain it: :math:`f(c)` + 2. For each category c, compute the category-level repeat factor: + :math:`r(c) = max(1, sqrt(t/f(c)))` + 3. For each image I, compute the image-level repeat factor: + :math:`r(I) = max_{c in I} r(c)` + + Args: + dataset (:obj:`CustomDataset`): The dataset to be repeated. + oversample_thr (float): frequency threshold below which data is + repeated. For categories with ``f_c >= oversample_thr``, there is + no oversampling. For categories with ``f_c < oversample_thr``, the + degree of oversampling following the square-root inverse frequency + heuristic above. + filter_empty_gt (bool, optional): If set true, images without bounding + boxes will not be oversampled. Otherwise, they will be categorized + as the pure background class and involved into the oversampling. + Default: True. + """ + + def __init__(self, dataset, oversample_thr, filter_empty_gt=True): + self.dataset = dataset + self.oversample_thr = oversample_thr + self.filter_empty_gt = filter_empty_gt + self.CLASSES = dataset.CLASSES + self.PALETTE = getattr(dataset, 'PALETTE', None) + + repeat_factors = self._get_repeat_factors(dataset, oversample_thr) + repeat_indices = [] + for dataset_idx, repeat_factor in enumerate(repeat_factors): + repeat_indices.extend([dataset_idx] * math.ceil(repeat_factor)) + self.repeat_indices = repeat_indices + + flags = [] + if hasattr(self.dataset, 'flag'): + for flag, repeat_factor in zip(self.dataset.flag, repeat_factors): + flags.extend([flag] * int(math.ceil(repeat_factor))) + assert len(flags) == len(repeat_indices) + self.flag = np.asarray(flags, dtype=np.uint8) + + def _get_repeat_factors(self, dataset, repeat_thr): + """Get repeat factor for each images in the dataset. + + Args: + dataset (:obj:`CustomDataset`): The dataset + repeat_thr (float): The threshold of frequency. If an image + contains the categories whose frequency below the threshold, + it would be repeated. + + Returns: + list[float]: The repeat factors for each images in the dataset. + """ + + # 1. For each category c, compute the fraction # of images + # that contain it: f(c) + category_freq = defaultdict(int) + num_images = len(dataset) + for idx in range(num_images): + cat_ids = set(self.dataset.get_cat_ids(idx)) + if len(cat_ids) == 0 and not self.filter_empty_gt: + cat_ids = set([len(self.CLASSES)]) + for cat_id in cat_ids: + category_freq[cat_id] += 1 + for k, v in category_freq.items(): + category_freq[k] = v / num_images + + # 2. For each category c, compute the category-level repeat factor: + # r(c) = max(1, sqrt(t/f(c))) + category_repeat = { + cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq)) + for cat_id, cat_freq in category_freq.items() + } + + # 3. For each image I, compute the image-level repeat factor: + # r(I) = max_{c in I} r(c) + repeat_factors = [] + for idx in range(num_images): + cat_ids = set(self.dataset.get_cat_ids(idx)) + if len(cat_ids) == 0 and not self.filter_empty_gt: + cat_ids = set([len(self.CLASSES)]) + repeat_factor = 1 + if len(cat_ids) > 0: + repeat_factor = max( + {category_repeat[cat_id] + for cat_id in cat_ids}) + repeat_factors.append(repeat_factor) + + return repeat_factors + + def __getitem__(self, idx): + ori_index = self.repeat_indices[idx] + return self.dataset[ori_index] + + def get_ann_info(self, idx): + """Get annotation of dataset by index. + + Args: + idx (int): Index of data. + + Returns: + dict: Annotation info of specified index. + """ + ori_index = self.repeat_indices[idx] + return self.dataset.get_ann_info(ori_index) + + def __len__(self): + """Length after repetition.""" + return len(self.repeat_indices) + + +@DATASETS.register_module() +class MultiImageMixDataset: + """A wrapper of multiple images mixed dataset. + + Suitable for training on multiple images mixed data augmentation like + mosaic and mixup. For the augmentation pipeline of mixed image data, + the `get_indexes` method needs to be provided to obtain the image + indexes, and you can set `skip_flags` to change the pipeline running + process. At the same time, we provide the `dynamic_scale` parameter + to dynamically change the output image size. + + Args: + dataset (:obj:`CustomDataset`): The dataset to be mixed. + pipeline (Sequence[dict]): Sequence of transform object or + config dict to be composed. + dynamic_scale (tuple[int], optional): The image scale can be changed + dynamically. Default to None. It is deprecated. + skip_type_keys (list[str], optional): Sequence of type string to + be skip pipeline. Default to None. + max_refetch (int): The maximum number of retry iterations for getting + valid results from the pipeline. If the number of iterations is + greater than `max_refetch`, but results is still None, then the + iteration is terminated and raise the error. Default: 15. + """ + + def __init__(self, + dataset, + pipeline, + dynamic_scale=None, + skip_type_keys=None, + max_refetch=15): + if dynamic_scale is not None: + raise RuntimeError( + 'dynamic_scale is deprecated. Please use Resize pipeline ' + 'to achieve similar functions') + assert isinstance(pipeline, collections.abc.Sequence) + if skip_type_keys is not None: + assert all([ + isinstance(skip_type_key, str) + for skip_type_key in skip_type_keys + ]) + self._skip_type_keys = skip_type_keys + + self.pipeline = [] + self.pipeline_types = [] + for transform in pipeline: + if isinstance(transform, dict): + self.pipeline_types.append(transform['type']) + transform = build_from_cfg(transform, PIPELINES) + self.pipeline.append(transform) + else: + raise TypeError('pipeline must be a dict') + + self.dataset = dataset + self.CLASSES = dataset.CLASSES + self.PALETTE = getattr(dataset, 'PALETTE', None) + if hasattr(self.dataset, 'flag'): + self.flag = dataset.flag + self.num_samples = len(dataset) + self.max_refetch = max_refetch + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + results = copy.deepcopy(self.dataset[idx]) + for (transform, transform_type) in zip(self.pipeline, + self.pipeline_types): + if self._skip_type_keys is not None and \ + transform_type in self._skip_type_keys: + continue + + if hasattr(transform, 'get_indexes'): + for i in range(self.max_refetch): + # Make sure the results passed the loading pipeline + # of the original dataset is not None. + indexes = transform.get_indexes(self.dataset) + if not isinstance(indexes, collections.abc.Sequence): + indexes = [indexes] + mix_results = [ + copy.deepcopy(self.dataset[index]) for index in indexes + ] + if None not in mix_results: + results['mix_results'] = mix_results + break + else: + raise RuntimeError( + 'The loading pipeline of the original dataset' + ' always return None. Please check the correctness ' + 'of the dataset and its pipeline.') + + for i in range(self.max_refetch): + # To confirm the results passed the training pipeline + # of the wrapper is not None. + updated_results = transform(copy.deepcopy(results)) + if updated_results is not None: + results = updated_results + break + else: + raise RuntimeError( + 'The training pipeline of the dataset wrapper' + ' always return None.Please check the correctness ' + 'of the dataset and its pipeline.') + + if 'mix_results' in results: + results.pop('mix_results') + + return results + + def update_skip_type_keys(self, skip_type_keys): + """Update skip_type_keys. It is called by an external hook. + + Args: + skip_type_keys (list[str], optional): Sequence of type + string to be skip pipeline. + """ + assert all([ + isinstance(skip_type_key, str) for skip_type_key in skip_type_keys + ]) + self._skip_type_keys = skip_type_keys diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/__init__.py new file mode 100755 index 00000000..f16d18a7 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# from .auto_augment import (AutoAugment, BrightnessTransform, ColorTransform, +# ContrastTransform, EqualizeTransform, Rotate, Shear, +# Translate) +from .compose import Compose +from .formatting import (Collect, DefaultFormatBundle, ImageToTensor, + ToDataContainer, ToTensor, Transpose, to_tensor) +# from .instaboost import InstaBoost +from .loading import (FilterAnnotations, LoadAnnotations, LoadImageFromFile, + LoadImageFromWebcam, LoadMultiChannelImageFromFiles, + LoadPanopticAnnotations, LoadProposals) +from .test_time_aug import MultiScaleFlipAug +from .transforms import (Albu, CopyPaste, CutOut, Expand, MinIoURandomCrop, + MixUp, Mosaic, Normalize, Pad, PhotoMetricDistortion, + RandomAffine, RandomCenterCropPad, RandomCrop, + RandomFlip, RandomShift, Resize, SegRescale, + YOLOXHSVRandomAug) + +# __all__ = [ +# 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', +# 'Transpose', 'Collect', 'DefaultFormatBundle', 'LoadAnnotations', +# 'LoadImageFromFile', 'LoadImageFromWebcam', 'LoadPanopticAnnotations', +# 'LoadMultiChannelImageFromFiles', 'LoadProposals', 'FilterAnnotations', +# 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', +# 'Normalize', 'SegRescale', 'MinIoURandomCrop', 'Expand', +# 'PhotoMetricDistortion', 'Albu', 'InstaBoost', 'RandomCenterCropPad', +# 'AutoAugment', 'CutOut', 'Shear', 'Rotate', 'ColorTransform', +# 'EqualizeTransform', 'BrightnessTransform', 'ContrastTransform', +# 'Translate', 'RandomShift', 'Mosaic', 'MixUp', 'RandomAffine', +# 'YOLOXHSVRandomAug', 'CopyPaste' +# ] diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/compose.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/compose.py new file mode 100755 index 00000000..cbe56b4d --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/compose.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import collections + +from dbnet_cv.utils import build_from_cfg + +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class Compose: + """Compose multiple transforms sequentially. + + Args: + transforms (Sequence[dict | callable]): Sequence of transform object or + config dict to be composed. + """ + + def __init__(self, transforms): + assert isinstance(transforms, collections.abc.Sequence) + self.transforms = [] + for transform in transforms: + if isinstance(transform, dict): + transform = build_from_cfg(transform, PIPELINES) + self.transforms.append(transform) + elif callable(transform): + self.transforms.append(transform) + else: + raise TypeError('transform must be callable or a dict') + + def __call__(self, data): + """Call function to apply transforms sequentially. + + Args: + data (dict): A result dict contains the data to transform. + + Returns: + dict: Transformed data. + """ + + for t in self.transforms: + data = t(data) + if data is None: + return None + return data + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + str_ = t.__repr__() + if 'Compose(' in str_: + str_ = str_.replace('\n', '\n ') + format_string += '\n' + format_string += f' {str_}' + format_string += '\n)' + return format_string diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/formatting.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/formatting.py new file mode 100755 index 00000000..9454d849 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/formatting.py @@ -0,0 +1,392 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections.abc import Sequence + +import dbnet_cv +import numpy as np +import torch +from dbnet_cv.parallel import DataContainer as DC + +from ..builder import PIPELINES + + +def to_tensor(data): + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + + Args: + data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to + be converted. + """ + + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not dbnet_cv.is_str(data): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError(f'type {type(data)} cannot be converted to tensor.') + + +@PIPELINES.register_module() +class ToTensor: + """Convert some results to :obj:`torch.Tensor` by given keys. + + Args: + keys (Sequence[str]): Keys that need to be converted to Tensor. + """ + + def __init__(self, keys): + self.keys = keys + + def __call__(self, results): + """Call function to convert data in results to :obj:`torch.Tensor`. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data converted + to :obj:`torch.Tensor`. + """ + for key in self.keys: + results[key] = to_tensor(results[key]) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' + + +@PIPELINES.register_module() +class ImageToTensor: + """Convert image to :obj:`torch.Tensor` by given keys. + + The dimension order of input image is (H, W, C). The pipeline will convert + it to (C, H, W). If only 2 dimension (H, W) is given, the output would be + (1, H, W). + + Args: + keys (Sequence[str]): Key of images to be converted to Tensor. + """ + + def __init__(self, keys): + self.keys = keys + + def __call__(self, results): + """Call function to convert image in results to :obj:`torch.Tensor` and + transpose the channel order. + + Args: + results (dict): Result dict contains the image data to convert. + + Returns: + dict: The result dict contains the image converted + to :obj:`torch.Tensor` and transposed to (C, H, W) order. + """ + for key in self.keys: + img = results[key] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + results[key] = (to_tensor(img.transpose(2, 0, 1))).contiguous() + return results + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' + + +@PIPELINES.register_module() +class Transpose: + """Transpose some results by given keys. + + Args: + keys (Sequence[str]): Keys of results to be transposed. + order (Sequence[int]): Order of transpose. + """ + + def __init__(self, keys, order): + self.keys = keys + self.order = order + + def __call__(self, results): + """Call function to transpose the channel order of data in results. + + Args: + results (dict): Result dict contains the data to transpose. + + Returns: + dict: The result dict contains the data transposed to \ + ``self.order``. + """ + for key in self.keys: + results[key] = results[key].transpose(self.order) + return results + + def __repr__(self): + return self.__class__.__name__ + \ + f'(keys={self.keys}, order={self.order})' + + +@PIPELINES.register_module() +class ToDataContainer: + """Convert results to :obj:`dbnet_cv.DataContainer` by given fields. + + Args: + fields (Sequence[dict]): Each field is a dict like + ``dict(key='xxx', **kwargs)``. The ``key`` in result will + be converted to :obj:`dbnet_cv.DataContainer` with ``**kwargs``. + Default: ``(dict(key='img', stack=True), dict(key='gt_bboxes'), + dict(key='gt_labels'))``. + """ + + def __init__(self, + fields=(dict(key='img', stack=True), dict(key='gt_bboxes'), + dict(key='gt_labels'))): + self.fields = fields + + def __call__(self, results): + """Call function to convert data in results to + :obj:`dbnet_cv.DataContainer`. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data converted to \ + :obj:`dbnet_cv.DataContainer`. + """ + + for field in self.fields: + field = field.copy() + key = field.pop('key') + results[key] = DC(results[key], **field) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(fields={self.fields})' + + +@PIPELINES.register_module() +class DefaultFormatBundle: + """Default formatting bundle. + + It simplifies the pipeline of formatting common fields, including "img", + "proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg". + These fields are formatted as follows. + + - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) + - proposals: (1)to tensor, (2)to DataContainer + - gt_bboxes: (1)to tensor, (2)to DataContainer + - gt_bboxes_ignore: (1)to tensor, (2)to DataContainer + - gt_labels: (1)to tensor, (2)to DataContainer + - gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True) + - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, \ + (3)to DataContainer (stack=True) + + Args: + img_to_float (bool): Whether to force the image to be converted to + float type. Default: True. + pad_val (dict): A dict for padding value in batch collating, + the default value is `dict(img=0, masks=0, seg=255)`. + Without this argument, the padding value of "gt_semantic_seg" + will be set to 0 by default, which should be 255. + """ + + def __init__(self, + img_to_float=True, + pad_val=dict(img=0, masks=0, seg=255)): + self.img_to_float = img_to_float + self.pad_val = pad_val + + def __call__(self, results): + """Call function to transform and format common fields in results. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data that is formatted with \ + default bundle. + """ + + if 'img' in results: + img = results['img'] + if self.img_to_float is True and img.dtype == np.uint8: + # Normally, image is of uint8 type without normalization. + # At this time, it needs to be forced to be converted to + # flot32, otherwise the model training and inference + # will be wrong. Only used for YOLOX currently . + img = img.astype(np.float32) + # add default meta keys + results = self._add_default_meta_keys(results) + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + img = np.ascontiguousarray(img.transpose(2, 0, 1)) + results['img'] = DC( + to_tensor(img), padding_value=self.pad_val['img'], stack=True) + for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels']: + if key not in results: + continue + results[key] = DC(to_tensor(results[key])) + if 'gt_masks' in results: + results['gt_masks'] = DC( + results['gt_masks'], + padding_value=self.pad_val['masks'], + cpu_only=True) + if 'gt_semantic_seg' in results: + results['gt_semantic_seg'] = DC( + to_tensor(results['gt_semantic_seg'][None, ...]), + padding_value=self.pad_val['seg'], + stack=True) + return results + + def _add_default_meta_keys(self, results): + """Add default meta keys. + + We set default meta keys including `pad_shape`, `scale_factor` and + `img_norm_cfg` to avoid the case where no `Resize`, `Normalize` and + `Pad` are implemented during the whole pipeline. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + results (dict): Updated result dict contains the data to convert. + """ + img = results['img'] + results.setdefault('pad_shape', img.shape) + results.setdefault('scale_factor', 1.0) + num_channels = 1 if len(img.shape) < 3 else img.shape[2] + results.setdefault( + 'img_norm_cfg', + dict( + mean=np.zeros(num_channels, dtype=np.float32), + std=np.ones(num_channels, dtype=np.float32), + to_rgb=False)) + return results + + def __repr__(self): + return self.__class__.__name__ + \ + f'(img_to_float={self.img_to_float})' + + +@PIPELINES.register_module() +class Collect: + """Collect data from the loader relevant to the specific task. + + This is usually the last stage of the data loader pipeline. Typically keys + is set to some subset of "img", "proposals", "gt_bboxes", + "gt_bboxes_ignore", "gt_labels", and/or "gt_masks". + + The "img_meta" item is always populated. The contents of the "img_meta" + dictionary depends on "meta_keys". By default this includes: + + - "img_shape": shape of the image input to the network as a tuple \ + (h, w, c). Note that images may be zero padded on the \ + bottom/right if the batch tensor is larger than this shape. + + - "scale_factor": a float indicating the preprocessing scale + + - "flip": a boolean indicating if image flip transform was used + + - "filename": path to the image file + + - "ori_shape": original shape of the image as a tuple (h, w, c) + + - "pad_shape": image shape after padding + + - "img_norm_cfg": a dict of normalization information: + + - mean - per channel mean subtraction + - std - per channel std divisor + - to_rgb - bool indicating if bgr was converted to rgb + + Args: + keys (Sequence[str]): Keys of results to be collected in ``data``. + meta_keys (Sequence[str], optional): Meta keys to be converted to + ``dbnet_cv.DataContainer`` and collected in ``data[img_metas]``. + Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'img_norm_cfg')`` + """ + + def __init__(self, + keys, + meta_keys=('filename', 'ori_filename', 'ori_shape', + 'img_shape', 'pad_shape', 'scale_factor', 'flip', + 'flip_direction', 'img_norm_cfg')): + self.keys = keys + self.meta_keys = meta_keys + + def __call__(self, results): + """Call function to collect keys in results. The keys in ``meta_keys`` + will be converted to :obj:dbnet_cv.DataContainer. + + Args: + results (dict): Result dict contains the data to collect. + + Returns: + dict: The result dict contains the following keys + + - keys in``self.keys`` + - ``img_metas`` + """ + + data = {} + img_meta = {} + for key in self.meta_keys: + img_meta[key] = results[key] + data['img_metas'] = DC(img_meta, cpu_only=True) + for key in self.keys: + data[key] = results[key] + return data + + def __repr__(self): + return self.__class__.__name__ + \ + f'(keys={self.keys}, meta_keys={self.meta_keys})' + + +@PIPELINES.register_module() +class WrapFieldsToLists: + """Wrap fields of the data dictionary into lists for evaluation. + + This class can be used as a last step of a test or validation + pipeline for single image evaluation or inference. + + Example: + >>> test_pipeline = [ + >>> dict(type='LoadImageFromFile'), + >>> dict(type='Normalize', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True), + >>> dict(type='Pad', size_divisor=32), + >>> dict(type='ImageToTensor', keys=['img']), + >>> dict(type='Collect', keys=['img']), + >>> dict(type='WrapFieldsToLists') + >>> ] + """ + + def __call__(self, results): + """Call function to wrap fields into lists. + + Args: + results (dict): Result dict contains the data to wrap. + + Returns: + dict: The result dict where value of ``self.keys`` are wrapped \ + into list. + """ + + # Wrap dict fields into lists + for key, val in results.items(): + results[key] = [val] + return results + + def __repr__(self): + return f'{self.__class__.__name__}()' diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/loading.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/loading.py new file mode 100755 index 00000000..1c29642e --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/loading.py @@ -0,0 +1,643 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +import dbnet_cv +import numpy as np +import pycocotools.mask as maskUtils + +from dbnet_det.core import BitmapMasks, PolygonMasks +from ..builder import PIPELINES + +try: + from panopticapi.utils import rgb2id +except ImportError: + rgb2id = None + + +@PIPELINES.register_module() +class LoadImageFromFile: + """Load an image from file. + + Required keys are "img_prefix" and "img_info" (a dict that must contain the + key "filename"). Added or updated keys are "filename", "img", "img_shape", + "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), + "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + color_type (str): The flag argument for :func:`dbnet_cv.imfrombytes`. + Defaults to 'color'. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`dbnet_cv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + """ + + def __init__(self, + to_float32=False, + color_type='color', + channel_order='bgr', + file_client_args=dict(backend='disk')): + self.to_float32 = to_float32 + self.color_type = color_type + self.channel_order = channel_order + self.file_client_args = file_client_args.copy() + self.file_client = None + + def __call__(self, results): + """Call functions to load image and get image meta information. + + Args: + results (dict): Result dict from :obj:`dbnet_det.CustomDataset`. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + if self.file_client is None: + self.file_client = dbnet_cv.FileClient(**self.file_client_args) + + if results['img_prefix'] is not None: + filename = osp.join(results['img_prefix'], + results['img_info']['filename']) + else: + filename = results['img_info']['filename'] + + img_bytes = self.file_client.get(filename) + img = dbnet_cv.imfrombytes( + img_bytes, flag=self.color_type, channel_order=self.channel_order) + if self.to_float32: + img = img.astype(np.float32) + + results['filename'] = filename + results['ori_filename'] = results['img_info']['filename'] + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + results['img_fields'] = ['img'] + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'to_float32={self.to_float32}, ' + f"color_type='{self.color_type}', " + f"channel_order='{self.channel_order}', " + f'file_client_args={self.file_client_args})') + return repr_str + + +@PIPELINES.register_module() +class LoadImageFromWebcam(LoadImageFromFile): + """Load an image from webcam. + + Similar with :obj:`LoadImageFromFile`, but the image read from webcam is in + ``results['img']``. + """ + + def __call__(self, results): + """Call functions to add image meta information. + + Args: + results (dict): Result dict with Webcam read image in + ``results['img']``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + img = results['img'] + if self.to_float32: + img = img.astype(np.float32) + + results['filename'] = None + results['ori_filename'] = None + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + results['img_fields'] = ['img'] + return results + + +@PIPELINES.register_module() +class LoadMultiChannelImageFromFiles: + """Load multi-channel images from a list of separate channel files. + + Required keys are "img_prefix" and "img_info" (a dict that must contain the + key "filename", which is expected to be a list of filenames). + Added or updated keys are "filename", "img", "img_shape", + "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), + "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + color_type (str): The flag argument for :func:`dbnet_cv.imfrombytes`. + Defaults to 'color'. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`dbnet_cv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + """ + + def __init__(self, + to_float32=False, + color_type='unchanged', + file_client_args=dict(backend='disk')): + self.to_float32 = to_float32 + self.color_type = color_type + self.file_client_args = file_client_args.copy() + self.file_client = None + + def __call__(self, results): + """Call functions to load multiple images and get images meta + information. + + Args: + results (dict): Result dict from :obj:`dbnet_det.CustomDataset`. + + Returns: + dict: The dict contains loaded images and meta information. + """ + + if self.file_client is None: + self.file_client = dbnet_cv.FileClient(**self.file_client_args) + + if results['img_prefix'] is not None: + filename = [ + osp.join(results['img_prefix'], fname) + for fname in results['img_info']['filename'] + ] + else: + filename = results['img_info']['filename'] + + img = [] + for name in filename: + img_bytes = self.file_client.get(name) + img.append(dbnet_cv.imfrombytes(img_bytes, flag=self.color_type)) + img = np.stack(img, axis=-1) + if self.to_float32: + img = img.astype(np.float32) + + results['filename'] = filename + results['ori_filename'] = results['img_info']['filename'] + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 + num_channels = 1 if len(img.shape) < 3 else img.shape[2] + results['img_norm_cfg'] = dict( + mean=np.zeros(num_channels, dtype=np.float32), + std=np.ones(num_channels, dtype=np.float32), + to_rgb=False) + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'to_float32={self.to_float32}, ' + f"color_type='{self.color_type}', " + f'file_client_args={self.file_client_args})') + return repr_str + + +@PIPELINES.register_module() +class LoadAnnotations: + """Load multiple types of annotations. + + Args: + with_bbox (bool): Whether to parse and load the bbox annotation. + Default: True. + with_label (bool): Whether to parse and load the label annotation. + Default: True. + with_mask (bool): Whether to parse and load the mask annotation. + Default: False. + with_seg (bool): Whether to parse and load the semantic segmentation + annotation. Default: False. + poly2mask (bool): Whether to convert the instance masks from polygons + to bitmaps. Default: True. + denorm_bbox (bool): Whether to convert bbox from relative value to + absolute value. Only used in OpenImage Dataset. + Default: False. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`dbnet_cv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + """ + + def __init__(self, + with_bbox=True, + with_label=True, + with_mask=False, + with_seg=False, + poly2mask=True, + denorm_bbox=False, + file_client_args=dict(backend='disk')): + self.with_bbox = with_bbox + self.with_label = with_label + self.with_mask = with_mask + self.with_seg = with_seg + self.poly2mask = poly2mask + self.denorm_bbox = denorm_bbox + self.file_client_args = file_client_args.copy() + self.file_client = None + + def _load_bboxes(self, results): + """Private function to load bounding box annotations. + + Args: + results (dict): Result dict from :obj:`dbnet_det.CustomDataset`. + + Returns: + dict: The dict contains loaded bounding box annotations. + """ + + ann_info = results['ann_info'] + results['gt_bboxes'] = ann_info['bboxes'].copy() + + if self.denorm_bbox: + bbox_num = results['gt_bboxes'].shape[0] + if bbox_num != 0: + h, w = results['img_shape'][:2] + results['gt_bboxes'][:, 0::2] *= w + results['gt_bboxes'][:, 1::2] *= h + + gt_bboxes_ignore = ann_info.get('bboxes_ignore', None) + if gt_bboxes_ignore is not None: + results['gt_bboxes_ignore'] = gt_bboxes_ignore.copy() + results['bbox_fields'].append('gt_bboxes_ignore') + results['bbox_fields'].append('gt_bboxes') + + gt_is_group_ofs = ann_info.get('gt_is_group_ofs', None) + if gt_is_group_ofs is not None: + results['gt_is_group_ofs'] = gt_is_group_ofs.copy() + + return results + + def _load_labels(self, results): + """Private function to load label annotations. + + Args: + results (dict): Result dict from :obj:`dbnet_det.CustomDataset`. + + Returns: + dict: The dict contains loaded label annotations. + """ + + results['gt_labels'] = results['ann_info']['labels'].copy() + return results + + def _poly2mask(self, mask_ann, img_h, img_w): + """Private function to convert masks represented with polygon to + bitmaps. + + Args: + mask_ann (list | dict): Polygon mask annotation input. + img_h (int): The height of output mask. + img_w (int): The width of output mask. + + Returns: + numpy.ndarray: The decode bitmap mask of shape (img_h, img_w). + """ + + if isinstance(mask_ann, list): + # polygon -- a single object might consist of multiple parts + # we merge all parts into one mask rle code + rles = maskUtils.frPyObjects(mask_ann, img_h, img_w) + rle = maskUtils.merge(rles) + elif isinstance(mask_ann['counts'], list): + # uncompressed RLE + rle = maskUtils.frPyObjects(mask_ann, img_h, img_w) + else: + # rle + rle = mask_ann + mask = maskUtils.decode(rle) + return mask + + def process_polygons(self, polygons): + """Convert polygons to list of ndarray and filter invalid polygons. + + Args: + polygons (list[list]): Polygons of one instance. + + Returns: + list[numpy.ndarray]: Processed polygons. + """ + + polygons = [np.array(p) for p in polygons] + valid_polygons = [] + for polygon in polygons: + if len(polygon) % 2 == 0 and len(polygon) >= 6: + valid_polygons.append(polygon) + return valid_polygons + + def _load_masks(self, results): + """Private function to load mask annotations. + + Args: + results (dict): Result dict from :obj:`dbnet_det.CustomDataset`. + + Returns: + dict: The dict contains loaded mask annotations. + If ``self.poly2mask`` is set ``True``, `gt_mask` will contain + :obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used. + """ + + h, w = results['img_info']['height'], results['img_info']['width'] + gt_masks = results['ann_info']['masks'] + if self.poly2mask: + gt_masks = BitmapMasks( + [self._poly2mask(mask, h, w) for mask in gt_masks], h, w) + else: + gt_masks = PolygonMasks( + [self.process_polygons(polygons) for polygons in gt_masks], h, + w) + results['gt_masks'] = gt_masks + results['mask_fields'].append('gt_masks') + return results + + def _load_semantic_seg(self, results): + """Private function to load semantic segmentation annotations. + + Args: + results (dict): Result dict from :obj:`dataset`. + + Returns: + dict: The dict contains loaded semantic segmentation annotations. + """ + + if self.file_client is None: + self.file_client = dbnet_cv.FileClient(**self.file_client_args) + + filename = osp.join(results['seg_prefix'], + results['ann_info']['seg_map']) + img_bytes = self.file_client.get(filename) + results['gt_semantic_seg'] = dbnet_cv.imfrombytes( + img_bytes, flag='unchanged').squeeze() + results['seg_fields'].append('gt_semantic_seg') + return results + + def __call__(self, results): + """Call function to load multiple types annotations. + + Args: + results (dict): Result dict from :obj:`dbnet_det.CustomDataset`. + + Returns: + dict: The dict contains loaded bounding box, label, mask and + semantic segmentation annotations. + """ + + if self.with_bbox: + results = self._load_bboxes(results) + if results is None: + return None + if self.with_label: + results = self._load_labels(results) + if self.with_mask: + results = self._load_masks(results) + if self.with_seg: + results = self._load_semantic_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(with_bbox={self.with_bbox}, ' + repr_str += f'with_label={self.with_label}, ' + repr_str += f'with_mask={self.with_mask}, ' + repr_str += f'with_seg={self.with_seg}, ' + repr_str += f'poly2mask={self.poly2mask}, ' + repr_str += f'poly2mask={self.file_client_args})' + return repr_str + + +@PIPELINES.register_module() +class LoadPanopticAnnotations(LoadAnnotations): + """Load multiple types of panoptic annotations. + + Args: + with_bbox (bool): Whether to parse and load the bbox annotation. + Default: True. + with_label (bool): Whether to parse and load the label annotation. + Default: True. + with_mask (bool): Whether to parse and load the mask annotation. + Default: True. + with_seg (bool): Whether to parse and load the semantic segmentation + annotation. Default: True. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`dbnet_cv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + """ + + def __init__(self, + with_bbox=True, + with_label=True, + with_mask=True, + with_seg=True, + file_client_args=dict(backend='disk')): + if rgb2id is None: + raise RuntimeError( + 'panopticapi is not installed, please install it by: ' + 'pip install git+https://github.com/cocodataset/' + 'panopticapi.git.') + + super(LoadPanopticAnnotations, self).__init__( + with_bbox=with_bbox, + with_label=with_label, + with_mask=with_mask, + with_seg=with_seg, + poly2mask=True, + denorm_bbox=False, + file_client_args=file_client_args) + + def _load_masks_and_semantic_segs(self, results): + """Private function to load mask and semantic segmentation annotations. + + In gt_semantic_seg, the foreground label is from `0` to + `num_things - 1`, the background label is from `num_things` to + `num_things + num_stuff - 1`, 255 means the ignored label (`VOID`). + + Args: + results (dict): Result dict from :obj:`dbnet_det.CustomDataset`. + + Returns: + dict: The dict contains loaded mask and semantic segmentation + annotations. `BitmapMasks` is used for mask annotations. + """ + + if self.file_client is None: + self.file_client = dbnet_cv.FileClient(**self.file_client_args) + + filename = osp.join(results['seg_prefix'], + results['ann_info']['seg_map']) + img_bytes = self.file_client.get(filename) + pan_png = dbnet_cv.imfrombytes( + img_bytes, flag='color', channel_order='rgb').squeeze() + pan_png = rgb2id(pan_png) + + gt_masks = [] + gt_seg = np.zeros_like(pan_png) + 255 # 255 as ignore + + for mask_info in results['ann_info']['masks']: + mask = (pan_png == mask_info['id']) + gt_seg = np.where(mask, mask_info['category'], gt_seg) + + # The legal thing masks + if mask_info.get('is_thing'): + gt_masks.append(mask.astype(np.uint8)) + + if self.with_mask: + h, w = results['img_info']['height'], results['img_info']['width'] + gt_masks = BitmapMasks(gt_masks, h, w) + results['gt_masks'] = gt_masks + results['mask_fields'].append('gt_masks') + + if self.with_seg: + results['gt_semantic_seg'] = gt_seg + results['seg_fields'].append('gt_semantic_seg') + return results + + def __call__(self, results): + """Call function to load multiple types panoptic annotations. + + Args: + results (dict): Result dict from :obj:`dbnet_det.CustomDataset`. + + Returns: + dict: The dict contains loaded bounding box, label, mask and + semantic segmentation annotations. + """ + + if self.with_bbox: + results = self._load_bboxes(results) + if results is None: + return None + if self.with_label: + results = self._load_labels(results) + if self.with_mask or self.with_seg: + # The tasks completed by '_load_masks' and '_load_semantic_segs' + # in LoadAnnotations are merged to one function. + results = self._load_masks_and_semantic_segs(results) + + return results + + +@PIPELINES.register_module() +class LoadProposals: + """Load proposal pipeline. + + Required key is "proposals". Updated keys are "proposals", "bbox_fields". + + Args: + num_max_proposals (int, optional): Maximum number of proposals to load. + If not specified, all proposals will be loaded. + """ + + def __init__(self, num_max_proposals=None): + self.num_max_proposals = num_max_proposals + + def __call__(self, results): + """Call function to load proposals from file. + + Args: + results (dict): Result dict from :obj:`dbnet_det.CustomDataset`. + + Returns: + dict: The dict contains loaded proposal annotations. + """ + + proposals = results['proposals'] + if proposals.shape[1] not in (4, 5): + raise AssertionError( + 'proposals should have shapes (n, 4) or (n, 5), ' + f'but found {proposals.shape}') + proposals = proposals[:, :4] + + if self.num_max_proposals is not None: + proposals = proposals[:self.num_max_proposals] + + if len(proposals) == 0: + proposals = np.array([[0, 0, 0, 0]], dtype=np.float32) + results['proposals'] = proposals + results['bbox_fields'].append('proposals') + return results + + def __repr__(self): + return self.__class__.__name__ + \ + f'(num_max_proposals={self.num_max_proposals})' + + +@PIPELINES.register_module() +class FilterAnnotations: + """Filter invalid annotations. + + Args: + min_gt_bbox_wh (tuple[float]): Minimum width and height of ground truth + boxes. Default: (1., 1.) + min_gt_mask_area (int): Minimum foreground area of ground truth masks. + Default: 1 + by_box (bool): Filter instances with bounding boxes not meeting the + min_gt_bbox_wh threshold. Default: True + by_mask (bool): Filter instances with masks not meeting + min_gt_mask_area threshold. Default: False + keep_empty (bool): Whether to return None when it + becomes an empty bbox after filtering. Default: True + """ + + def __init__(self, + min_gt_bbox_wh=(1., 1.), + min_gt_mask_area=1, + by_box=True, + by_mask=False, + keep_empty=True): + # TODO: add more filter options + assert by_box or by_mask + self.min_gt_bbox_wh = min_gt_bbox_wh + self.min_gt_mask_area = min_gt_mask_area + self.by_box = by_box + self.by_mask = by_mask + self.keep_empty = keep_empty + + def __call__(self, results): + if self.by_box: + assert 'gt_bboxes' in results + gt_bboxes = results['gt_bboxes'] + instance_num = gt_bboxes.shape[0] + if self.by_mask: + assert 'gt_masks' in results + gt_masks = results['gt_masks'] + instance_num = len(gt_masks) + + if instance_num == 0: + return results + + tests = [] + if self.by_box: + w = gt_bboxes[:, 2] - gt_bboxes[:, 0] + h = gt_bboxes[:, 3] - gt_bboxes[:, 1] + tests.append((w > self.min_gt_bbox_wh[0]) + & (h > self.min_gt_bbox_wh[1])) + if self.by_mask: + gt_masks = results['gt_masks'] + tests.append(gt_masks.areas >= self.min_gt_mask_area) + + keep = tests[0] + for t in tests[1:]: + keep = keep & t + + keys = ('gt_bboxes', 'gt_labels', 'gt_masks') + for key in keys: + if key in results: + results[key] = results[key][keep] + if not keep.any(): + if self.keep_empty: + return None + return results + + def __repr__(self): + return self.__class__.__name__ + \ + f'(min_gt_bbox_wh={self.min_gt_bbox_wh},' \ + f'(min_gt_mask_area={self.min_gt_mask_area},' \ + f'(by_box={self.by_box},' \ + f'(by_mask={self.by_mask},' \ + f'always_keep={self.always_keep})' diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/test_time_aug.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/test_time_aug.py new file mode 100755 index 00000000..853b9712 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/test_time_aug.py @@ -0,0 +1,121 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import dbnet_cv + +from ..builder import PIPELINES +from .compose import Compose + + +@PIPELINES.register_module() +class MultiScaleFlipAug: + """Test-time augmentation with multiple scales and flipping. + + An example configuration is as followed: + + .. code-block:: + + img_scale=[(1333, 400), (1333, 800)], + flip=True, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ] + + After MultiScaleFLipAug with above configuration, the results are wrapped + into lists of the same length as followed: + + .. code-block:: + + dict( + img=[...], + img_shape=[...], + scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)] + flip=[False, True, False, True] + ... + ) + + Args: + transforms (list[dict]): Transforms to apply in each augmentation. + img_scale (tuple | list[tuple] | None): Images scales for resizing. + scale_factor (float | list[float] | None): Scale factors for resizing. + flip (bool): Whether apply flip augmentation. Default: False. + flip_direction (str | list[str]): Flip augmentation directions, + options are "horizontal", "vertical" and "diagonal". If + flip_direction is a list, multiple flip augmentations will be + applied. It has no effect when flip == False. Default: + "horizontal". + """ + + def __init__(self, + transforms, + img_scale=None, + scale_factor=None, + flip=False, + flip_direction='horizontal'): + self.transforms = Compose(transforms) + assert (img_scale is None) ^ (scale_factor is None), ( + 'Must have but only one variable can be set') + if img_scale is not None: + self.img_scale = img_scale if isinstance(img_scale, + list) else [img_scale] + self.scale_key = 'scale' + assert dbnet_cv.is_list_of(self.img_scale, tuple) + else: + self.img_scale = scale_factor if isinstance( + scale_factor, list) else [scale_factor] + self.scale_key = 'scale_factor' + + self.flip = flip + self.flip_direction = flip_direction if isinstance( + flip_direction, list) else [flip_direction] + assert dbnet_cv.is_list_of(self.flip_direction, str) + if not self.flip and self.flip_direction != ['horizontal']: + warnings.warn( + 'flip_direction has no effect when flip is set to False') + if (self.flip + and not any([t['type'] == 'RandomFlip' for t in transforms])): + warnings.warn( + 'flip has no effect when RandomFlip is not in transforms') + + def __call__(self, results): + """Call function to apply test time augment transforms on results. + + Args: + results (dict): Result dict contains the data to transform. + + Returns: + dict[str: list]: The augmented data, where each value is wrapped + into a list. + """ + + aug_data = [] + flip_args = [(False, None)] + if self.flip: + flip_args += [(True, direction) + for direction in self.flip_direction] + for scale in self.img_scale: + for flip, direction in flip_args: + _results = results.copy() + _results[self.scale_key] = scale + _results['flip'] = flip + _results['flip_direction'] = direction + data = self.transforms(_results) + aug_data.append(data) + # list of dict to dict of list + aug_data_dict = {key: [] for key in aug_data[0]} + for data in aug_data: + for key, val in data.items(): + aug_data_dict[key].append(val) + return aug_data_dict + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(transforms={self.transforms}, ' + repr_str += f'img_scale={self.img_scale}, flip={self.flip}, ' + repr_str += f'flip_direction={self.flip_direction})' + return repr_str diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/transforms.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/transforms.py new file mode 100755 index 00000000..d4e69351 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/pipelines/transforms.py @@ -0,0 +1,2919 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import inspect +import math +import warnings + +import cv2 +import dbnet_cv +import numpy as np +from numpy import random + +from dbnet_det.core import BitmapMasks, PolygonMasks, find_inside_bboxes +from dbnet_det.core.evaluation.bbox_overlaps import bbox_overlaps +from dbnet_det.utils import log_img_scale +from ..builder import PIPELINES + +try: + from imagecorruptions import corrupt +except ImportError: + corrupt = None + +try: + import albumentations + from albumentations import Compose +except ImportError: + albumentations = None + Compose = None + + +@PIPELINES.register_module() +class Resize: + """Resize images & bbox & mask. + + This transform resizes the input image to some scale. Bboxes and masks are + then resized with the same scale factor. If the input dict contains the key + "scale", then the scale in the input dict is used, otherwise the specified + scale in the init method is used. If the input dict contains the key + "scale_factor" (if MultiScaleFlipAug does not give img_scale but + scale_factor), the actual scale will be computed by image shape and + scale_factor. + + `img_scale` can either be a tuple (single-scale) or a list of tuple + (multi-scale). There are 3 multiscale modes: + + - ``ratio_range is not None``: randomly sample a ratio from the ratio \ + range and multiply it with the image scale. + - ``ratio_range is None`` and ``multiscale_mode == "range"``: randomly \ + sample a scale from the multiscale range. + - ``ratio_range is None`` and ``multiscale_mode == "value"``: randomly \ + sample a scale from multiple scales. + + Args: + img_scale (tuple or list[tuple]): Images scales for resizing. + multiscale_mode (str): Either "range" or "value". + ratio_range (tuple[float]): (min_ratio, max_ratio) + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. + bbox_clip_border (bool, optional): Whether to clip the objects outside + the border of the image. In some dataset like MOT17, the gt bboxes + are allowed to cross the border of images. Therefore, we don't + need to clip the gt bboxes in these cases. Defaults to True. + backend (str): Image resize backend, choices are 'cv2' and 'pillow'. + These two backends generates slightly different results. Defaults + to 'cv2'. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. + override (bool, optional): Whether to override `scale` and + `scale_factor` so as to call resize twice. Default False. If True, + after the first resizing, the existed `scale` and `scale_factor` + will be ignored so the second resizing can be allowed. + This option is a work-around for multiple times of resize in DETR. + Defaults to False. + """ + + def __init__(self, + img_scale=None, + multiscale_mode='range', + ratio_range=None, + keep_ratio=True, + bbox_clip_border=True, + backend='cv2', + interpolation='bilinear', + override=False): + if img_scale is None: + self.img_scale = None + else: + if isinstance(img_scale, list): + self.img_scale = img_scale + else: + self.img_scale = [img_scale] + assert dbnet_cv.is_list_of(self.img_scale, tuple) + + if ratio_range is not None: + # mode 1: given a scale and a range of image ratio + assert len(self.img_scale) == 1 + else: + # mode 2: given multiple scales or a range of scales + assert multiscale_mode in ['value', 'range'] + + self.backend = backend + self.multiscale_mode = multiscale_mode + self.ratio_range = ratio_range + self.keep_ratio = keep_ratio + # TODO: refactor the override option in Resize + self.interpolation = interpolation + self.override = override + self.bbox_clip_border = bbox_clip_border + + @staticmethod + def random_select(img_scales): + """Randomly select an img_scale from given candidates. + + Args: + img_scales (list[tuple]): Images scales for selection. + + Returns: + (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, \ + where ``img_scale`` is the selected image scale and \ + ``scale_idx`` is the selected index in the given candidates. + """ + + assert dbnet_cv.is_list_of(img_scales, tuple) + scale_idx = np.random.randint(len(img_scales)) + img_scale = img_scales[scale_idx] + return img_scale, scale_idx + + @staticmethod + def random_sample(img_scales): + """Randomly sample an img_scale when ``multiscale_mode=='range'``. + + Args: + img_scales (list[tuple]): Images scale range for sampling. + There must be two tuples in img_scales, which specify the lower + and upper bound of image scales. + + Returns: + (tuple, None): Returns a tuple ``(img_scale, None)``, where \ + ``img_scale`` is sampled scale and None is just a placeholder \ + to be consistent with :func:`random_select`. + """ + + assert dbnet_cv.is_list_of(img_scales, tuple) and len(img_scales) == 2 + img_scale_long = [max(s) for s in img_scales] + img_scale_short = [min(s) for s in img_scales] + long_edge = np.random.randint( + min(img_scale_long), + max(img_scale_long) + 1) + short_edge = np.random.randint( + min(img_scale_short), + max(img_scale_short) + 1) + img_scale = (long_edge, short_edge) + return img_scale, None + + @staticmethod + def random_sample_ratio(img_scale, ratio_range): + """Randomly sample an img_scale when ``ratio_range`` is specified. + + A ratio will be randomly sampled from the range specified by + ``ratio_range``. Then it would be multiplied with ``img_scale`` to + generate sampled scale. + + Args: + img_scale (tuple): Images scale base to multiply with ratio. + ratio_range (tuple[float]): The minimum and maximum ratio to scale + the ``img_scale``. + + Returns: + (tuple, None): Returns a tuple ``(scale, None)``, where \ + ``scale`` is sampled ratio multiplied with ``img_scale`` and \ + None is just a placeholder to be consistent with \ + :func:`random_select`. + """ + + assert isinstance(img_scale, tuple) and len(img_scale) == 2 + min_ratio, max_ratio = ratio_range + assert min_ratio <= max_ratio + ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio + scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) + return scale, None + + def _random_scale(self, results): + """Randomly sample an img_scale according to ``ratio_range`` and + ``multiscale_mode``. + + If ``ratio_range`` is specified, a ratio will be sampled and be + multiplied with ``img_scale``. + If multiple scales are specified by ``img_scale``, a scale will be + sampled according to ``multiscale_mode``. + Otherwise, single scale will be used. + + Args: + results (dict): Result dict from :obj:`dataset`. + + Returns: + dict: Two new keys 'scale` and 'scale_idx` are added into \ + ``results``, which would be used by subsequent pipelines. + """ + + if self.ratio_range is not None: + scale, scale_idx = self.random_sample_ratio( + self.img_scale[0], self.ratio_range) + elif len(self.img_scale) == 1: + scale, scale_idx = self.img_scale[0], 0 + elif self.multiscale_mode == 'range': + scale, scale_idx = self.random_sample(self.img_scale) + elif self.multiscale_mode == 'value': + scale, scale_idx = self.random_select(self.img_scale) + else: + raise NotImplementedError + + results['scale'] = scale + results['scale_idx'] = scale_idx + + def _resize_img(self, results): + """Resize images with ``results['scale']``.""" + for key in results.get('img_fields', ['img']): + if self.keep_ratio: + img, scale_factor = dbnet_cv.imrescale( + results[key], + results['scale'], + return_scale=True, + interpolation=self.interpolation, + backend=self.backend) + # the w_scale and h_scale has minor difference + # a real fix should be done in the dbnet_cv.imrescale in the future + new_h, new_w = img.shape[:2] + h, w = results[key].shape[:2] + w_scale = new_w / w + h_scale = new_h / h + else: + img, w_scale, h_scale = dbnet_cv.imresize( + results[key], + results['scale'], + return_scale=True, + interpolation=self.interpolation, + backend=self.backend) + results[key] = img + + scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], + dtype=np.float32) + results['img_shape'] = img.shape + # in case that there is no padding + results['pad_shape'] = img.shape + results['scale_factor'] = scale_factor + results['keep_ratio'] = self.keep_ratio + + def _resize_bboxes(self, results): + """Resize bounding boxes with ``results['scale_factor']``.""" + for key in results.get('bbox_fields', []): + bboxes = results[key] * results['scale_factor'] + if self.bbox_clip_border: + img_shape = results['img_shape'] + bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) + bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0]) + results[key] = bboxes + + def _resize_masks(self, results): + """Resize masks with ``results['scale']``""" + for key in results.get('mask_fields', []): + if results[key] is None: + continue + if self.keep_ratio: + results[key] = results[key].rescale(results['scale']) + else: + results[key] = results[key].resize(results['img_shape'][:2]) + + def _resize_seg(self, results): + """Resize semantic segmentation map with ``results['scale']``.""" + for key in results.get('seg_fields', []): + if self.keep_ratio: + gt_seg = dbnet_cv.imrescale( + results[key], + results['scale'], + interpolation='nearest', + backend=self.backend) + else: + gt_seg = dbnet_cv.imresize( + results[key], + results['scale'], + interpolation='nearest', + backend=self.backend) + results[key] = gt_seg + + def __call__(self, results): + """Call function to resize images, bounding boxes, masks, semantic + segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', \ + 'keep_ratio' keys are added into result dict. + """ + + if 'scale' not in results: + if 'scale_factor' in results: + img_shape = results['img'].shape[:2] + scale_factor = results['scale_factor'] + assert isinstance(scale_factor, float) + results['scale'] = tuple( + [int(x * scale_factor) for x in img_shape][::-1]) + else: + self._random_scale(results) + else: + if not self.override: + assert 'scale_factor' not in results, ( + 'scale and scale_factor cannot be both set.') + else: + results.pop('scale') + if 'scale_factor' in results: + results.pop('scale_factor') + self._random_scale(results) + + self._resize_img(results) + self._resize_bboxes(results) + self._resize_masks(results) + self._resize_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(img_scale={self.img_scale}, ' + repr_str += f'multiscale_mode={self.multiscale_mode}, ' + repr_str += f'ratio_range={self.ratio_range}, ' + repr_str += f'keep_ratio={self.keep_ratio}, ' + repr_str += f'bbox_clip_border={self.bbox_clip_border})' + return repr_str + + +@PIPELINES.register_module() +class RandomFlip: + """Flip the image & bbox & mask. + + If the input dict contains the key "flip", then the flag will be used, + otherwise it will be randomly decided by a ratio specified in the init + method. + + When random flip is enabled, ``flip_ratio``/``direction`` can either be a + float/string or tuple of float/string. There are 3 flip modes: + + - ``flip_ratio`` is float, ``direction`` is string: the image will be + ``direction``ly flipped with probability of ``flip_ratio`` . + E.g., ``flip_ratio=0.5``, ``direction='horizontal'``, + then image will be horizontally flipped with probability of 0.5. + - ``flip_ratio`` is float, ``direction`` is list of string: the image will + be ``direction[i]``ly flipped with probability of + ``flip_ratio/len(direction)``. + E.g., ``flip_ratio=0.5``, ``direction=['horizontal', 'vertical']``, + then image will be horizontally flipped with probability of 0.25, + vertically with probability of 0.25. + - ``flip_ratio`` is list of float, ``direction`` is list of string: + given ``len(flip_ratio) == len(direction)``, the image will + be ``direction[i]``ly flipped with probability of ``flip_ratio[i]``. + E.g., ``flip_ratio=[0.3, 0.5]``, ``direction=['horizontal', + 'vertical']``, then image will be horizontally flipped with probability + of 0.3, vertically with probability of 0.5. + + Args: + flip_ratio (float | list[float], optional): The flipping probability. + Default: None. + direction(str | list[str], optional): The flipping direction. Options + are 'horizontal', 'vertical', 'diagonal'. Default: 'horizontal'. + If input is a list, the length must equal ``flip_ratio``. Each + element in ``flip_ratio`` indicates the flip probability of + corresponding direction. + """ + + def __init__(self, flip_ratio=None, direction='horizontal'): + if isinstance(flip_ratio, list): + assert dbnet_cv.is_list_of(flip_ratio, float) + assert 0 <= sum(flip_ratio) <= 1 + elif isinstance(flip_ratio, float): + assert 0 <= flip_ratio <= 1 + elif flip_ratio is None: + pass + else: + raise ValueError('flip_ratios must be None, float, ' + 'or list of float') + self.flip_ratio = flip_ratio + + valid_directions = ['horizontal', 'vertical', 'diagonal'] + if isinstance(direction, str): + assert direction in valid_directions + elif isinstance(direction, list): + assert dbnet_cv.is_list_of(direction, str) + assert set(direction).issubset(set(valid_directions)) + else: + raise ValueError('direction must be either str or list of str') + self.direction = direction + + if isinstance(flip_ratio, list): + assert len(self.flip_ratio) == len(self.direction) + + def bbox_flip(self, bboxes, img_shape, direction): + """Flip bboxes horizontally. + + Args: + bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k) + img_shape (tuple[int]): Image shape (height, width) + direction (str): Flip direction. Options are 'horizontal', + 'vertical'. + + Returns: + numpy.ndarray: Flipped bounding boxes. + """ + + assert bboxes.shape[-1] % 4 == 0 + flipped = bboxes.copy() + if direction == 'horizontal': + w = img_shape[1] + flipped[..., 0::4] = w - bboxes[..., 2::4] + flipped[..., 2::4] = w - bboxes[..., 0::4] + elif direction == 'vertical': + h = img_shape[0] + flipped[..., 1::4] = h - bboxes[..., 3::4] + flipped[..., 3::4] = h - bboxes[..., 1::4] + elif direction == 'diagonal': + w = img_shape[1] + h = img_shape[0] + flipped[..., 0::4] = w - bboxes[..., 2::4] + flipped[..., 1::4] = h - bboxes[..., 3::4] + flipped[..., 2::4] = w - bboxes[..., 0::4] + flipped[..., 3::4] = h - bboxes[..., 1::4] + else: + raise ValueError(f"Invalid flipping direction '{direction}'") + return flipped + + def __call__(self, results): + """Call function to flip bounding boxes, masks, semantic segmentation + maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Flipped results, 'flip', 'flip_direction' keys are added \ + into result dict. + """ + + if 'flip' not in results: + if isinstance(self.direction, list): + # None means non-flip + direction_list = self.direction + [None] + else: + # None means non-flip + direction_list = [self.direction, None] + + if isinstance(self.flip_ratio, list): + non_flip_ratio = 1 - sum(self.flip_ratio) + flip_ratio_list = self.flip_ratio + [non_flip_ratio] + else: + non_flip_ratio = 1 - self.flip_ratio + # exclude non-flip + single_ratio = self.flip_ratio / (len(direction_list) - 1) + flip_ratio_list = [single_ratio] * (len(direction_list) - + 1) + [non_flip_ratio] + + cur_dir = np.random.choice(direction_list, p=flip_ratio_list) + + results['flip'] = cur_dir is not None + if 'flip_direction' not in results: + results['flip_direction'] = cur_dir + if results['flip']: + # flip image + for key in results.get('img_fields', ['img']): + results[key] = dbnet_cv.imflip( + results[key], direction=results['flip_direction']) + # flip bboxes + for key in results.get('bbox_fields', []): + results[key] = self.bbox_flip(results[key], + results['img_shape'], + results['flip_direction']) + # flip masks + for key in results.get('mask_fields', []): + results[key] = results[key].flip(results['flip_direction']) + + # flip segs + for key in results.get('seg_fields', []): + results[key] = dbnet_cv.imflip( + results[key], direction=results['flip_direction']) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})' + + +@PIPELINES.register_module() +class RandomShift: + """Shift the image and box given shift pixels and probability. + + Args: + shift_ratio (float): Probability of shifts. Default 0.5. + max_shift_px (int): The max pixels for shifting. Default 32. + filter_thr_px (int): The width and height threshold for filtering. + The bbox and the rest of the targets below the width and + height threshold will be filtered. Default 1. + """ + + def __init__(self, shift_ratio=0.5, max_shift_px=32, filter_thr_px=1): + assert 0 <= shift_ratio <= 1 + assert max_shift_px >= 0 + self.shift_ratio = shift_ratio + self.max_shift_px = max_shift_px + self.filter_thr_px = int(filter_thr_px) + # The key correspondence from bboxes to labels. + self.bbox2label = { + 'gt_bboxes': 'gt_labels', + 'gt_bboxes_ignore': 'gt_labels_ignore' + } + + def __call__(self, results): + """Call function to random shift images, bounding boxes. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Shift results. + """ + if random.random() < self.shift_ratio: + img_shape = results['img'].shape[:2] + + random_shift_x = random.randint(-self.max_shift_px, + self.max_shift_px) + random_shift_y = random.randint(-self.max_shift_px, + self.max_shift_px) + new_x = max(0, random_shift_x) + ori_x = max(0, -random_shift_x) + new_y = max(0, random_shift_y) + ori_y = max(0, -random_shift_y) + + # TODO: support mask and semantic segmentation maps. + for key in results.get('bbox_fields', []): + bboxes = results[key].copy() + bboxes[..., 0::2] += random_shift_x + bboxes[..., 1::2] += random_shift_y + + # clip border + bboxes[..., 0::2] = np.clip(bboxes[..., 0::2], 0, img_shape[1]) + bboxes[..., 1::2] = np.clip(bboxes[..., 1::2], 0, img_shape[0]) + + # remove invalid bboxes + bbox_w = bboxes[..., 2] - bboxes[..., 0] + bbox_h = bboxes[..., 3] - bboxes[..., 1] + valid_inds = (bbox_w > self.filter_thr_px) & ( + bbox_h > self.filter_thr_px) + # If the shift does not contain any gt-bbox area, skip this + # image. + if key == 'gt_bboxes' and not valid_inds.any(): + return results + bboxes = bboxes[valid_inds] + results[key] = bboxes + + # label fields. e.g. gt_labels and gt_labels_ignore + label_key = self.bbox2label.get(key) + if label_key in results: + results[label_key] = results[label_key][valid_inds] + + for key in results.get('img_fields', ['img']): + img = results[key] + new_img = np.zeros_like(img) + img_h, img_w = img.shape[:2] + new_h = img_h - np.abs(random_shift_y) + new_w = img_w - np.abs(random_shift_x) + new_img[new_y:new_y + new_h, new_x:new_x + new_w] \ + = img[ori_y:ori_y + new_h, ori_x:ori_x + new_w] + results[key] = new_img + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(max_shift_px={self.max_shift_px}, ' + return repr_str + + +@PIPELINES.register_module() +class Pad: + """Pad the image & masks & segmentation map. + + There are two padding modes: (1) pad to a fixed size and (2) pad to the + minimum size that is divisible by some number. + Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor", + + Args: + size (tuple, optional): Fixed padding size. + size_divisor (int, optional): The divisor of padded size. + pad_to_square (bool): Whether to pad the image into a square. + Currently only used for YOLOX. Default: False. + pad_val (dict, optional): A dict for padding value, the default + value is `dict(img=0, masks=0, seg=255)`. + """ + + def __init__(self, + size=None, + size_divisor=None, + pad_to_square=False, + pad_val=dict(img=0, masks=0, seg=255)): + self.size = size + self.size_divisor = size_divisor + if isinstance(pad_val, float) or isinstance(pad_val, int): + warnings.warn( + 'pad_val of float type is deprecated now, ' + f'please use pad_val=dict(img={pad_val}, ' + f'masks={pad_val}, seg=255) instead.', DeprecationWarning) + pad_val = dict(img=pad_val, masks=pad_val, seg=255) + assert isinstance(pad_val, dict) + self.pad_val = pad_val + self.pad_to_square = pad_to_square + + if pad_to_square: + assert size is None and size_divisor is None, \ + 'The size and size_divisor must be None ' \ + 'when pad2square is True' + else: + assert size is not None or size_divisor is not None, \ + 'only one of size and size_divisor should be valid' + assert size is None or size_divisor is None + + def _pad_img(self, results): + """Pad images according to ``self.size``.""" + pad_val = self.pad_val.get('img', 0) + for key in results.get('img_fields', ['img']): + if self.pad_to_square: + max_size = max(results[key].shape[:2]) + self.size = (max_size, max_size) + if self.size is not None: + padded_img = dbnet_cv.impad( + results[key], shape=self.size, pad_val=pad_val) + elif self.size_divisor is not None: + padded_img = dbnet_cv.impad_to_multiple( + results[key], self.size_divisor, pad_val=pad_val) + results[key] = padded_img + results['pad_shape'] = padded_img.shape + results['pad_fixed_size'] = self.size + results['pad_size_divisor'] = self.size_divisor + + def _pad_masks(self, results): + """Pad masks according to ``results['pad_shape']``.""" + pad_shape = results['pad_shape'][:2] + pad_val = self.pad_val.get('masks', 0) + for key in results.get('mask_fields', []): + results[key] = results[key].pad(pad_shape, pad_val=pad_val) + + def _pad_seg(self, results): + """Pad semantic segmentation map according to + ``results['pad_shape']``.""" + pad_val = self.pad_val.get('seg', 255) + for key in results.get('seg_fields', []): + results[key] = dbnet_cv.impad( + results[key], shape=results['pad_shape'][:2], pad_val=pad_val) + + def __call__(self, results): + """Call function to pad images, masks, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Updated result dict. + """ + self._pad_img(results) + self._pad_masks(results) + self._pad_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(size={self.size}, ' + repr_str += f'size_divisor={self.size_divisor}, ' + repr_str += f'pad_to_square={self.pad_to_square}, ' + repr_str += f'pad_val={self.pad_val})' + return repr_str + + +@PIPELINES.register_module() +class Normalize: + """Normalize the image. + + Added key is "img_norm_cfg". + + Args: + mean (sequence): Mean values of 3 channels. + std (sequence): Std values of 3 channels. + to_rgb (bool): Whether to convert the image from BGR to RGB, + default is true. + """ + + def __init__(self, mean, std, to_rgb=True): + self.mean = np.array(mean, dtype=np.float32) + self.std = np.array(std, dtype=np.float32) + self.to_rgb = to_rgb + + def __call__(self, results): + """Call function to normalize images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Normalized results, 'img_norm_cfg' key is added into + result dict. + """ + for key in results.get('img_fields', ['img']): + results[key] = dbnet_cv.imnormalize(results[key], self.mean, self.std, + self.to_rgb) + results['img_norm_cfg'] = dict( + mean=self.mean, std=self.std, to_rgb=self.to_rgb) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})' + return repr_str + + +@PIPELINES.register_module() +class RandomCrop: + """Random crop the image & bboxes & masks. + + The absolute `crop_size` is sampled based on `crop_type` and `image_size`, + then the cropped results are generated. + + Args: + crop_size (tuple): The relative ratio or absolute pixels of + height and width. + crop_type (str, optional): one of "relative_range", "relative", + "absolute", "absolute_range". "relative" randomly crops + (h * crop_size[0], w * crop_size[1]) part from an input of size + (h, w). "relative_range" uniformly samples relative crop size from + range [crop_size[0], 1] and [crop_size[1], 1] for height and width + respectively. "absolute" crops from an input with absolute size + (crop_size[0], crop_size[1]). "absolute_range" uniformly samples + crop_h in range [crop_size[0], min(h, crop_size[1])] and crop_w + in range [crop_size[0], min(w, crop_size[1])]. Default "absolute". + allow_negative_crop (bool, optional): Whether to allow a crop that does + not contain any bbox area. Default False. + recompute_bbox (bool, optional): Whether to re-compute the boxes based + on cropped instance masks. Default False. + bbox_clip_border (bool, optional): Whether clip the objects outside + the border of the image. Defaults to True. + + Note: + - If the image is smaller than the absolute crop size, return the + original image. + - The keys for bboxes, labels and masks must be aligned. That is, + `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and + `gt_bboxes_ignore` corresponds to `gt_labels_ignore` and + `gt_masks_ignore`. + - If the crop does not contain any gt-bbox region and + `allow_negative_crop` is set to False, skip this image. + """ + + def __init__(self, + crop_size, + crop_type='absolute', + allow_negative_crop=False, + recompute_bbox=False, + bbox_clip_border=True): + if crop_type not in [ + 'relative_range', 'relative', 'absolute', 'absolute_range' + ]: + raise ValueError(f'Invalid crop_type {crop_type}.') + if crop_type in ['absolute', 'absolute_range']: + assert crop_size[0] > 0 and crop_size[1] > 0 + assert isinstance(crop_size[0], int) and isinstance( + crop_size[1], int) + else: + assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1 + self.crop_size = crop_size + self.crop_type = crop_type + self.allow_negative_crop = allow_negative_crop + self.bbox_clip_border = bbox_clip_border + self.recompute_bbox = recompute_bbox + # The key correspondence from bboxes to labels and masks. + self.bbox2label = { + 'gt_bboxes': 'gt_labels', + 'gt_bboxes_ignore': 'gt_labels_ignore' + } + self.bbox2mask = { + 'gt_bboxes': 'gt_masks', + 'gt_bboxes_ignore': 'gt_masks_ignore' + } + + def _crop_data(self, results, crop_size, allow_negative_crop): + """Function to randomly crop images, bounding boxes, masks, semantic + segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + crop_size (tuple): Expected absolute size after cropping, (h, w). + allow_negative_crop (bool): Whether to allow a crop that does not + contain any bbox area. Default to False. + + Returns: + dict: Randomly cropped results, 'img_shape' key in result dict is + updated according to crop size. + """ + assert crop_size[0] > 0 and crop_size[1] > 0 + for key in results.get('img_fields', ['img']): + img = results[key] + margin_h = max(img.shape[0] - crop_size[0], 0) + margin_w = max(img.shape[1] - crop_size[1], 0) + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + crop_y1, crop_y2 = offset_h, offset_h + crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + crop_size[1] + + # crop the image + img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] + img_shape = img.shape + results[key] = img + results['img_shape'] = img_shape + + # crop bboxes accordingly and clip to the image boundary + for key in results.get('bbox_fields', []): + # e.g. gt_bboxes and gt_bboxes_ignore + bbox_offset = np.array([offset_w, offset_h, offset_w, offset_h], + dtype=np.float32) + bboxes = results[key] - bbox_offset + if self.bbox_clip_border: + bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) + bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0]) + valid_inds = (bboxes[:, 2] > bboxes[:, 0]) & ( + bboxes[:, 3] > bboxes[:, 1]) + # If the crop does not contain any gt-bbox area and + # allow_negative_crop is False, skip this image. + if (key == 'gt_bboxes' and not valid_inds.any() + and not allow_negative_crop): + return None + results[key] = bboxes[valid_inds, :] + # label fields. e.g. gt_labels and gt_labels_ignore + label_key = self.bbox2label.get(key) + if label_key in results: + results[label_key] = results[label_key][valid_inds] + + # mask fields, e.g. gt_masks and gt_masks_ignore + mask_key = self.bbox2mask.get(key) + if mask_key in results: + results[mask_key] = results[mask_key][ + valid_inds.nonzero()[0]].crop( + np.asarray([crop_x1, crop_y1, crop_x2, crop_y2])) + if self.recompute_bbox: + results[key] = results[mask_key].get_bboxes() + + # crop semantic seg + for key in results.get('seg_fields', []): + results[key] = results[key][crop_y1:crop_y2, crop_x1:crop_x2] + + return results + + def _get_crop_size(self, image_size): + """Randomly generates the absolute crop size based on `crop_type` and + `image_size`. + + Args: + image_size (tuple): (h, w). + + Returns: + crop_size (tuple): (crop_h, crop_w) in absolute pixels. + """ + h, w = image_size + if self.crop_type == 'absolute': + return (min(self.crop_size[0], h), min(self.crop_size[1], w)) + elif self.crop_type == 'absolute_range': + assert self.crop_size[0] <= self.crop_size[1] + crop_h = np.random.randint( + min(h, self.crop_size[0]), + min(h, self.crop_size[1]) + 1) + crop_w = np.random.randint( + min(w, self.crop_size[0]), + min(w, self.crop_size[1]) + 1) + return crop_h, crop_w + elif self.crop_type == 'relative': + crop_h, crop_w = self.crop_size + return int(h * crop_h + 0.5), int(w * crop_w + 0.5) + elif self.crop_type == 'relative_range': + crop_size = np.asarray(self.crop_size, dtype=np.float32) + crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size) + return int(h * crop_h + 0.5), int(w * crop_w + 0.5) + + def __call__(self, results): + """Call function to randomly crop images, bounding boxes, masks, + semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' key in result dict is + updated according to crop size. + """ + image_size = results['img'].shape[:2] + crop_size = self._get_crop_size(image_size) + results = self._crop_data(results, crop_size, self.allow_negative_crop) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(crop_size={self.crop_size}, ' + repr_str += f'crop_type={self.crop_type}, ' + repr_str += f'allow_negative_crop={self.allow_negative_crop}, ' + repr_str += f'bbox_clip_border={self.bbox_clip_border})' + return repr_str + + +@PIPELINES.register_module() +class SegRescale: + """Rescale semantic segmentation maps. + + Args: + scale_factor (float): The scale factor of the final output. + backend (str): Image rescale backend, choices are 'cv2' and 'pillow'. + These two backends generates slightly different results. Defaults + to 'cv2'. + """ + + def __init__(self, scale_factor=1, backend='cv2'): + self.scale_factor = scale_factor + self.backend = backend + + def __call__(self, results): + """Call function to scale the semantic segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with semantic segmentation map scaled. + """ + + for key in results.get('seg_fields', []): + if self.scale_factor != 1: + results[key] = dbnet_cv.imrescale( + results[key], + self.scale_factor, + interpolation='nearest', + backend=self.backend) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(scale_factor={self.scale_factor})' + + +@PIPELINES.register_module() +class PhotoMetricDistortion: + """Apply photometric distortion to image sequentially, every transformation + is applied with a probability of 0.5. The position of random contrast is in + second or second to last. + + 1. random brightness + 2. random contrast (mode 0) + 3. convert color from BGR to HSV + 4. random saturation + 5. random hue + 6. convert color from HSV to BGR + 7. random contrast (mode 1) + 8. randomly swap channels + + Args: + brightness_delta (int): delta of brightness. + contrast_range (tuple): range of contrast. + saturation_range (tuple): range of saturation. + hue_delta (int): delta of hue. + """ + + def __init__(self, + brightness_delta=32, + contrast_range=(0.5, 1.5), + saturation_range=(0.5, 1.5), + hue_delta=18): + self.brightness_delta = brightness_delta + self.contrast_lower, self.contrast_upper = contrast_range + self.saturation_lower, self.saturation_upper = saturation_range + self.hue_delta = hue_delta + + def __call__(self, results): + """Call function to perform photometric distortion on images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images distorted. + """ + + if 'img_fields' in results: + assert results['img_fields'] == ['img'], \ + 'Only single img_fields is allowed' + img = results['img'] + img = img.astype(np.float32) + # random brightness + if random.randint(2): + delta = random.uniform(-self.brightness_delta, + self.brightness_delta) + img += delta + + # mode == 0 --> do random contrast first + # mode == 1 --> do random contrast last + mode = random.randint(2) + if mode == 1: + if random.randint(2): + alpha = random.uniform(self.contrast_lower, + self.contrast_upper) + img *= alpha + + # convert color from BGR to HSV + img = dbnet_cv.bgr2hsv(img) + + # random saturation + if random.randint(2): + img[..., 1] *= random.uniform(self.saturation_lower, + self.saturation_upper) + + # random hue + if random.randint(2): + img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta) + img[..., 0][img[..., 0] > 360] -= 360 + img[..., 0][img[..., 0] < 0] += 360 + + # convert color from HSV to BGR + img = dbnet_cv.hsv2bgr(img) + + # random contrast + if mode == 0: + if random.randint(2): + alpha = random.uniform(self.contrast_lower, + self.contrast_upper) + img *= alpha + + # randomly swap channels + if random.randint(2): + img = img[..., random.permutation(3)] + + results['img'] = img + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(\nbrightness_delta={self.brightness_delta},\n' + repr_str += 'contrast_range=' + repr_str += f'{(self.contrast_lower, self.contrast_upper)},\n' + repr_str += 'saturation_range=' + repr_str += f'{(self.saturation_lower, self.saturation_upper)},\n' + repr_str += f'hue_delta={self.hue_delta})' + return repr_str + + +@PIPELINES.register_module() +class Expand: + """Random expand the image & bboxes. + + Randomly place the original image on a canvas of 'ratio' x original image + size filled with mean values. The ratio is in the range of ratio_range. + + Args: + mean (tuple): mean value of dataset. + to_rgb (bool): if need to convert the order of mean to align with RGB. + ratio_range (tuple): range of expand ratio. + prob (float): probability of applying this transformation + """ + + def __init__(self, + mean=(0, 0, 0), + to_rgb=True, + ratio_range=(1, 4), + seg_ignore_label=None, + prob=0.5): + self.to_rgb = to_rgb + self.ratio_range = ratio_range + if to_rgb: + self.mean = mean[::-1] + else: + self.mean = mean + self.min_ratio, self.max_ratio = ratio_range + self.seg_ignore_label = seg_ignore_label + self.prob = prob + + def __call__(self, results): + """Call function to expand images, bounding boxes. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images, bounding boxes expanded + """ + + if random.uniform(0, 1) > self.prob: + return results + + if 'img_fields' in results: + assert results['img_fields'] == ['img'], \ + 'Only single img_fields is allowed' + img = results['img'] + + h, w, c = img.shape + ratio = random.uniform(self.min_ratio, self.max_ratio) + # speedup expand when meets large image + if np.all(self.mean == self.mean[0]): + expand_img = np.empty((int(h * ratio), int(w * ratio), c), + img.dtype) + expand_img.fill(self.mean[0]) + else: + expand_img = np.full((int(h * ratio), int(w * ratio), c), + self.mean, + dtype=img.dtype) + left = int(random.uniform(0, w * ratio - w)) + top = int(random.uniform(0, h * ratio - h)) + expand_img[top:top + h, left:left + w] = img + + results['img'] = expand_img + # expand bboxes + for key in results.get('bbox_fields', []): + results[key] = results[key] + np.tile( + (left, top), 2).astype(results[key].dtype) + + # expand masks + for key in results.get('mask_fields', []): + results[key] = results[key].expand( + int(h * ratio), int(w * ratio), top, left) + + # expand segs + for key in results.get('seg_fields', []): + gt_seg = results[key] + expand_gt_seg = np.full((int(h * ratio), int(w * ratio)), + self.seg_ignore_label, + dtype=gt_seg.dtype) + expand_gt_seg[top:top + h, left:left + w] = gt_seg + results[key] = expand_gt_seg + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(mean={self.mean}, to_rgb={self.to_rgb}, ' + repr_str += f'ratio_range={self.ratio_range}, ' + repr_str += f'seg_ignore_label={self.seg_ignore_label})' + return repr_str + + +@PIPELINES.register_module() +class MinIoURandomCrop: + """Random crop the image & bboxes, the cropped patches have minimum IoU + requirement with original image & bboxes, the IoU threshold is randomly + selected from min_ious. + + Args: + min_ious (tuple): minimum IoU threshold for all intersections with + bounding boxes + min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, + where a >= min_crop_size). + bbox_clip_border (bool, optional): Whether clip the objects outside + the border of the image. Defaults to True. + + Note: + The keys for bboxes, labels and masks should be paired. That is, \ + `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \ + `gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`. + """ + + def __init__(self, + min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), + min_crop_size=0.3, + bbox_clip_border=True): + # 1: return ori img + self.min_ious = min_ious + self.sample_mode = (1, *min_ious, 0) + self.min_crop_size = min_crop_size + self.bbox_clip_border = bbox_clip_border + self.bbox2label = { + 'gt_bboxes': 'gt_labels', + 'gt_bboxes_ignore': 'gt_labels_ignore' + } + self.bbox2mask = { + 'gt_bboxes': 'gt_masks', + 'gt_bboxes_ignore': 'gt_masks_ignore' + } + + def __call__(self, results): + """Call function to crop images and bounding boxes with minimum IoU + constraint. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images and bounding boxes cropped, \ + 'img_shape' key is updated. + """ + + if 'img_fields' in results: + assert results['img_fields'] == ['img'], \ + 'Only single img_fields is allowed' + img = results['img'] + assert 'bbox_fields' in results + boxes = [results[key] for key in results['bbox_fields']] + boxes = np.concatenate(boxes, 0) + h, w, c = img.shape + while True: + mode = random.choice(self.sample_mode) + self.mode = mode + if mode == 1: + return results + + min_iou = mode + for i in range(50): + new_w = random.uniform(self.min_crop_size * w, w) + new_h = random.uniform(self.min_crop_size * h, h) + + # h / w in [0.5, 2] + if new_h / new_w < 0.5 or new_h / new_w > 2: + continue + + left = random.uniform(w - new_w) + top = random.uniform(h - new_h) + + patch = np.array( + (int(left), int(top), int(left + new_w), int(top + new_h))) + # Line or point crop is not allowed + if patch[2] == patch[0] or patch[3] == patch[1]: + continue + overlaps = bbox_overlaps( + patch.reshape(-1, 4), boxes.reshape(-1, 4)).reshape(-1) + if len(overlaps) > 0 and overlaps.min() < min_iou: + continue + + # center of boxes should inside the crop img + # only adjust boxes and instance masks when the gt is not empty + if len(overlaps) > 0: + # adjust boxes + def is_center_of_bboxes_in_patch(boxes, patch): + center = (boxes[:, :2] + boxes[:, 2:]) / 2 + mask = ((center[:, 0] > patch[0]) * + (center[:, 1] > patch[1]) * + (center[:, 0] < patch[2]) * + (center[:, 1] < patch[3])) + return mask + + mask = is_center_of_bboxes_in_patch(boxes, patch) + if not mask.any(): + continue + for key in results.get('bbox_fields', []): + boxes = results[key].copy() + mask = is_center_of_bboxes_in_patch(boxes, patch) + boxes = boxes[mask] + if self.bbox_clip_border: + boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) + boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) + boxes -= np.tile(patch[:2], 2) + + results[key] = boxes + # labels + label_key = self.bbox2label.get(key) + if label_key in results: + results[label_key] = results[label_key][mask] + + # mask fields + mask_key = self.bbox2mask.get(key) + if mask_key in results: + results[mask_key] = results[mask_key][ + mask.nonzero()[0]].crop(patch) + # adjust the img no matter whether the gt is empty before crop + img = img[patch[1]:patch[3], patch[0]:patch[2]] + results['img'] = img + results['img_shape'] = img.shape + + # seg fields + for key in results.get('seg_fields', []): + results[key] = results[key][patch[1]:patch[3], + patch[0]:patch[2]] + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(min_ious={self.min_ious}, ' + repr_str += f'min_crop_size={self.min_crop_size}, ' + repr_str += f'bbox_clip_border={self.bbox_clip_border})' + return repr_str + + +@PIPELINES.register_module() +class Corrupt: + """Corruption augmentation. + + Corruption transforms implemented based on + `imagecorruptions `_. + + Args: + corruption (str): Corruption name. + severity (int, optional): The severity of corruption. Default: 1. + """ + + def __init__(self, corruption, severity=1): + self.corruption = corruption + self.severity = severity + + def __call__(self, results): + """Call function to corrupt image. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images corrupted. + """ + + if corrupt is None: + raise RuntimeError('imagecorruptions is not installed') + if 'img_fields' in results: + assert results['img_fields'] == ['img'], \ + 'Only single img_fields is allowed' + results['img'] = corrupt( + results['img'].astype(np.uint8), + corruption_name=self.corruption, + severity=self.severity) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(corruption={self.corruption}, ' + repr_str += f'severity={self.severity})' + return repr_str + + +@PIPELINES.register_module() +class Albu: + """Albumentation augmentation. + + Adds custom transformations from Albumentations library. + Please, visit `https://albumentations.readthedocs.io` + to get more information. + + An example of ``transforms`` is as followed: + + .. code-block:: + + [ + dict( + type='ShiftScaleRotate', + shift_limit=0.0625, + scale_limit=0.0, + rotate_limit=0, + interpolation=1, + p=0.5), + dict( + type='RandomBrightnessContrast', + brightness_limit=[0.1, 0.3], + contrast_limit=[0.1, 0.3], + p=0.2), + dict(type='ChannelShuffle', p=0.1), + dict( + type='OneOf', + transforms=[ + dict(type='Blur', blur_limit=3, p=1.0), + dict(type='MedianBlur', blur_limit=3, p=1.0) + ], + p=0.1), + ] + + Args: + transforms (list[dict]): A list of albu transformations + bbox_params (dict): Bbox_params for albumentation `Compose` + keymap (dict): Contains {'input key':'albumentation-style key'} + skip_img_without_anno (bool): Whether to skip the image if no ann left + after aug + """ + + def __init__(self, + transforms, + bbox_params=None, + keymap=None, + update_pad_shape=False, + skip_img_without_anno=False): + if Compose is None: + raise RuntimeError('albumentations is not installed') + + # Args will be modified later, copying it will be safer + transforms = copy.deepcopy(transforms) + if bbox_params is not None: + bbox_params = copy.deepcopy(bbox_params) + if keymap is not None: + keymap = copy.deepcopy(keymap) + self.transforms = transforms + self.filter_lost_elements = False + self.update_pad_shape = update_pad_shape + self.skip_img_without_anno = skip_img_without_anno + + # A simple workaround to remove masks without boxes + if (isinstance(bbox_params, dict) and 'label_fields' in bbox_params + and 'filter_lost_elements' in bbox_params): + self.filter_lost_elements = True + self.origin_label_fields = bbox_params['label_fields'] + bbox_params['label_fields'] = ['idx_mapper'] + del bbox_params['filter_lost_elements'] + + self.bbox_params = ( + self.albu_builder(bbox_params) if bbox_params else None) + self.aug = Compose([self.albu_builder(t) for t in self.transforms], + bbox_params=self.bbox_params) + + if not keymap: + self.keymap_to_albu = { + 'img': 'image', + 'gt_masks': 'masks', + 'gt_bboxes': 'bboxes' + } + else: + self.keymap_to_albu = keymap + self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()} + + def albu_builder(self, cfg): + """Import a module from albumentations. + + It inherits some of :func:`build_from_cfg` logic. + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + + Returns: + obj: The constructed object. + """ + + assert isinstance(cfg, dict) and 'type' in cfg + args = cfg.copy() + + obj_type = args.pop('type') + if dbnet_cv.is_str(obj_type): + if albumentations is None: + raise RuntimeError('albumentations is not installed') + obj_cls = getattr(albumentations, obj_type) + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError( + f'type must be a str or valid type, but got {type(obj_type)}') + + if 'transforms' in args: + args['transforms'] = [ + self.albu_builder(transform) + for transform in args['transforms'] + ] + + return obj_cls(**args) + + @staticmethod + def mapper(d, keymap): + """Dictionary mapper. Renames keys according to keymap provided. + + Args: + d (dict): old dict + keymap (dict): {'old_key':'new_key'} + Returns: + dict: new dict. + """ + + updated_dict = {} + for k, v in zip(d.keys(), d.values()): + new_k = keymap.get(k, k) + updated_dict[new_k] = d[k] + return updated_dict + + def __call__(self, results): + # dict to albumentations format + results = self.mapper(results, self.keymap_to_albu) + # TODO: add bbox_fields + if 'bboxes' in results: + # to list of boxes + if isinstance(results['bboxes'], np.ndarray): + results['bboxes'] = [x for x in results['bboxes']] + # add pseudo-field for filtration + if self.filter_lost_elements: + results['idx_mapper'] = np.arange(len(results['bboxes'])) + + # TODO: Support mask structure in albu + if 'masks' in results: + if isinstance(results['masks'], PolygonMasks): + raise NotImplementedError( + 'Albu only supports BitMap masks now') + ori_masks = results['masks'] + if albumentations.__version__ < '0.5': + results['masks'] = results['masks'].masks + else: + results['masks'] = [mask for mask in results['masks'].masks] + + results = self.aug(**results) + + if 'bboxes' in results: + if isinstance(results['bboxes'], list): + results['bboxes'] = np.array( + results['bboxes'], dtype=np.float32) + results['bboxes'] = results['bboxes'].reshape(-1, 4) + + # filter label_fields + if self.filter_lost_elements: + + for label in self.origin_label_fields: + results[label] = np.array( + [results[label][i] for i in results['idx_mapper']]) + if 'masks' in results: + results['masks'] = np.array( + [results['masks'][i] for i in results['idx_mapper']]) + results['masks'] = ori_masks.__class__( + results['masks'], results['image'].shape[0], + results['image'].shape[1]) + + if (not len(results['idx_mapper']) + and self.skip_img_without_anno): + return None + + if 'gt_labels' in results: + if isinstance(results['gt_labels'], list): + results['gt_labels'] = np.array(results['gt_labels']) + results['gt_labels'] = results['gt_labels'].astype(np.int64) + + # back to the original format + results = self.mapper(results, self.keymap_back) + + # update final shape + if self.update_pad_shape: + results['pad_shape'] = results['img'].shape + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + f'(transforms={self.transforms})' + return repr_str + + +@PIPELINES.register_module() +class RandomCenterCropPad: + """Random center crop and random around padding for CornerNet. + + This operation generates randomly cropped image from the original image and + pads it simultaneously. Different from :class:`RandomCrop`, the output + shape may not equal to ``crop_size`` strictly. We choose a random value + from ``ratios`` and the output shape could be larger or smaller than + ``crop_size``. The padding operation is also different from :class:`Pad`, + here we use around padding instead of right-bottom padding. + + The relation between output image (padding image) and original image: + + .. code:: text + + output image + + +----------------------------+ + | padded area | + +------|----------------------------|----------+ + | | cropped area | | + | | +---------------+ | | + | | | . center | | | original image + | | | range | | | + | | +---------------+ | | + +------|----------------------------|----------+ + | padded area | + +----------------------------+ + + There are 5 main areas in the figure: + + - output image: output image of this operation, also called padding + image in following instruction. + - original image: input image of this operation. + - padded area: non-intersect area of output image and original image. + - cropped area: the overlap of output image and original image. + - center range: a smaller area where random center chosen from. + center range is computed by ``border`` and original image's shape + to avoid our random center is too close to original image's border. + + Also this operation act differently in train and test mode, the summary + pipeline is listed below. + + Train pipeline: + + 1. Choose a ``random_ratio`` from ``ratios``, the shape of padding image + will be ``random_ratio * crop_size``. + 2. Choose a ``random_center`` in center range. + 3. Generate padding image with center matches the ``random_center``. + 4. Initialize the padding image with pixel value equals to ``mean``. + 5. Copy the cropped area to padding image. + 6. Refine annotations. + + Test pipeline: + + 1. Compute output shape according to ``test_pad_mode``. + 2. Generate padding image with center matches the original image + center. + 3. Initialize the padding image with pixel value equals to ``mean``. + 4. Copy the ``cropped area`` to padding image. + + Args: + crop_size (tuple | None): expected size after crop, final size will + computed according to ratio. Requires (h, w) in train mode, and + None in test mode. + ratios (tuple): random select a ratio from tuple and crop image to + (crop_size[0] * ratio) * (crop_size[1] * ratio). + Only available in train mode. + border (int): max distance from center select area to image border. + Only available in train mode. + mean (sequence): Mean values of 3 channels. + std (sequence): Std values of 3 channels. + to_rgb (bool): Whether to convert the image from BGR to RGB. + test_mode (bool): whether involve random variables in transform. + In train mode, crop_size is fixed, center coords and ratio is + random selected from predefined lists. In test mode, crop_size + is image's original shape, center coords and ratio is fixed. + test_pad_mode (tuple): padding method and padding shape value, only + available in test mode. Default is using 'logical_or' with + 127 as padding shape value. + + - 'logical_or': final_shape = input_shape | padding_shape_value + - 'size_divisor': final_shape = int( + ceil(input_shape / padding_shape_value) * padding_shape_value) + test_pad_add_pix (int): Extra padding pixel in test mode. Default 0. + bbox_clip_border (bool, optional): Whether clip the objects outside + the border of the image. Defaults to True. + """ + + def __init__(self, + crop_size=None, + ratios=(0.9, 1.0, 1.1), + border=128, + mean=None, + std=None, + to_rgb=None, + test_mode=False, + test_pad_mode=('logical_or', 127), + test_pad_add_pix=0, + bbox_clip_border=True): + if test_mode: + assert crop_size is None, 'crop_size must be None in test mode' + assert ratios is None, 'ratios must be None in test mode' + assert border is None, 'border must be None in test mode' + assert isinstance(test_pad_mode, (list, tuple)) + assert test_pad_mode[0] in ['logical_or', 'size_divisor'] + else: + assert isinstance(crop_size, (list, tuple)) + assert crop_size[0] > 0 and crop_size[1] > 0, ( + 'crop_size must > 0 in train mode') + assert isinstance(ratios, (list, tuple)) + assert test_pad_mode is None, ( + 'test_pad_mode must be None in train mode') + + self.crop_size = crop_size + self.ratios = ratios + self.border = border + # We do not set default value to mean, std and to_rgb because these + # hyper-parameters are easy to forget but could affect the performance. + # Please use the same setting as Normalize for performance assurance. + assert mean is not None and std is not None and to_rgb is not None + self.to_rgb = to_rgb + self.input_mean = mean + self.input_std = std + if to_rgb: + self.mean = mean[::-1] + self.std = std[::-1] + else: + self.mean = mean + self.std = std + self.test_mode = test_mode + self.test_pad_mode = test_pad_mode + self.test_pad_add_pix = test_pad_add_pix + self.bbox_clip_border = bbox_clip_border + + def _get_border(self, border, size): + """Get final border for the target size. + + This function generates a ``final_border`` according to image's shape. + The area between ``final_border`` and ``size - final_border`` is the + ``center range``. We randomly choose center from the ``center range`` + to avoid our random center is too close to original image's border. + Also ``center range`` should be larger than 0. + + Args: + border (int): The initial border, default is 128. + size (int): The width or height of original image. + Returns: + int: The final border. + """ + k = 2 * border / size + i = pow(2, np.ceil(np.log2(np.ceil(k))) + (k == int(k))) + return border // i + + def _filter_boxes(self, patch, boxes): + """Check whether the center of each box is in the patch. + + Args: + patch (list[int]): The cropped area, [left, top, right, bottom]. + boxes (numpy array, (N x 4)): Ground truth boxes. + + Returns: + mask (numpy array, (N,)): Each box is inside or outside the patch. + """ + center = (boxes[:, :2] + boxes[:, 2:]) / 2 + mask = (center[:, 0] > patch[0]) * (center[:, 1] > patch[1]) * ( + center[:, 0] < patch[2]) * ( + center[:, 1] < patch[3]) + return mask + + def _crop_image_and_paste(self, image, center, size): + """Crop image with a given center and size, then paste the cropped + image to a blank image with two centers align. + + This function is equivalent to generating a blank image with ``size`` + as its shape. Then cover it on the original image with two centers ( + the center of blank image and the random center of original image) + aligned. The overlap area is paste from the original image and the + outside area is filled with ``mean pixel``. + + Args: + image (np array, H x W x C): Original image. + center (list[int]): Target crop center coord. + size (list[int]): Target crop size. [target_h, target_w] + + Returns: + cropped_img (np array, target_h x target_w x C): Cropped image. + border (np array, 4): The distance of four border of + ``cropped_img`` to the original image area, [top, bottom, + left, right] + patch (list[int]): The cropped area, [left, top, right, bottom]. + """ + center_y, center_x = center + target_h, target_w = size + img_h, img_w, img_c = image.shape + + x0 = max(0, center_x - target_w // 2) + x1 = min(center_x + target_w // 2, img_w) + y0 = max(0, center_y - target_h // 2) + y1 = min(center_y + target_h // 2, img_h) + patch = np.array((int(x0), int(y0), int(x1), int(y1))) + + left, right = center_x - x0, x1 - center_x + top, bottom = center_y - y0, y1 - center_y + + cropped_center_y, cropped_center_x = target_h // 2, target_w // 2 + cropped_img = np.zeros((target_h, target_w, img_c), dtype=image.dtype) + for i in range(img_c): + cropped_img[:, :, i] += self.mean[i] + y_slice = slice(cropped_center_y - top, cropped_center_y + bottom) + x_slice = slice(cropped_center_x - left, cropped_center_x + right) + cropped_img[y_slice, x_slice, :] = image[y0:y1, x0:x1, :] + + border = np.array([ + cropped_center_y - top, cropped_center_y + bottom, + cropped_center_x - left, cropped_center_x + right + ], + dtype=np.float32) + + return cropped_img, border, patch + + def _train_aug(self, results): + """Random crop and around padding the original image. + + Args: + results (dict): Image infomations in the augment pipeline. + + Returns: + results (dict): The updated dict. + """ + img = results['img'] + h, w, c = img.shape + boxes = results['gt_bboxes'] + while True: + scale = random.choice(self.ratios) + new_h = int(self.crop_size[0] * scale) + new_w = int(self.crop_size[1] * scale) + h_border = self._get_border(self.border, h) + w_border = self._get_border(self.border, w) + + for i in range(50): + center_x = random.randint(low=w_border, high=w - w_border) + center_y = random.randint(low=h_border, high=h - h_border) + + cropped_img, border, patch = self._crop_image_and_paste( + img, [center_y, center_x], [new_h, new_w]) + + mask = self._filter_boxes(patch, boxes) + # if image do not have valid bbox, any crop patch is valid. + if not mask.any() and len(boxes) > 0: + continue + + results['img'] = cropped_img + results['img_shape'] = cropped_img.shape + results['pad_shape'] = cropped_img.shape + + x0, y0, x1, y1 = patch + + left_w, top_h = center_x - x0, center_y - y0 + cropped_center_x, cropped_center_y = new_w // 2, new_h // 2 + + # crop bboxes accordingly and clip to the image boundary + for key in results.get('bbox_fields', []): + mask = self._filter_boxes(patch, results[key]) + bboxes = results[key][mask] + bboxes[:, 0:4:2] += cropped_center_x - left_w - x0 + bboxes[:, 1:4:2] += cropped_center_y - top_h - y0 + if self.bbox_clip_border: + bboxes[:, 0:4:2] = np.clip(bboxes[:, 0:4:2], 0, new_w) + bboxes[:, 1:4:2] = np.clip(bboxes[:, 1:4:2], 0, new_h) + keep = (bboxes[:, 2] > bboxes[:, 0]) & ( + bboxes[:, 3] > bboxes[:, 1]) + bboxes = bboxes[keep] + results[key] = bboxes + if key in ['gt_bboxes']: + if 'gt_labels' in results: + labels = results['gt_labels'][mask] + labels = labels[keep] + results['gt_labels'] = labels + if 'gt_masks' in results: + raise NotImplementedError( + 'RandomCenterCropPad only supports bbox.') + + # crop semantic seg + for key in results.get('seg_fields', []): + raise NotImplementedError( + 'RandomCenterCropPad only supports bbox.') + return results + + def _test_aug(self, results): + """Around padding the original image without cropping. + + The padding mode and value are from ``test_pad_mode``. + + Args: + results (dict): Image infomations in the augment pipeline. + + Returns: + results (dict): The updated dict. + """ + img = results['img'] + h, w, c = img.shape + results['img_shape'] = img.shape + if self.test_pad_mode[0] in ['logical_or']: + # self.test_pad_add_pix is only used for centernet + target_h = (h | self.test_pad_mode[1]) + self.test_pad_add_pix + target_w = (w | self.test_pad_mode[1]) + self.test_pad_add_pix + elif self.test_pad_mode[0] in ['size_divisor']: + divisor = self.test_pad_mode[1] + target_h = int(np.ceil(h / divisor)) * divisor + target_w = int(np.ceil(w / divisor)) * divisor + else: + raise NotImplementedError( + 'RandomCenterCropPad only support two testing pad mode:' + 'logical-or and size_divisor.') + + cropped_img, border, _ = self._crop_image_and_paste( + img, [h // 2, w // 2], [target_h, target_w]) + results['img'] = cropped_img + results['pad_shape'] = cropped_img.shape + results['border'] = border + return results + + def __call__(self, results): + img = results['img'] + assert img.dtype == np.float32, ( + 'RandomCenterCropPad needs the input image of dtype np.float32,' + ' please set "to_float32=True" in "LoadImageFromFile" pipeline') + h, w, c = img.shape + assert c == len(self.mean) + if self.test_mode: + return self._test_aug(results) + else: + return self._train_aug(results) + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(crop_size={self.crop_size}, ' + repr_str += f'ratios={self.ratios}, ' + repr_str += f'border={self.border}, ' + repr_str += f'mean={self.input_mean}, ' + repr_str += f'std={self.input_std}, ' + repr_str += f'to_rgb={self.to_rgb}, ' + repr_str += f'test_mode={self.test_mode}, ' + repr_str += f'test_pad_mode={self.test_pad_mode}, ' + repr_str += f'bbox_clip_border={self.bbox_clip_border})' + return repr_str + + +@PIPELINES.register_module() +class CutOut: + """CutOut operation. + + Randomly drop some regions of image used in + `Cutout `_. + + Args: + n_holes (int | tuple[int, int]): Number of regions to be dropped. + If it is given as a list, number of holes will be randomly + selected from the closed interval [`n_holes[0]`, `n_holes[1]`]. + cutout_shape (tuple[int, int] | list[tuple[int, int]]): The candidate + shape of dropped regions. It can be `tuple[int, int]` to use a + fixed cutout shape, or `list[tuple[int, int]]` to randomly choose + shape from the list. + cutout_ratio (tuple[float, float] | list[tuple[float, float]]): The + candidate ratio of dropped regions. It can be `tuple[float, float]` + to use a fixed ratio or `list[tuple[float, float]]` to randomly + choose ratio from the list. Please note that `cutout_shape` + and `cutout_ratio` cannot be both given at the same time. + fill_in (tuple[float, float, float] | tuple[int, int, int]): The value + of pixel to fill in the dropped regions. Default: (0, 0, 0). + """ + + def __init__(self, + n_holes, + cutout_shape=None, + cutout_ratio=None, + fill_in=(0, 0, 0)): + + assert (cutout_shape is None) ^ (cutout_ratio is None), \ + 'Either cutout_shape or cutout_ratio should be specified.' + assert (isinstance(cutout_shape, (list, tuple)) + or isinstance(cutout_ratio, (list, tuple))) + if isinstance(n_holes, tuple): + assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1] + else: + n_holes = (n_holes, n_holes) + self.n_holes = n_holes + self.fill_in = fill_in + self.with_ratio = cutout_ratio is not None + self.candidates = cutout_ratio if self.with_ratio else cutout_shape + if not isinstance(self.candidates, list): + self.candidates = [self.candidates] + + def __call__(self, results): + """Call function to drop some regions of image.""" + h, w, c = results['img'].shape + n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1) + for _ in range(n_holes): + x1 = np.random.randint(0, w) + y1 = np.random.randint(0, h) + index = np.random.randint(0, len(self.candidates)) + if not self.with_ratio: + cutout_w, cutout_h = self.candidates[index] + else: + cutout_w = int(self.candidates[index][0] * w) + cutout_h = int(self.candidates[index][1] * h) + + x2 = np.clip(x1 + cutout_w, 0, w) + y2 = np.clip(y1 + cutout_h, 0, h) + results['img'][y1:y2, x1:x2, :] = self.fill_in + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(n_holes={self.n_holes}, ' + repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio + else f'cutout_shape={self.candidates}, ') + repr_str += f'fill_in={self.fill_in})' + return repr_str + + +@PIPELINES.register_module() +class Mosaic: + """Mosaic augmentation. + + Given 4 images, mosaic transform combines them into + one output image. The output image is composed of the parts from each sub- + image. + + .. code:: text + + mosaic transform + center_x + +------------------------------+ + | pad | pad | + | +-----------+ | + | | | | + | | image1 |--------+ | + | | | | | + | | | image2 | | + center_y |----+-------------+-----------| + | | cropped | | + |pad | image3 | image4 | + | | | | + +----|-------------+-----------+ + | | + +-------------+ + + The mosaic transform steps are as follows: + + 1. Choose the mosaic center as the intersections of 4 images + 2. Get the left top image according to the index, and randomly + sample another 3 images from the custom dataset. + 3. Sub image will be cropped if image is larger than mosaic patch + + Args: + img_scale (Sequence[int]): Image size after mosaic pipeline of single + image. The shape order should be (height, width). + Default to (640, 640). + center_ratio_range (Sequence[float]): Center ratio range of mosaic + output. Default to (0.5, 1.5). + min_bbox_size (int | float): The minimum pixel for filtering + invalid bboxes after the mosaic pipeline. Default to 0. + bbox_clip_border (bool, optional): Whether to clip the objects outside + the border of the image. In some dataset like MOT17, the gt bboxes + are allowed to cross the border of images. Therefore, we don't + need to clip the gt bboxes in these cases. Defaults to True. + skip_filter (bool): Whether to skip filtering rules. If it + is True, the filter rule will not be applied, and the + `min_bbox_size` is invalid. Default to True. + pad_val (int): Pad value. Default to 114. + prob (float): Probability of applying this transformation. + Default to 1.0. + """ + + def __init__(self, + img_scale=(640, 640), + center_ratio_range=(0.5, 1.5), + min_bbox_size=0, + bbox_clip_border=True, + skip_filter=True, + pad_val=114, + prob=1.0): + assert isinstance(img_scale, tuple) + assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. '\ + f'got {prob}.' + + log_img_scale(img_scale, skip_square=True) + self.img_scale = img_scale + self.center_ratio_range = center_ratio_range + self.min_bbox_size = min_bbox_size + self.bbox_clip_border = bbox_clip_border + self.skip_filter = skip_filter + self.pad_val = pad_val + self.prob = prob + + def __call__(self, results): + """Call function to make a mosaic of image. + + Args: + results (dict): Result dict. + + Returns: + dict: Result dict with mosaic transformed. + """ + + if random.uniform(0, 1) > self.prob: + return results + + results = self._mosaic_transform(results) + return results + + def get_indexes(self, dataset): + """Call function to collect indexes. + + Args: + dataset (:obj:`MultiImageMixDataset`): The dataset. + + Returns: + list: indexes. + """ + + indexes = [random.randint(0, len(dataset)) for _ in range(3)] + return indexes + + def _mosaic_transform(self, results): + """Mosaic transform function. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + + assert 'mix_results' in results + mosaic_labels = [] + mosaic_bboxes = [] + if len(results['img'].shape) == 3: + mosaic_img = np.full( + (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), 3), + self.pad_val, + dtype=results['img'].dtype) + else: + mosaic_img = np.full( + (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)), + self.pad_val, + dtype=results['img'].dtype) + + # mosaic center x, y + center_x = int( + random.uniform(*self.center_ratio_range) * self.img_scale[1]) + center_y = int( + random.uniform(*self.center_ratio_range) * self.img_scale[0]) + center_position = (center_x, center_y) + + loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') + for i, loc in enumerate(loc_strs): + if loc == 'top_left': + results_patch = copy.deepcopy(results) + else: + results_patch = copy.deepcopy(results['mix_results'][i - 1]) + + img_i = results_patch['img'] + h_i, w_i = img_i.shape[:2] + # keep_ratio resize + scale_ratio_i = min(self.img_scale[0] / h_i, + self.img_scale[1] / w_i) + img_i = dbnet_cv.imresize( + img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i))) + + # compute the combine parameters + paste_coord, crop_coord = self._mosaic_combine( + loc, center_position, img_i.shape[:2][::-1]) + x1_p, y1_p, x2_p, y2_p = paste_coord + x1_c, y1_c, x2_c, y2_c = crop_coord + + # crop and paste image + mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c] + + # adjust coordinate + gt_bboxes_i = results_patch['gt_bboxes'] + gt_labels_i = results_patch['gt_labels'] + + if gt_bboxes_i.shape[0] > 0: + padw = x1_p - x1_c + padh = y1_p - y1_c + gt_bboxes_i[:, 0::2] = \ + scale_ratio_i * gt_bboxes_i[:, 0::2] + padw + gt_bboxes_i[:, 1::2] = \ + scale_ratio_i * gt_bboxes_i[:, 1::2] + padh + + mosaic_bboxes.append(gt_bboxes_i) + mosaic_labels.append(gt_labels_i) + + if len(mosaic_labels) > 0: + mosaic_bboxes = np.concatenate(mosaic_bboxes, 0) + mosaic_labels = np.concatenate(mosaic_labels, 0) + + if self.bbox_clip_border: + mosaic_bboxes[:, 0::2] = np.clip(mosaic_bboxes[:, 0::2], 0, + 2 * self.img_scale[1]) + mosaic_bboxes[:, 1::2] = np.clip(mosaic_bboxes[:, 1::2], 0, + 2 * self.img_scale[0]) + + if not self.skip_filter: + mosaic_bboxes, mosaic_labels = \ + self._filter_box_candidates(mosaic_bboxes, mosaic_labels) + + # remove outside bboxes + inside_inds = find_inside_bboxes(mosaic_bboxes, 2 * self.img_scale[0], + 2 * self.img_scale[1]) + mosaic_bboxes = mosaic_bboxes[inside_inds] + mosaic_labels = mosaic_labels[inside_inds] + + results['img'] = mosaic_img + results['img_shape'] = mosaic_img.shape + results['gt_bboxes'] = mosaic_bboxes + results['gt_labels'] = mosaic_labels + + return results + + def _mosaic_combine(self, loc, center_position_xy, img_shape_wh): + """Calculate global coordinate of mosaic image and local coordinate of + cropped sub-image. + + Args: + loc (str): Index for the sub-image, loc in ('top_left', + 'top_right', 'bottom_left', 'bottom_right'). + center_position_xy (Sequence[float]): Mixing center for 4 images, + (x, y). + img_shape_wh (Sequence[int]): Width and height of sub-image + + Returns: + tuple[tuple[float]]: Corresponding coordinate of pasting and + cropping + - paste_coord (tuple): paste corner coordinate in mosaic image. + - crop_coord (tuple): crop corner coordinate in mosaic image. + """ + assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right') + if loc == 'top_left': + # index0 to top left part of image + x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \ + max(center_position_xy[1] - img_shape_wh[1], 0), \ + center_position_xy[0], \ + center_position_xy[1] + crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - ( + y2 - y1), img_shape_wh[0], img_shape_wh[1] + + elif loc == 'top_right': + # index1 to top right part of image + x1, y1, x2, y2 = center_position_xy[0], \ + max(center_position_xy[1] - img_shape_wh[1], 0), \ + min(center_position_xy[0] + img_shape_wh[0], + self.img_scale[1] * 2), \ + center_position_xy[1] + crop_coord = 0, img_shape_wh[1] - (y2 - y1), min( + img_shape_wh[0], x2 - x1), img_shape_wh[1] + + elif loc == 'bottom_left': + # index2 to bottom left part of image + x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \ + center_position_xy[1], \ + center_position_xy[0], \ + min(self.img_scale[0] * 2, center_position_xy[1] + + img_shape_wh[1]) + crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min( + y2 - y1, img_shape_wh[1]) + + else: + # index3 to bottom right part of image + x1, y1, x2, y2 = center_position_xy[0], \ + center_position_xy[1], \ + min(center_position_xy[0] + img_shape_wh[0], + self.img_scale[1] * 2), \ + min(self.img_scale[0] * 2, center_position_xy[1] + + img_shape_wh[1]) + crop_coord = 0, 0, min(img_shape_wh[0], + x2 - x1), min(y2 - y1, img_shape_wh[1]) + + paste_coord = x1, y1, x2, y2 + return paste_coord, crop_coord + + def _filter_box_candidates(self, bboxes, labels): + """Filter out bboxes too small after Mosaic.""" + bbox_w = bboxes[:, 2] - bboxes[:, 0] + bbox_h = bboxes[:, 3] - bboxes[:, 1] + valid_inds = (bbox_w > self.min_bbox_size) & \ + (bbox_h > self.min_bbox_size) + valid_inds = np.nonzero(valid_inds)[0] + return bboxes[valid_inds], labels[valid_inds] + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'img_scale={self.img_scale}, ' + repr_str += f'center_ratio_range={self.center_ratio_range}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'min_bbox_size={self.min_bbox_size}, ' + repr_str += f'skip_filter={self.skip_filter})' + return repr_str + + +@PIPELINES.register_module() +class MixUp: + """MixUp data augmentation. + + .. code:: text + + mixup transform + +------------------------------+ + | mixup image | | + | +--------|--------+ | + | | | | | + |---------------+ | | + | | | | + | | image | | + | | | | + | | | | + | |-----------------+ | + | pad | + +------------------------------+ + + The mixup transform steps are as follows: + + 1. Another random image is picked by dataset and embedded in + the top left patch(after padding and resizing) + 2. The target of mixup transform is the weighted average of mixup + image and origin image. + + Args: + img_scale (Sequence[int]): Image output size after mixup pipeline. + The shape order should be (height, width). Default: (640, 640). + ratio_range (Sequence[float]): Scale ratio of mixup image. + Default: (0.5, 1.5). + flip_ratio (float): Horizontal flip ratio of mixup image. + Default: 0.5. + pad_val (int): Pad value. Default: 114. + max_iters (int): The maximum number of iterations. If the number of + iterations is greater than `max_iters`, but gt_bbox is still + empty, then the iteration is terminated. Default: 15. + min_bbox_size (float): Width and height threshold to filter bboxes. + If the height or width of a box is smaller than this value, it + will be removed. Default: 5. + min_area_ratio (float): Threshold of area ratio between + original bboxes and wrapped bboxes. If smaller than this value, + the box will be removed. Default: 0.2. + max_aspect_ratio (float): Aspect ratio of width and height + threshold to filter bboxes. If max(h/w, w/h) larger than this + value, the box will be removed. Default: 20. + bbox_clip_border (bool, optional): Whether to clip the objects outside + the border of the image. In some dataset like MOT17, the gt bboxes + are allowed to cross the border of images. Therefore, we don't + need to clip the gt bboxes in these cases. Defaults to True. + skip_filter (bool): Whether to skip filtering rules. If it + is True, the filter rule will not be applied, and the + `min_bbox_size` and `min_area_ratio` and `max_aspect_ratio` + is invalid. Default to True. + """ + + def __init__(self, + img_scale=(640, 640), + ratio_range=(0.5, 1.5), + flip_ratio=0.5, + pad_val=114, + max_iters=15, + min_bbox_size=5, + min_area_ratio=0.2, + max_aspect_ratio=20, + bbox_clip_border=True, + skip_filter=True): + assert isinstance(img_scale, tuple) + log_img_scale(img_scale, skip_square=True) + self.dynamic_scale = img_scale + self.ratio_range = ratio_range + self.flip_ratio = flip_ratio + self.pad_val = pad_val + self.max_iters = max_iters + self.min_bbox_size = min_bbox_size + self.min_area_ratio = min_area_ratio + self.max_aspect_ratio = max_aspect_ratio + self.bbox_clip_border = bbox_clip_border + self.skip_filter = skip_filter + + def __call__(self, results): + """Call function to make a mixup of image. + + Args: + results (dict): Result dict. + + Returns: + dict: Result dict with mixup transformed. + """ + + results = self._mixup_transform(results) + return results + + def get_indexes(self, dataset): + """Call function to collect indexes. + + Args: + dataset (:obj:`MultiImageMixDataset`): The dataset. + + Returns: + list: indexes. + """ + + for i in range(self.max_iters): + index = random.randint(0, len(dataset)) + gt_bboxes_i = dataset.get_ann_info(index)['bboxes'] + if len(gt_bboxes_i) != 0: + break + + return index + + def _mixup_transform(self, results): + """MixUp transform function. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + + assert 'mix_results' in results + assert len( + results['mix_results']) == 1, 'MixUp only support 2 images now !' + + if results['mix_results'][0]['gt_bboxes'].shape[0] == 0: + # empty bbox + return results + + retrieve_results = results['mix_results'][0] + retrieve_img = retrieve_results['img'] + + jit_factor = random.uniform(*self.ratio_range) + is_filp = random.uniform(0, 1) > self.flip_ratio + + if len(retrieve_img.shape) == 3: + out_img = np.ones( + (self.dynamic_scale[0], self.dynamic_scale[1], 3), + dtype=retrieve_img.dtype) * self.pad_val + else: + out_img = np.ones( + self.dynamic_scale, dtype=retrieve_img.dtype) * self.pad_val + + # 1. keep_ratio resize + scale_ratio = min(self.dynamic_scale[0] / retrieve_img.shape[0], + self.dynamic_scale[1] / retrieve_img.shape[1]) + retrieve_img = dbnet_cv.imresize( + retrieve_img, (int(retrieve_img.shape[1] * scale_ratio), + int(retrieve_img.shape[0] * scale_ratio))) + + # 2. paste + out_img[:retrieve_img.shape[0], :retrieve_img.shape[1]] = retrieve_img + + # 3. scale jit + scale_ratio *= jit_factor + out_img = dbnet_cv.imresize(out_img, (int(out_img.shape[1] * jit_factor), + int(out_img.shape[0] * jit_factor))) + + # 4. flip + if is_filp: + out_img = out_img[:, ::-1, :] + + # 5. random crop + ori_img = results['img'] + origin_h, origin_w = out_img.shape[:2] + target_h, target_w = ori_img.shape[:2] + padded_img = np.zeros( + (max(origin_h, target_h), max(origin_w, + target_w), 3)).astype(np.uint8) + padded_img[:origin_h, :origin_w] = out_img + + x_offset, y_offset = 0, 0 + if padded_img.shape[0] > target_h: + y_offset = random.randint(0, padded_img.shape[0] - target_h) + if padded_img.shape[1] > target_w: + x_offset = random.randint(0, padded_img.shape[1] - target_w) + padded_cropped_img = padded_img[y_offset:y_offset + target_h, + x_offset:x_offset + target_w] + + # 6. adjust bbox + retrieve_gt_bboxes = retrieve_results['gt_bboxes'] + retrieve_gt_bboxes[:, 0::2] = retrieve_gt_bboxes[:, 0::2] * scale_ratio + retrieve_gt_bboxes[:, 1::2] = retrieve_gt_bboxes[:, 1::2] * scale_ratio + if self.bbox_clip_border: + retrieve_gt_bboxes[:, 0::2] = np.clip(retrieve_gt_bboxes[:, 0::2], + 0, origin_w) + retrieve_gt_bboxes[:, 1::2] = np.clip(retrieve_gt_bboxes[:, 1::2], + 0, origin_h) + + if is_filp: + retrieve_gt_bboxes[:, 0::2] = ( + origin_w - retrieve_gt_bboxes[:, 0::2][:, ::-1]) + + # 7. filter + cp_retrieve_gt_bboxes = retrieve_gt_bboxes.copy() + cp_retrieve_gt_bboxes[:, 0::2] = \ + cp_retrieve_gt_bboxes[:, 0::2] - x_offset + cp_retrieve_gt_bboxes[:, 1::2] = \ + cp_retrieve_gt_bboxes[:, 1::2] - y_offset + if self.bbox_clip_border: + cp_retrieve_gt_bboxes[:, 0::2] = np.clip( + cp_retrieve_gt_bboxes[:, 0::2], 0, target_w) + cp_retrieve_gt_bboxes[:, 1::2] = np.clip( + cp_retrieve_gt_bboxes[:, 1::2], 0, target_h) + + # 8. mix up + ori_img = ori_img.astype(np.float32) + mixup_img = 0.5 * ori_img + 0.5 * padded_cropped_img.astype(np.float32) + + retrieve_gt_labels = retrieve_results['gt_labels'] + if not self.skip_filter: + keep_list = self._filter_box_candidates(retrieve_gt_bboxes.T, + cp_retrieve_gt_bboxes.T) + + retrieve_gt_labels = retrieve_gt_labels[keep_list] + cp_retrieve_gt_bboxes = cp_retrieve_gt_bboxes[keep_list] + + mixup_gt_bboxes = np.concatenate( + (results['gt_bboxes'], cp_retrieve_gt_bboxes), axis=0) + mixup_gt_labels = np.concatenate( + (results['gt_labels'], retrieve_gt_labels), axis=0) + + # remove outside bbox + inside_inds = find_inside_bboxes(mixup_gt_bboxes, target_h, target_w) + mixup_gt_bboxes = mixup_gt_bboxes[inside_inds] + mixup_gt_labels = mixup_gt_labels[inside_inds] + + results['img'] = mixup_img.astype(np.uint8) + results['img_shape'] = mixup_img.shape + results['gt_bboxes'] = mixup_gt_bboxes + results['gt_labels'] = mixup_gt_labels + + return results + + def _filter_box_candidates(self, bbox1, bbox2): + """Compute candidate boxes which include following 5 things: + + bbox1 before augment, bbox2 after augment, min_bbox_size (pixels), + min_area_ratio, max_aspect_ratio. + """ + + w1, h1 = bbox1[2] - bbox1[0], bbox1[3] - bbox1[1] + w2, h2 = bbox2[2] - bbox2[0], bbox2[3] - bbox2[1] + ar = np.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16)) + return ((w2 > self.min_bbox_size) + & (h2 > self.min_bbox_size) + & (w2 * h2 / (w1 * h1 + 1e-16) > self.min_area_ratio) + & (ar < self.max_aspect_ratio)) + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'dynamic_scale={self.dynamic_scale}, ' + repr_str += f'ratio_range={self.ratio_range}, ' + repr_str += f'flip_ratio={self.flip_ratio}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'max_iters={self.max_iters}, ' + repr_str += f'min_bbox_size={self.min_bbox_size}, ' + repr_str += f'min_area_ratio={self.min_area_ratio}, ' + repr_str += f'max_aspect_ratio={self.max_aspect_ratio}, ' + repr_str += f'skip_filter={self.skip_filter})' + return repr_str + + +@PIPELINES.register_module() +class RandomAffine: + """Random affine transform data augmentation. + + This operation randomly generates affine transform matrix which including + rotation, translation, shear and scaling transforms. + + Args: + max_rotate_degree (float): Maximum degrees of rotation transform. + Default: 10. + max_translate_ratio (float): Maximum ratio of translation. + Default: 0.1. + scaling_ratio_range (tuple[float]): Min and max ratio of + scaling transform. Default: (0.5, 1.5). + max_shear_degree (float): Maximum degrees of shear + transform. Default: 2. + border (tuple[int]): Distance from height and width sides of input + image to adjust output shape. Only used in mosaic dataset. + Default: (0, 0). + border_val (tuple[int]): Border padding values of 3 channels. + Default: (114, 114, 114). + min_bbox_size (float): Width and height threshold to filter bboxes. + If the height or width of a box is smaller than this value, it + will be removed. Default: 2. + min_area_ratio (float): Threshold of area ratio between + original bboxes and wrapped bboxes. If smaller than this value, + the box will be removed. Default: 0.2. + max_aspect_ratio (float): Aspect ratio of width and height + threshold to filter bboxes. If max(h/w, w/h) larger than this + value, the box will be removed. + bbox_clip_border (bool, optional): Whether to clip the objects outside + the border of the image. In some dataset like MOT17, the gt bboxes + are allowed to cross the border of images. Therefore, we don't + need to clip the gt bboxes in these cases. Defaults to True. + skip_filter (bool): Whether to skip filtering rules. If it + is True, the filter rule will not be applied, and the + `min_bbox_size` and `min_area_ratio` and `max_aspect_ratio` + is invalid. Default to True. + """ + + def __init__(self, + max_rotate_degree=10.0, + max_translate_ratio=0.1, + scaling_ratio_range=(0.5, 1.5), + max_shear_degree=2.0, + border=(0, 0), + border_val=(114, 114, 114), + min_bbox_size=2, + min_area_ratio=0.2, + max_aspect_ratio=20, + bbox_clip_border=True, + skip_filter=True): + assert 0 <= max_translate_ratio <= 1 + assert scaling_ratio_range[0] <= scaling_ratio_range[1] + assert scaling_ratio_range[0] > 0 + self.max_rotate_degree = max_rotate_degree + self.max_translate_ratio = max_translate_ratio + self.scaling_ratio_range = scaling_ratio_range + self.max_shear_degree = max_shear_degree + self.border = border + self.border_val = border_val + self.min_bbox_size = min_bbox_size + self.min_area_ratio = min_area_ratio + self.max_aspect_ratio = max_aspect_ratio + self.bbox_clip_border = bbox_clip_border + self.skip_filter = skip_filter + + def __call__(self, results): + img = results['img'] + height = img.shape[0] + self.border[0] * 2 + width = img.shape[1] + self.border[1] * 2 + + # Rotation + rotation_degree = random.uniform(-self.max_rotate_degree, + self.max_rotate_degree) + rotation_matrix = self._get_rotation_matrix(rotation_degree) + + # Scaling + scaling_ratio = random.uniform(self.scaling_ratio_range[0], + self.scaling_ratio_range[1]) + scaling_matrix = self._get_scaling_matrix(scaling_ratio) + + # Shear + x_degree = random.uniform(-self.max_shear_degree, + self.max_shear_degree) + y_degree = random.uniform(-self.max_shear_degree, + self.max_shear_degree) + shear_matrix = self._get_shear_matrix(x_degree, y_degree) + + # Translation + trans_x = random.uniform(-self.max_translate_ratio, + self.max_translate_ratio) * width + trans_y = random.uniform(-self.max_translate_ratio, + self.max_translate_ratio) * height + translate_matrix = self._get_translation_matrix(trans_x, trans_y) + + warp_matrix = ( + translate_matrix @ shear_matrix @ rotation_matrix @ scaling_matrix) + + img = cv2.warpPerspective( + img, + warp_matrix, + dsize=(width, height), + borderValue=self.border_val) + results['img'] = img + results['img_shape'] = img.shape + + for key in results.get('bbox_fields', []): + bboxes = results[key] + num_bboxes = len(bboxes) + if num_bboxes: + # homogeneous coordinates + xs = bboxes[:, [0, 0, 2, 2]].reshape(num_bboxes * 4) + ys = bboxes[:, [1, 3, 3, 1]].reshape(num_bboxes * 4) + ones = np.ones_like(xs) + points = np.vstack([xs, ys, ones]) + + warp_points = warp_matrix @ points + warp_points = warp_points[:2] / warp_points[2] + xs = warp_points[0].reshape(num_bboxes, 4) + ys = warp_points[1].reshape(num_bboxes, 4) + + warp_bboxes = np.vstack( + (xs.min(1), ys.min(1), xs.max(1), ys.max(1))).T + + if self.bbox_clip_border: + warp_bboxes[:, [0, 2]] = \ + warp_bboxes[:, [0, 2]].clip(0, width) + warp_bboxes[:, [1, 3]] = \ + warp_bboxes[:, [1, 3]].clip(0, height) + + # remove outside bbox + valid_index = find_inside_bboxes(warp_bboxes, height, width) + if not self.skip_filter: + # filter bboxes + filter_index = self.filter_gt_bboxes( + bboxes * scaling_ratio, warp_bboxes) + valid_index = valid_index & filter_index + + results[key] = warp_bboxes[valid_index] + if key in ['gt_bboxes']: + if 'gt_labels' in results: + results['gt_labels'] = results['gt_labels'][ + valid_index] + + if 'gt_masks' in results: + raise NotImplementedError( + 'RandomAffine only supports bbox.') + return results + + def filter_gt_bboxes(self, origin_bboxes, wrapped_bboxes): + origin_w = origin_bboxes[:, 2] - origin_bboxes[:, 0] + origin_h = origin_bboxes[:, 3] - origin_bboxes[:, 1] + wrapped_w = wrapped_bboxes[:, 2] - wrapped_bboxes[:, 0] + wrapped_h = wrapped_bboxes[:, 3] - wrapped_bboxes[:, 1] + aspect_ratio = np.maximum(wrapped_w / (wrapped_h + 1e-16), + wrapped_h / (wrapped_w + 1e-16)) + + wh_valid_idx = (wrapped_w > self.min_bbox_size) & \ + (wrapped_h > self.min_bbox_size) + area_valid_idx = wrapped_w * wrapped_h / (origin_w * origin_h + + 1e-16) > self.min_area_ratio + aspect_ratio_valid_idx = aspect_ratio < self.max_aspect_ratio + return wh_valid_idx & area_valid_idx & aspect_ratio_valid_idx + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(max_rotate_degree={self.max_rotate_degree}, ' + repr_str += f'max_translate_ratio={self.max_translate_ratio}, ' + repr_str += f'scaling_ratio={self.scaling_ratio_range}, ' + repr_str += f'max_shear_degree={self.max_shear_degree}, ' + repr_str += f'border={self.border}, ' + repr_str += f'border_val={self.border_val}, ' + repr_str += f'min_bbox_size={self.min_bbox_size}, ' + repr_str += f'min_area_ratio={self.min_area_ratio}, ' + repr_str += f'max_aspect_ratio={self.max_aspect_ratio}, ' + repr_str += f'skip_filter={self.skip_filter})' + return repr_str + + @staticmethod + def _get_rotation_matrix(rotate_degrees): + radian = math.radians(rotate_degrees) + rotation_matrix = np.array( + [[np.cos(radian), -np.sin(radian), 0.], + [np.sin(radian), np.cos(radian), 0.], [0., 0., 1.]], + dtype=np.float32) + return rotation_matrix + + @staticmethod + def _get_scaling_matrix(scale_ratio): + scaling_matrix = np.array( + [[scale_ratio, 0., 0.], [0., scale_ratio, 0.], [0., 0., 1.]], + dtype=np.float32) + return scaling_matrix + + @staticmethod + def _get_share_matrix(scale_ratio): + scaling_matrix = np.array( + [[scale_ratio, 0., 0.], [0., scale_ratio, 0.], [0., 0., 1.]], + dtype=np.float32) + return scaling_matrix + + @staticmethod + def _get_shear_matrix(x_shear_degrees, y_shear_degrees): + x_radian = math.radians(x_shear_degrees) + y_radian = math.radians(y_shear_degrees) + shear_matrix = np.array([[1, np.tan(x_radian), 0.], + [np.tan(y_radian), 1, 0.], [0., 0., 1.]], + dtype=np.float32) + return shear_matrix + + @staticmethod + def _get_translation_matrix(x, y): + translation_matrix = np.array([[1, 0., x], [0., 1, y], [0., 0., 1.]], + dtype=np.float32) + return translation_matrix + + +@PIPELINES.register_module() +class YOLOXHSVRandomAug: + """Apply HSV augmentation to image sequentially. It is referenced from + https://github.com/Megvii- + BaseDetection/YOLOX/blob/main/yolox/data/data_augment.py#L21. + + Args: + hue_delta (int): delta of hue. Default: 5. + saturation_delta (int): delta of saturation. Default: 30. + value_delta (int): delat of value. Default: 30. + """ + + def __init__(self, hue_delta=5, saturation_delta=30, value_delta=30): + self.hue_delta = hue_delta + self.saturation_delta = saturation_delta + self.value_delta = value_delta + + def __call__(self, results): + img = results['img'] + hsv_gains = np.random.uniform(-1, 1, 3) * [ + self.hue_delta, self.saturation_delta, self.value_delta + ] + # random selection of h, s, v + hsv_gains *= np.random.randint(0, 2, 3) + # prevent overflow + hsv_gains = hsv_gains.astype(np.int16) + img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.int16) + + img_hsv[..., 0] = (img_hsv[..., 0] + hsv_gains[0]) % 180 + img_hsv[..., 1] = np.clip(img_hsv[..., 1] + hsv_gains[1], 0, 255) + img_hsv[..., 2] = np.clip(img_hsv[..., 2] + hsv_gains[2], 0, 255) + cv2.cvtColor(img_hsv.astype(img.dtype), cv2.COLOR_HSV2BGR, dst=img) + + results['img'] = img + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(hue_delta={self.hue_delta}, ' + repr_str += f'saturation_delta={self.saturation_delta}, ' + repr_str += f'value_delta={self.value_delta})' + return repr_str + + +@PIPELINES.register_module() +class CopyPaste: + """Simple Copy-Paste is a Strong Data Augmentation Method for Instance + Segmentation The simple copy-paste transform steps are as follows: + + 1. The destination image is already resized with aspect ratio kept, + cropped and padded. + 2. Randomly select a source image, which is also already resized + with aspect ratio kept, cropped and padded in a similar way + as the destination image. + 3. Randomly select some objects from the source image. + 4. Paste these source objects to the destination image directly, + due to the source and destination image have the same size. + 5. Update object masks of the destination image, for some origin objects + may be occluded. + 6. Generate bboxes from the updated destination masks and + filter some objects which are totally occluded, and adjust bboxes + which are partly occluded. + 7. Append selected source bboxes, masks, and labels. + + Args: + max_num_pasted (int): The maximum number of pasted objects. + Default: 100. + bbox_occluded_thr (int): The threshold of occluded bbox. + Default: 10. + mask_occluded_thr (int): The threshold of occluded mask. + Default: 300. + selected (bool): Whether select objects or not. If select is False, + all objects of the source image will be pasted to the + destination image. + Default: True. + """ + + def __init__( + self, + max_num_pasted=100, + bbox_occluded_thr=10, + mask_occluded_thr=300, + selected=True, + ): + self.max_num_pasted = max_num_pasted + self.bbox_occluded_thr = bbox_occluded_thr + self.mask_occluded_thr = mask_occluded_thr + self.selected = selected + + def get_indexes(self, dataset): + """Call function to collect indexes.s. + + Args: + dataset (:obj:`MultiImageMixDataset`): The dataset. + Returns: + list: Indexes. + """ + return random.randint(0, len(dataset)) + + def __call__(self, results): + """Call function to make a copy-paste of image. + + Args: + results (dict): Result dict. + Returns: + dict: Result dict with copy-paste transformed. + """ + + assert 'mix_results' in results + num_images = len(results['mix_results']) + assert num_images == 1, \ + f'CopyPaste only supports processing 2 images, got {num_images}' + if self.selected: + selected_results = self._select_object(results['mix_results'][0]) + else: + selected_results = results['mix_results'][0] + return self._copy_paste(results, selected_results) + + def _select_object(self, results): + """Select some objects from the source results.""" + bboxes = results['gt_bboxes'] + labels = results['gt_labels'] + masks = results['gt_masks'] + max_num_pasted = min(bboxes.shape[0] + 1, self.max_num_pasted) + num_pasted = np.random.randint(0, max_num_pasted) + selected_inds = np.random.choice( + bboxes.shape[0], size=num_pasted, replace=False) + + selected_bboxes = bboxes[selected_inds] + selected_labels = labels[selected_inds] + selected_masks = masks[selected_inds] + + results['gt_bboxes'] = selected_bboxes + results['gt_labels'] = selected_labels + results['gt_masks'] = selected_masks + return results + + def _copy_paste(self, dst_results, src_results): + """CopyPaste transform function. + + Args: + dst_results (dict): Result dict of the destination image. + src_results (dict): Result dict of the source image. + Returns: + dict: Updated result dict. + """ + dst_img = dst_results['img'] + dst_bboxes = dst_results['gt_bboxes'] + dst_labels = dst_results['gt_labels'] + dst_masks = dst_results['gt_masks'] + + src_img = src_results['img'] + src_bboxes = src_results['gt_bboxes'] + src_labels = src_results['gt_labels'] + src_masks = src_results['gt_masks'] + + if len(src_bboxes) == 0: + return dst_results + + # update masks and generate bboxes from updated masks + composed_mask = np.where(np.any(src_masks.masks, axis=0), 1, 0) + updated_dst_masks = self.get_updated_masks(dst_masks, composed_mask) + updated_dst_bboxes = updated_dst_masks.get_bboxes() + assert len(updated_dst_bboxes) == len(updated_dst_masks) + + # filter totally occluded objects + bboxes_inds = np.all( + np.abs( + (updated_dst_bboxes - dst_bboxes)) <= self.bbox_occluded_thr, + axis=-1) + masks_inds = updated_dst_masks.masks.sum( + axis=(1, 2)) > self.mask_occluded_thr + valid_inds = bboxes_inds | masks_inds + + # Paste source objects to destination image directly + img = dst_img * (1 - composed_mask[..., np.newaxis] + ) + src_img * composed_mask[..., np.newaxis] + bboxes = np.concatenate([updated_dst_bboxes[valid_inds], src_bboxes]) + labels = np.concatenate([dst_labels[valid_inds], src_labels]) + masks = np.concatenate( + [updated_dst_masks.masks[valid_inds], src_masks.masks]) + + dst_results['img'] = img + dst_results['gt_bboxes'] = bboxes + dst_results['gt_labels'] = labels + dst_results['gt_masks'] = BitmapMasks(masks, masks.shape[1], + masks.shape[2]) + + return dst_results + + def get_updated_masks(self, masks, composed_mask): + assert masks.masks.shape[-2:] == composed_mask.shape[-2:], \ + 'Cannot compare two arrays of different size' + masks.masks = np.where(composed_mask, 0, masks.masks) + return masks + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'max_num_pasted={self.max_num_pasted}, ' + repr_str += f'bbox_occluded_thr={self.bbox_occluded_thr}, ' + repr_str += f'mask_occluded_thr={self.mask_occluded_thr}, ' + repr_str += f'selected={self.selected}, ' + return repr_str diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/__init__.py new file mode 100755 index 00000000..d526159f --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/__init__.py @@ -0,0 +1,10 @@ +# # Copyright (c) OpenMMLab. All rights reserved. +from .class_aware_sampler import ClassAwareSampler +from .distributed_sampler import DistributedSampler +from .group_sampler import DistributedGroupSampler, GroupSampler +from .infinite_sampler import InfiniteBatchSampler, InfiniteGroupBatchSampler + +__all__ = [ + 'DistributedSampler', 'DistributedGroupSampler', 'GroupSampler', + 'InfiniteGroupBatchSampler', 'InfiniteBatchSampler', 'ClassAwareSampler' +] diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/class_aware_sampler.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/class_aware_sampler.py new file mode 100755 index 00000000..91ee88d7 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/class_aware_sampler.py @@ -0,0 +1,176 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +from dbnet_cv.runner import get_dist_info +from torch.utils.data import Sampler + +from dbnet_det.core.utils import sync_random_seed + + +class ClassAwareSampler(Sampler): + r"""Sampler that restricts data loading to the label of the dataset. + + A class-aware sampling strategy to effectively tackle the + non-uniform class distribution. The length of the training data is + consistent with source data. Simple improvements based on `Relay + Backpropagation for Effective Learning of Deep Convolutional + Neural Networks `_ + + The implementation logic is referred to + https://github.com/Sense-X/TSD/blob/master/dbnet_det/datasets/samplers/distributed_classaware_sampler.py + + Args: + dataset: Dataset used for sampling. + samples_per_gpu (int): When model is :obj:`DistributedDataParallel`, + it is the number of training samples on each GPU. + When model is :obj:`DataParallel`, it is + `num_gpus * samples_per_gpu`. + Default : 1. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + seed (int, optional): random seed used to shuffle the sampler if + ``shuffle=True``. This number should be identical across all + processes in the distributed group. Default: 0. + num_sample_class (int): The number of samples taken from each + per-label list. Default: 1 + """ + + def __init__(self, + dataset, + samples_per_gpu=1, + num_replicas=None, + rank=None, + seed=0, + num_sample_class=1): + _rank, _num_replicas = get_dist_info() + if num_replicas is None: + num_replicas = _num_replicas + if rank is None: + rank = _rank + + self.dataset = dataset + self.num_replicas = num_replicas + self.samples_per_gpu = samples_per_gpu + self.rank = rank + self.epoch = 0 + # Must be the same across all workers. If None, will use a + # random seed shared among workers + # (require synchronization among all workers) + self.seed = sync_random_seed(seed) + + # The number of samples taken from each per-label list + assert num_sample_class > 0 and isinstance(num_sample_class, int) + self.num_sample_class = num_sample_class + # Get per-label image list from dataset + assert hasattr(dataset, 'get_cat2imgs'), \ + 'dataset must have `get_cat2imgs` function' + self.cat_dict = dataset.get_cat2imgs() + + self.num_samples = int( + math.ceil( + len(self.dataset) * 1.0 / self.num_replicas / + self.samples_per_gpu)) * self.samples_per_gpu + self.total_size = self.num_samples * self.num_replicas + + # get number of images containing each category + self.num_cat_imgs = [len(x) for x in self.cat_dict.values()] + # filter labels without images + self.valid_cat_inds = [ + i for i, length in enumerate(self.num_cat_imgs) if length != 0 + ] + self.num_classes = len(self.valid_cat_inds) + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch + self.seed) + + # initialize label list + label_iter_list = RandomCycleIter(self.valid_cat_inds, generator=g) + # initialize each per-label image list + data_iter_dict = dict() + for i in self.valid_cat_inds: + data_iter_dict[i] = RandomCycleIter(self.cat_dict[i], generator=g) + + def gen_cat_img_inds(cls_list, data_dict, num_sample_cls): + """Traverse the categories and extract `num_sample_cls` image + indexes of the corresponding categories one by one.""" + id_indices = [] + for _ in range(len(cls_list)): + cls_idx = next(cls_list) + for _ in range(num_sample_cls): + id = next(data_dict[cls_idx]) + id_indices.append(id) + return id_indices + + # deterministically shuffle based on epoch + num_bins = int( + math.ceil(self.total_size * 1.0 / self.num_classes / + self.num_sample_class)) + indices = [] + for i in range(num_bins): + indices += gen_cat_img_inds(label_iter_list, data_iter_dict, + self.num_sample_class) + + # fix extra samples to make it evenly divisible + if len(indices) >= self.total_size: + indices = indices[:self.total_size] + else: + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + offset = self.num_samples * self.rank + indices = indices[offset:offset + self.num_samples] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch + + +class RandomCycleIter: + """Shuffle the list and do it again after the list have traversed. + + The implementation logic is referred to + https://github.com/wutong16/DistributionBalancedLoss/blob/master/mllt/datasets/loader/sampler.py + + Example: + >>> label_list = [0, 1, 2, 4, 5] + >>> g = torch.Generator() + >>> g.manual_seed(0) + >>> label_iter_list = RandomCycleIter(label_list, generator=g) + >>> index = next(label_iter_list) + Args: + data (list or ndarray): The data that needs to be shuffled. + generator: An torch.Generator object, which is used in setting the seed + for generating random numbers. + """ # noqa: W605 + + def __init__(self, data, generator=None): + self.data = data + self.length = len(data) + self.index = torch.randperm(self.length, generator=generator).numpy() + self.i = 0 + self.generator = generator + + def __iter__(self): + return self + + def __len__(self): + return len(self.data) + + def __next__(self): + if self.i == self.length: + self.index = torch.randperm( + self.length, generator=self.generator).numpy() + self.i = 0 + idx = self.data[self.index[self.i]] + self.i += 1 + return idx diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/distributed_sampler.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/distributed_sampler.py new file mode 100755 index 00000000..b7e98df4 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/distributed_sampler.py @@ -0,0 +1,54 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +from torch.utils.data import DistributedSampler as _DistributedSampler + +from dbnet_det.core.utils import sync_random_seed +from dbnet_det.utils import get_device + + +class DistributedSampler(_DistributedSampler): + + def __init__(self, + dataset, + num_replicas=None, + rank=None, + shuffle=True, + seed=0): + super().__init__( + dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) + + # In distributed sampling, different ranks should sample + # non-overlapped data in the dataset. Therefore, this function + # is used to make sure that each rank shuffles the data indices + # in the same order based on the same seed. Then different ranks + # could use different indices to select non-overlapped data from the + # same data list. + device = get_device() + self.seed = sync_random_seed(seed, device) + + def __iter__(self): + # deterministically shuffle based on epoch + if self.shuffle: + g = torch.Generator() + # When :attr:`shuffle=True`, this ensures all replicas + # use a different random ordering for each epoch. + # Otherwise, the next iteration of this sampler will + # yield the same ordering. + g.manual_seed(self.epoch + self.seed) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + # in case that indices is shorter than half of total_size + indices = (indices * + math.ceil(self.total_size / len(indices)))[:self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/group_sampler.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/group_sampler.py new file mode 100755 index 00000000..4fc2b0de --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/group_sampler.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import numpy as np +import torch +from dbnet_cv.runner import get_dist_info +from torch.utils.data import Sampler + + +class GroupSampler(Sampler): + + def __init__(self, dataset, samples_per_gpu=1): + assert hasattr(dataset, 'flag') + self.dataset = dataset + self.samples_per_gpu = samples_per_gpu + self.flag = dataset.flag.astype(np.int64) + self.group_sizes = np.bincount(self.flag) + self.num_samples = 0 + for i, size in enumerate(self.group_sizes): + self.num_samples += int(np.ceil( + size / self.samples_per_gpu)) * self.samples_per_gpu + + def __iter__(self): + indices = [] + for i, size in enumerate(self.group_sizes): + if size == 0: + continue + indice = np.where(self.flag == i)[0] + assert len(indice) == size + np.random.shuffle(indice) + num_extra = int(np.ceil(size / self.samples_per_gpu) + ) * self.samples_per_gpu - len(indice) + indice = np.concatenate( + [indice, np.random.choice(indice, num_extra)]) + indices.append(indice) + indices = np.concatenate(indices) + indices = [ + indices[i * self.samples_per_gpu:(i + 1) * self.samples_per_gpu] + for i in np.random.permutation( + range(len(indices) // self.samples_per_gpu)) + ] + indices = np.concatenate(indices) + indices = indices.astype(np.int64).tolist() + assert len(indices) == self.num_samples + return iter(indices) + + def __len__(self): + return self.num_samples + + +class DistributedGroupSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size. + + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + seed (int, optional): random seed used to shuffle the sampler if + ``shuffle=True``. This number should be identical across all + processes in the distributed group. Default: 0. + """ + + def __init__(self, + dataset, + samples_per_gpu=1, + num_replicas=None, + rank=None, + seed=0): + _rank, _num_replicas = get_dist_info() + if num_replicas is None: + num_replicas = _num_replicas + if rank is None: + rank = _rank + self.dataset = dataset + self.samples_per_gpu = samples_per_gpu + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.seed = seed if seed is not None else 0 + + assert hasattr(self.dataset, 'flag') + self.flag = self.dataset.flag + self.group_sizes = np.bincount(self.flag) + + self.num_samples = 0 + for i, j in enumerate(self.group_sizes): + self.num_samples += int( + math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu / + self.num_replicas)) * self.samples_per_gpu + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch + self.seed) + + indices = [] + for i, size in enumerate(self.group_sizes): + if size > 0: + indice = np.where(self.flag == i)[0] + assert len(indice) == size + # add .numpy() to avoid bug when selecting indice in parrots. + # TODO: check whether torch.randperm() can be replaced by + # numpy.random.permutation(). + indice = indice[list( + torch.randperm(int(size), generator=g).numpy())].tolist() + extra = int( + math.ceil( + size * 1.0 / self.samples_per_gpu / self.num_replicas) + ) * self.samples_per_gpu * self.num_replicas - len(indice) + # pad indice + tmp = indice.copy() + for _ in range(extra // size): + indice.extend(tmp) + indice.extend(tmp[:extra % size]) + indices.extend(indice) + + assert len(indices) == self.total_size + + indices = [ + indices[j] for i in list( + torch.randperm( + len(indices) // self.samples_per_gpu, generator=g)) + for j in range(i * self.samples_per_gpu, (i + 1) * + self.samples_per_gpu) + ] + + # subsample + offset = self.num_samples * self.rank + indices = indices[offset:offset + self.num_samples] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/infinite_sampler.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/infinite_sampler.py new file mode 100755 index 00000000..6bb743fd --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/samplers/infinite_sampler.py @@ -0,0 +1,186 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools + +import numpy as np +import torch +from dbnet_cv.runner import get_dist_info +from torch.utils.data.sampler import Sampler + +from dbnet_det.core.utils import sync_random_seed + + +class InfiniteGroupBatchSampler(Sampler): + """Similar to `BatchSampler` warping a `GroupSampler. It is designed for + iteration-based runners like `IterBasedRunner` and yields a mini-batch + indices each time, all indices in a batch should be in the same group. + + The implementation logic is referred to + https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py + + Args: + dataset (object): The dataset. + batch_size (int): When model is :obj:`DistributedDataParallel`, + it is the number of training samples on each GPU. + When model is :obj:`DataParallel`, it is + `num_gpus * samples_per_gpu`. + Default : 1. + world_size (int, optional): Number of processes participating in + distributed training. Default: None. + rank (int, optional): Rank of current process. Default: None. + seed (int): Random seed. Default: 0. + shuffle (bool): Whether shuffle the indices of a dummy `epoch`, it + should be noted that `shuffle` can not guarantee that you can + generate sequential indices because it need to ensure + that all indices in a batch is in a group. Default: True. + """ # noqa: W605 + + def __init__(self, + dataset, + batch_size=1, + world_size=None, + rank=None, + seed=0, + shuffle=True): + _rank, _world_size = get_dist_info() + if world_size is None: + world_size = _world_size + if rank is None: + rank = _rank + self.rank = rank + self.world_size = world_size + self.dataset = dataset + self.batch_size = batch_size + # In distributed sampling, different ranks should sample + # non-overlapped data in the dataset. Therefore, this function + # is used to make sure that each rank shuffles the data indices + # in the same order based on the same seed. Then different ranks + # could use different indices to select non-overlapped data from the + # same data list. + self.seed = sync_random_seed(seed) + self.shuffle = shuffle + + assert hasattr(self.dataset, 'flag') + self.flag = self.dataset.flag + self.group_sizes = np.bincount(self.flag) + # buffer used to save indices of each group + self.buffer_per_group = {k: [] for k in range(len(self.group_sizes))} + + self.size = len(dataset) + self.indices = self._indices_of_rank() + + def _infinite_indices(self): + """Infinitely yield a sequence of indices.""" + g = torch.Generator() + g.manual_seed(self.seed) + while True: + if self.shuffle: + yield from torch.randperm(self.size, generator=g).tolist() + + else: + yield from torch.arange(self.size).tolist() + + def _indices_of_rank(self): + """Slice the infinite indices by rank.""" + yield from itertools.islice(self._infinite_indices(), self.rank, None, + self.world_size) + + def __iter__(self): + # once batch size is reached, yield the indices + for idx in self.indices: + flag = self.flag[idx] + group_buffer = self.buffer_per_group[flag] + group_buffer.append(idx) + if len(group_buffer) == self.batch_size: + yield group_buffer[:] + del group_buffer[:] + + def __len__(self): + """Length of base dataset.""" + return self.size + + def set_epoch(self, epoch): + """Not supported in `IterationBased` runner.""" + raise NotImplementedError + + +class InfiniteBatchSampler(Sampler): + """Similar to `BatchSampler` warping a `DistributedSampler. It is designed + iteration-based runners like `IterBasedRunner` and yields a mini-batch + indices each time. + + The implementation logic is referred to + https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py + + Args: + dataset (object): The dataset. + batch_size (int): When model is :obj:`DistributedDataParallel`, + it is the number of training samples on each GPU, + When model is :obj:`DataParallel`, it is + `num_gpus * samples_per_gpu`. + Default : 1. + world_size (int, optional): Number of processes participating in + distributed training. Default: None. + rank (int, optional): Rank of current process. Default: None. + seed (int): Random seed. Default: 0. + shuffle (bool): Whether shuffle the dataset or not. Default: True. + """ # noqa: W605 + + def __init__(self, + dataset, + batch_size=1, + world_size=None, + rank=None, + seed=0, + shuffle=True): + _rank, _world_size = get_dist_info() + if world_size is None: + world_size = _world_size + if rank is None: + rank = _rank + self.rank = rank + self.world_size = world_size + self.dataset = dataset + self.batch_size = batch_size + # In distributed sampling, different ranks should sample + # non-overlapped data in the dataset. Therefore, this function + # is used to make sure that each rank shuffles the data indices + # in the same order based on the same seed. Then different ranks + # could use different indices to select non-overlapped data from the + # same data list. + self.seed = sync_random_seed(seed) + self.shuffle = shuffle + self.size = len(dataset) + self.indices = self._indices_of_rank() + + def _infinite_indices(self): + """Infinitely yield a sequence of indices.""" + g = torch.Generator() + g.manual_seed(self.seed) + while True: + if self.shuffle: + yield from torch.randperm(self.size, generator=g).tolist() + + else: + yield from torch.arange(self.size).tolist() + + def _indices_of_rank(self): + """Slice the infinite indices by rank.""" + yield from itertools.islice(self._infinite_indices(), self.rank, None, + self.world_size) + + def __iter__(self): + # once batch size is reached, yield the indices + batch_buffer = [] + for idx in self.indices: + batch_buffer.append(idx) + if len(batch_buffer) == self.batch_size: + yield batch_buffer + batch_buffer = [] + + def __len__(self): + """Length of base dataset.""" + return self.size + + def set_epoch(self, epoch): + """Not supported in `IterationBased` runner.""" + raise NotImplementedError diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/datasets/utils.py b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/utils.py new file mode 100755 index 00000000..5220aba4 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/datasets/utils.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import warnings + +# from dbnet_cv.cnn import VGG +from dbnet_cv.runner.hooks import HOOKS, Hook + +from dbnet_det.datasets.builder import PIPELINES +from dbnet_det.datasets.pipelines import (LoadAnnotations, LoadImageFromFile, + LoadPanopticAnnotations) +# from dbnet_det.models.dense_heads import GARPNHead, RPNHead +# from dbnet_det.models.roi_heads.mask_heads import FusedSemanticHead + + +def replace_ImageToTensor(pipelines): + """Replace the ImageToTensor transform in a data pipeline to + DefaultFormatBundle, which is normally useful in batch inference. + + Args: + pipelines (list[dict]): Data pipeline configs. + + Returns: + list: The new pipeline list with all ImageToTensor replaced by + DefaultFormatBundle. + + Examples: + >>> pipelines = [ + ... dict(type='LoadImageFromFile'), + ... dict( + ... type='MultiScaleFlipAug', + ... img_scale=(1333, 800), + ... flip=False, + ... transforms=[ + ... dict(type='Resize', keep_ratio=True), + ... dict(type='RandomFlip'), + ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]), + ... dict(type='Pad', size_divisor=32), + ... dict(type='ImageToTensor', keys=['img']), + ... dict(type='Collect', keys=['img']), + ... ]) + ... ] + >>> expected_pipelines = [ + ... dict(type='LoadImageFromFile'), + ... dict( + ... type='MultiScaleFlipAug', + ... img_scale=(1333, 800), + ... flip=False, + ... transforms=[ + ... dict(type='Resize', keep_ratio=True), + ... dict(type='RandomFlip'), + ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]), + ... dict(type='Pad', size_divisor=32), + ... dict(type='DefaultFormatBundle'), + ... dict(type='Collect', keys=['img']), + ... ]) + ... ] + >>> assert expected_pipelines == replace_ImageToTensor(pipelines) + """ + pipelines = copy.deepcopy(pipelines) + for i, pipeline in enumerate(pipelines): + if pipeline['type'] == 'MultiScaleFlipAug': + assert 'transforms' in pipeline + pipeline['transforms'] = replace_ImageToTensor( + pipeline['transforms']) + elif pipeline['type'] == 'ImageToTensor': + warnings.warn( + '"ImageToTensor" pipeline is replaced by ' + '"DefaultFormatBundle" for batch inference. It is ' + 'recommended to manually replace it in the test ' + 'data pipeline in your config file.', UserWarning) + pipelines[i] = {'type': 'DefaultFormatBundle'} + return pipelines + + +def get_loading_pipeline(pipeline): + """Only keep loading image and annotations related configuration. + + Args: + pipeline (list[dict]): Data pipeline configs. + + Returns: + list[dict]: The new pipeline list with only keep + loading image and annotations related configuration. + + Examples: + >>> pipelines = [ + ... dict(type='LoadImageFromFile'), + ... dict(type='LoadAnnotations', with_bbox=True), + ... dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + ... dict(type='RandomFlip', flip_ratio=0.5), + ... dict(type='Normalize', **img_norm_cfg), + ... dict(type='Pad', size_divisor=32), + ... dict(type='DefaultFormatBundle'), + ... dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) + ... ] + >>> expected_pipelines = [ + ... dict(type='LoadImageFromFile'), + ... dict(type='LoadAnnotations', with_bbox=True) + ... ] + >>> assert expected_pipelines ==\ + ... get_loading_pipeline(pipelines) + """ + loading_pipeline_cfg = [] + for cfg in pipeline: + obj_cls = PIPELINES.get(cfg['type']) + # TODO:use more elegant way to distinguish loading modules + if obj_cls is not None and obj_cls in (LoadImageFromFile, + LoadAnnotations, + LoadPanopticAnnotations): + loading_pipeline_cfg.append(cfg) + assert len(loading_pipeline_cfg) == 2, \ + 'The data pipeline in your config file must include ' \ + 'loading image and annotations related pipeline.' + return loading_pipeline_cfg + + +@HOOKS.register_module() +class NumClassCheckHook(Hook): + + def _check_head(self, runner): + """Check whether the `num_classes` in head matches the length of + `CLASSES` in `dataset`. + + Args: + runner (obj:`EpochBasedRunner`): Epoch based Runner. + """ + model = runner.model + dataset = runner.data_loader.dataset + if dataset.CLASSES is None: + runner.logger.warning( + f'Please set `CLASSES` ' + f'in the {dataset.__class__.__name__} and' + f'check if it is consistent with the `num_classes` ' + f'of head') + else: + assert type(dataset.CLASSES) is not str, \ + (f'`CLASSES` in {dataset.__class__.__name__}' + f'should be a tuple of str.' + f'Add comma if number of classes is 1 as ' + f'CLASSES = ({dataset.CLASSES},)') + for name, module in model.named_modules(): + if hasattr(module, 'num_classes') and not isinstance( + module, (RPNHead, FusedSemanticHead, GARPNHead)): + assert module.num_classes == len(dataset.CLASSES), \ + (f'The `num_classes` ({module.num_classes}) in ' + f'{module.__class__.__name__} of ' + f'{model.__class__.__name__} does not matches ' + f'the length of `CLASSES` ' + f'{len(dataset.CLASSES)}) in ' + f'{dataset.__class__.__name__}') + + def before_train_epoch(self, runner): + """Check whether the training dataset is compatible with head. + + Args: + runner (obj:`EpochBasedRunner`): Epoch based Runner. + """ + self._check_head(runner) + + def before_val_epoch(self, runner): + """Check whether the dataset in val epoch is compatible with head. + + Args: + runner (obj:`EpochBasedRunner`): Epoch based Runner. + """ + self._check_head(runner) diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/models/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_det/models/__init__.py new file mode 100755 index 00000000..d16681ec --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/models/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .backbones import * # noqa: F401,F403 +from .builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS, + ROI_EXTRACTORS, SHARED_HEADS, build_backbone, + build_detector, build_head, build_loss, build_neck, + build_roi_extractor, build_shared_head) +# from .dense_heads import * # noqa: F401,F403 +from .detectors import * # noqa: F401,F403 +# from .losses import * # noqa: F401,F403 +# from .necks import * # noqa: F401,F403 +# from .plugins import * # noqa: F401,F403 +# from .roi_heads import * # noqa: F401,F403 +# from .seg_heads import * # noqa: F401,F403 + +__all__ = [ + 'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES', + 'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor', + 'build_shared_head', 'build_head', 'build_loss', 'build_detector' +] diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/models/backbones/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_det/models/backbones/__init__.py new file mode 100755 index 00000000..bb083fc7 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/models/backbones/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# from .csp_darknet import CSPDarknet +# from .darknet import Darknet +# from .detectors_resnet import DetectoRS_ResNet +# from .detectors_resnext import DetectoRS_ResNeXt +# from .efficientnet import EfficientNet +# from .hourglass import HourglassNet +# from .hrnet import HRNet +from .mobilenet_v3 import MobileNetV3 +# from .pvt import PyramidVisionTransformer, PyramidVisionTransformerV2 +# from .regnet import RegNet +# from .res2net import Res2Net +# from .resnest import ResNeSt +from .resnet import ResNet, ResNetV1d +# from .resnext import ResNeXt +# from .ssd_vgg import SSDVGG +# from .swin import SwinTransformer +# from .trident_resnet import TridentResNet + +# __all__ = [ +# 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', +# 'MobileNetV2', 'Res2Net', 'HourglassNet', 'DetectoRS_ResNet', +# 'DetectoRS_ResNeXt', 'Darknet', 'ResNeSt', 'TridentResNet', 'CSPDarknet', +# 'SwinTransformer', 'PyramidVisionTransformer', +# 'PyramidVisionTransformerV2', 'EfficientNet' +# ] +__all__ = [ + 'ResNet', + 'MobileNetV3' +] \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/models/backbones/mobilenet_v3.py b/cv/ocr/dbnet/pytorch/dbnet_det/models/backbones/mobilenet_v3.py new file mode 100755 index 00000000..e6b73a39 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/models/backbones/mobilenet_v3.py @@ -0,0 +1,265 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import dbnet_cv +from dbnet_cv.cnn import ConvModule +from dbnet_cv.cnn.bricks import Conv2dAdaptivePadding +from dbnet_cv.runner import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm + +from ..builder import BACKBONES +from ..utils.inverted_residual import InvertedResidual + + +@BACKBONES.register_module() +class MobileNetV3(BaseModule): + """MobileNetV3 backbone. + This backbone is the improved implementation of `Searching for MobileNetV3 + `_. + Args: + arch (str): Architecture of mobilnetv3, from {'small', 'large'}. + Default: 'small'. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + out_indices (tuple[int]): Output from which layer. + Default: (0, 1, 12). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. + Default: False. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + # Parameters to build each block: + # [kernel size, mid channels, out channels, with_se, act type, stride] + arch_settings = { + 'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4 + [3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8 + [3, 88, 24, False, 'ReLU', 1], + [5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16 + [5, 240, 40, True, 'HSwish', 1], + [5, 240, 40, True, 'HSwish', 1], + [5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16 + [5, 144, 48, True, 'HSwish', 1], + [5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32 + [5, 576, 96, True, 'HSwish', 1], + [5, 576, 96, True, 'HSwish', 1]], + 'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2 + [3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4 + [3, 72, 24, False, 'ReLU', 1], + [5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8 + [5, 120, 40, True, 'ReLU', 1], + [5, 120, 40, True, 'ReLU', 1], + [3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16 + [3, 200, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16 + [3, 672, 112, True, 'HSwish', 1], + [5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32 + [5, 960, 160, True, 'HSwish', 1], + [5, 960, 160, True, 'HSwish', 1]] + } # yapf: disable + + def __init__(self, + arch='small', + conv_cfg=None, + norm_cfg=dict(type='BN'), + out_indices=(0, 1, 8, 12), + frozen_stages=-1, + reduction_factor=1, + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None): + super(MobileNetV3, self).__init__(init_cfg) + + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + assert arch in self.arch_settings + assert isinstance(reduction_factor, int) and reduction_factor > 0 + assert dbnet_cv.is_tuple_of(out_indices, int) + for index in out_indices: + if index not in range(0, len(self.arch_settings[arch]) + 2): + raise ValueError( + 'the item in out_indices must in ' + f'range(0, {len(self.arch_settings[arch])+2}). ' + f'But received {index}') + + if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2): + raise ValueError('frozen_stages must be in range(-1, ' + f'{len(self.arch_settings[arch])+2}). ' + f'But received {frozen_stages}') + self.arch = arch + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.reduction_factor = reduction_factor + self.norm_eval = norm_eval + self.with_cp = with_cp + self.layers = self._make_layer() + + def _make_layer(self): + layers = [] + + # build the first layer (layer0) + in_channels = 16 + layer = ConvModule( + in_channels=3, + out_channels=in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=dict(type='Conv2dAdaptivePadding'), + norm_cfg=self.norm_cfg, + act_cfg=dict(type='HSwish')) + self.add_module('layer0', layer) + layers.append('layer0') + + layer_setting = self.arch_settings[self.arch] + for i, params in enumerate(layer_setting): + (kernel_size, mid_channels, out_channels, with_se, act, + stride) = params + + if self.arch == 'large' and i >= 12 or self.arch == 'small' and \ + i >= 8: + mid_channels = mid_channels // self.reduction_factor + out_channels = out_channels // self.reduction_factor + + if with_se: + se_cfg = dict( + channels=mid_channels, + ratio=4, + act_cfg=(dict(type='ReLU'), + dict(type='HSigmoid', bias=3.0, divisor=6.0))) + else: + se_cfg = None + + layer = InvertedResidual( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + se_cfg=se_cfg, + with_expand_conv=(in_channels != mid_channels), + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type=act), + with_cp=self.with_cp) + in_channels = out_channels + layer_name = 'layer{}'.format(i + 1) + self.add_module(layer_name, layer) + layers.append(layer_name) + + # build the last layer + # block5 layer12 os=32 for small model + # block6 layer16 os=32 for large model + layer = ConvModule( + in_channels=in_channels, + out_channels=576 if self.arch == 'small' else 960, + kernel_size=1, + stride=1, + dilation=4, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type='HSwish')) + layer_name = 'layer{}'.format(len(layer_setting) + 1) + self.add_module(layer_name, layer) + layers.append(layer_name) + + # next, convert backbone MobileNetV3 to a semantic segmentation version + if self.arch == 'small': + self.layer4.depthwise_conv.conv.stride = (1, 1) + self.layer9.depthwise_conv.conv.stride = (1, 1) + for i in range(4, len(layers)): + layer = getattr(self, layers[i]) + if isinstance(layer, InvertedResidual): + modified_module = layer.depthwise_conv.conv + else: + modified_module = layer.conv + + if i < 9: + modified_module.dilation = (2, 2) + pad = 2 + else: + modified_module.dilation = (4, 4) + pad = 4 + + if not isinstance(modified_module, Conv2dAdaptivePadding): + # Adjust padding + pad *= (modified_module.kernel_size[0] - 1) // 2 + modified_module.padding = (pad, pad) + else: + self.layer7.depthwise_conv.conv.stride = (1, 1) + self.layer13.depthwise_conv.conv.stride = (1, 1) + for i in range(7, len(layers)): + layer = getattr(self, layers[i]) + if isinstance(layer, InvertedResidual): + modified_module = layer.depthwise_conv.conv + else: + modified_module = layer.conv + + if i < 13: + modified_module.dilation = (2, 2) + pad = 2 + else: + modified_module.dilation = (4, 4) + pad = 4 + + if not isinstance(modified_module, Conv2dAdaptivePadding): + # Adjust padding + pad *= (modified_module.kernel_size[0] - 1) // 2 + modified_module.padding = (pad, pad) + + return layers + + def forward(self, x): + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + return outs + + def _freeze_stages(self): + for i in range(self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MobileNetV3, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/models/backbones/resnet.py b/cv/ocr/dbnet/pytorch/dbnet_det/models/backbones/resnet.py new file mode 100755 index 00000000..6406cd08 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/models/backbones/resnet.py @@ -0,0 +1,672 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +import torch.utils.checkpoint as cp +from dbnet_cv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer +from dbnet_cv.runner import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm + +from ..builder import BACKBONES +from ..utils import ResLayer + + +class BasicBlock(BaseModule): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None, + init_cfg=None): + super(BasicBlock, self).__init__(init_cfg) + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, planes, planes, 3, padding=1, bias=False) + self.add_module(self.norm2_name, norm2) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.with_cp = with_cp + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Bottleneck(BaseModule): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None, + init_cfg=None): + """Bottleneck block for ResNet. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if + it is "caffe", the stride-two layer is the first 1x1 conv layer. + """ + super(Bottleneck, self).__init__(init_cfg) + assert style in ['pytorch', 'caffe'] + assert dcn is None or isinstance(dcn, dict) + assert plugins is None or isinstance(plugins, list) + if plugins is not None: + allowed_position = ['after_conv1', 'after_conv2', 'after_conv3'] + assert all(p['position'] in allowed_position for p in plugins) + + self.inplanes = inplanes + self.planes = planes + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.dcn = dcn + self.with_dcn = dcn is not None + self.plugins = plugins + self.with_plugins = plugins is not None + + if self.with_plugins: + # collect plugins for conv1/conv2/conv3 + self.after_conv1_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv1' + ] + self.after_conv2_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv2' + ] + self.after_conv3_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv3' + ] + + if self.style == 'pytorch': + self.conv1_stride = 1 + self.conv2_stride = stride + else: + self.conv1_stride = stride + self.conv2_stride = 1 + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + norm_cfg, planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + fallback_on_stride = False + if self.with_dcn: + fallback_on_stride = dcn.pop('fallback_on_stride', False) + if not self.with_dcn or fallback_on_stride: + self.conv2 = build_conv_layer( + conv_cfg, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + else: + assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + self.conv2 = build_conv_layer( + dcn, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + conv_cfg, + planes, + planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + if self.with_plugins: + self.after_conv1_plugin_names = self.make_block_plugins( + planes, self.after_conv1_plugins) + self.after_conv2_plugin_names = self.make_block_plugins( + planes, self.after_conv2_plugins) + self.after_conv3_plugin_names = self.make_block_plugins( + planes * self.expansion, self.after_conv3_plugins) + + def make_block_plugins(self, in_channels, plugins): + """make plugins for block. + + Args: + in_channels (int): Input channels of plugin. + plugins (list[dict]): List of plugins cfg to build. + + Returns: + list[str]: List of the names of plugin. + """ + assert isinstance(plugins, list) + plugin_names = [] + for plugin in plugins: + plugin = plugin.copy() + name, layer = build_plugin_layer( + plugin, + in_channels=in_channels, + postfix=plugin.pop('postfix', '')) + assert not hasattr(self, name), f'duplicate plugin {name}' + self.add_module(name, layer) + plugin_names.append(name) + return plugin_names + + def forward_plugin(self, x, plugin_names): + out = x + for name in plugin_names: + out = getattr(self, name)(out) + return out + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + @property + def norm3(self): + """nn.Module: normalization layer after the third convolution layer""" + return getattr(self, self.norm3_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@BACKBONES.register_module() +class ResNet(BaseModule): + """ResNet backbone. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + stem_channels (int | None): Number of stem channels. If not specified, + it will be the same as `base_channels`. Default: None. + base_channels (int): Number of base channels of res layer. Default: 64. + in_channels (int): Number of input image channels. Default: 3. + num_stages (int): Resnet stages. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + norm_cfg (dict): Dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + plugins (list[dict]): List of plugins for stages, each dict contains: + + - cfg (dict, required): Cfg dict to build plugin. + - position (str, required): Position inside block to insert + plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'. + - stages (tuple[bool], optional): Stages to apply plugin, length + should be same as 'num_stages'. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + + Example: + >>> from dbnet_det.models import ResNet + >>> import torch + >>> self = ResNet(depth=18) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 8, 8) + (1, 128, 4, 4) + (1, 256, 2, 2) + (1, 512, 1, 1) + """ + + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth, + in_channels=3, + stem_channels=None, + base_channels=64, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(0, 1, 2, 3), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + dcn=None, + stage_with_dcn=(False, False, False, False), + plugins=None, + with_cp=False, + zero_init_residual=True, + pretrained=None, + init_cfg=None): + super(ResNet, self).__init__(init_cfg) + self.zero_init_residual = zero_init_residual + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + + block_init_cfg = None + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be specified at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + block = self.arch_settings[depth][0] + if self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', + val=0, + override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', + val=0, + override=dict(name='norm3')) + else: + raise TypeError('pretrained must be a str or None') + + self.depth = depth + if stem_channels is None: + stem_channels = base_channels + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.dcn = dcn + self.stage_with_dcn = stage_with_dcn + if dcn is not None: + assert len(stage_with_dcn) == num_stages + self.plugins = plugins + self.block, stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + self.inplanes = stem_channels + + self._make_stem_layer(in_channels, stem_channels) + + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = strides[i] + dilation = dilations[i] + dcn = self.dcn if self.stage_with_dcn[i] else None + if plugins is not None: + stage_plugins = self.make_stage_plugins(plugins, i) + else: + stage_plugins = None + planes = base_channels * 2**i + res_layer = self.make_res_layer( + block=self.block, + inplanes=self.inplanes, + planes=planes, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + dcn=dcn, + plugins=stage_plugins, + init_cfg=block_init_cfg) + self.inplanes = planes * self.block.expansion + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = self.block.expansion * base_channels * 2**( + len(self.stage_blocks) - 1) + + def make_stage_plugins(self, plugins, stage_idx): + """Make plugins for ResNet ``stage_idx`` th stage. + + Currently we support to insert ``context_block``, + ``empirical_attention_block``, ``nonlocal_block`` into the backbone + like ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of + Bottleneck. + + An example of plugins format could be: + + Examples: + >>> plugins=[ + ... dict(cfg=dict(type='xxx', arg1='xxx'), + ... stages=(False, True, True, True), + ... position='after_conv2'), + ... dict(cfg=dict(type='yyy'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='1'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='2'), + ... stages=(True, True, True, True), + ... position='after_conv3') + ... ] + >>> self = ResNet(depth=18) + >>> stage_plugins = self.make_stage_plugins(plugins, 0) + >>> assert len(stage_plugins) == 3 + + Suppose ``stage_idx=0``, the structure of blocks in the stage would be: + + .. code-block:: none + + conv1-> conv2->conv3->yyy->zzz1->zzz2 + + Suppose 'stage_idx=1', the structure of blocks in the stage would be: + + .. code-block:: none + + conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2 + + If stages is missing, the plugin would be applied to all stages. + + Args: + plugins (list[dict]): List of plugins cfg to build. The postfix is + required if multiple same type plugins are inserted. + stage_idx (int): Index of stage to build + + Returns: + list[dict]: Plugins for current stage + """ + stage_plugins = [] + for plugin in plugins: + plugin = plugin.copy() + stages = plugin.pop('stages', None) + assert stages is None or len(stages) == self.num_stages + # whether to insert plugin into current stage + if stages is None or stages[stage_idx]: + stage_plugins.append(plugin) + + return stage_plugins + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer(**kwargs) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels, stem_channels): + if self.deep_stem: + self.stem = nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels)[1], + nn.ReLU(inplace=True)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x): + """Forward function.""" + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super(ResNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +@BACKBONES.register_module() +class ResNetV1d(ResNet): + r"""ResNetV1d variant described in `Bag of Tricks + `_. + + Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in + the input stem with three 3x3 convs. And in the downsampling block, a 2x2 + avg_pool with stride 2 is added before conv, whose stride is changed to 1. + """ + + def __init__(self, **kwargs): + super(ResNetV1d, self).__init__( + deep_stem=True, avg_down=True, **kwargs) diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/models/builder.py b/cv/ocr/dbnet/pytorch/dbnet_det/models/builder.py new file mode 100755 index 00000000..924493a0 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/models/builder.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from dbnet_cv.cnn import MODELS as DBNET_CV_MODELS +from dbnet_cv.utils import Registry + +MODELS = Registry('models', parent=DBNET_CV_MODELS) + +BACKBONES = MODELS +NECKS = MODELS +ROI_EXTRACTORS = MODELS +SHARED_HEADS = MODELS +HEADS = MODELS +LOSSES = MODELS +DETECTORS = MODELS + + +def build_backbone(cfg): + """Build backbone.""" + return BACKBONES.build(cfg) + + +def build_neck(cfg): + """Build neck.""" + return NECKS.build(cfg) + + +def build_roi_extractor(cfg): + """Build roi extractor.""" + return ROI_EXTRACTORS.build(cfg) + + +def build_shared_head(cfg): + """Build shared head.""" + return SHARED_HEADS.build(cfg) + + +def build_head(cfg): + """Build head.""" + return HEADS.build(cfg) + + +def build_loss(cfg): + """Build loss.""" + return LOSSES.build(cfg) + + +def build_detector(cfg, train_cfg=None, test_cfg=None): + """Build detector.""" + if train_cfg is not None or test_cfg is not None: + warnings.warn( + 'train_cfg and test_cfg is deprecated, ' + 'please specify them in model', UserWarning) + assert cfg.get('train_cfg') is None or train_cfg is None, \ + 'train_cfg specified in both outer field and model field ' + assert cfg.get('test_cfg') is None or test_cfg is None, \ + 'test_cfg specified in both outer field and model field ' + return DETECTORS.build( + cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/models/detectors/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_det/models/detectors/__init__.py new file mode 100755 index 00000000..9ad675e9 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/models/detectors/__init__.py @@ -0,0 +1,2 @@ +from .single_stage import SingleStageDetector +from .base import BaseDetector \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/models/detectors/base.py b/cv/ocr/dbnet/pytorch/dbnet_det/models/detectors/base.py new file mode 100755 index 00000000..fa74c58e --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/models/detectors/base.py @@ -0,0 +1,360 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from collections import OrderedDict + +import dbnet_cv +import numpy as np +import torch +import torch.distributed as dist +from dbnet_cv.runner import BaseModule, auto_fp16 + +from dbnet_det.core.visualization import imshow_det_bboxes + + +class BaseDetector(BaseModule, metaclass=ABCMeta): + """Base class for detectors.""" + + def __init__(self, init_cfg=None): + super(BaseDetector, self).__init__(init_cfg) + self.fp16_enabled = False + + @property + def with_neck(self): + """bool: whether the detector has a neck""" + return hasattr(self, 'neck') and self.neck is not None + + # TODO: these properties need to be carefully handled + # for both single stage & two stage detectors + @property + def with_shared_head(self): + """bool: whether the detector has a shared head in the RoI Head""" + return hasattr(self, 'roi_head') and self.roi_head.with_shared_head + + @property + def with_bbox(self): + """bool: whether the detector has a bbox head""" + return ((hasattr(self, 'roi_head') and self.roi_head.with_bbox) + or (hasattr(self, 'bbox_head') and self.bbox_head is not None)) + + @property + def with_mask(self): + """bool: whether the detector has a mask head""" + return ((hasattr(self, 'roi_head') and self.roi_head.with_mask) + or (hasattr(self, 'mask_head') and self.mask_head is not None)) + + @abstractmethod + def extract_feat(self, imgs): + """Extract features from images.""" + pass + + def extract_feats(self, imgs): + """Extract features from multiple images. + + Args: + imgs (list[torch.Tensor]): A list of images. The images are + augmented from the same image but in different ways. + + Returns: + list[torch.Tensor]: Features of different images + """ + assert isinstance(imgs, list) + return [self.extract_feat(img) for img in imgs] + + def forward_train(self, imgs, img_metas, **kwargs): + """ + Args: + img (Tensor): of shape (N, C, H, W) encoding input images. + Typically these should be mean centered and std scaled. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys, see + :class:`dbnet_det.datasets.pipelines.Collect`. + kwargs (keyword arguments): Specific to concrete implementation. + """ + # NOTE the batched image size information may be useful, e.g. + # in DETR, this is needed for the construction of masks, which is + # then used for the transformer_head. + batch_input_shape = tuple(imgs[0].size()[-2:]) + for img_meta in img_metas: + img_meta['batch_input_shape'] = batch_input_shape + + async def async_simple_test(self, img, img_metas, **kwargs): + raise NotImplementedError + + @abstractmethod + def simple_test(self, img, img_metas, **kwargs): + pass + + @abstractmethod + def aug_test(self, imgs, img_metas, **kwargs): + """Test function with test time augmentation.""" + pass + + async def aforward_test(self, *, img, img_metas, **kwargs): + for var, name in [(img, 'img'), (img_metas, 'img_metas')]: + if not isinstance(var, list): + raise TypeError(f'{name} must be a list, but got {type(var)}') + + num_augs = len(img) + if num_augs != len(img_metas): + raise ValueError(f'num of augmentations ({len(img)}) ' + f'!= num of image metas ({len(img_metas)})') + # TODO: remove the restriction of samples_per_gpu == 1 when prepared + samples_per_gpu = img[0].size(0) + assert samples_per_gpu == 1 + + if num_augs == 1: + return await self.async_simple_test(img[0], img_metas[0], **kwargs) + else: + raise NotImplementedError + + def forward_test(self, imgs, img_metas, **kwargs): + """ + Args: + imgs (List[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains all images in the batch. + img_metas (List[List[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. + """ + for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: + if not isinstance(var, list): + raise TypeError(f'{name} must be a list, but got {type(var)}') + + num_augs = len(imgs) + if num_augs != len(img_metas): + raise ValueError(f'num of augmentations ({len(imgs)}) ' + f'!= num of image meta ({len(img_metas)})') + + # NOTE the batched image size information may be useful, e.g. + # in DETR, this is needed for the construction of masks, which is + # then used for the transformer_head. + for img, img_meta in zip(imgs, img_metas): + batch_size = len(img_meta) + for img_id in range(batch_size): + img_meta[img_id]['batch_input_shape'] = tuple(img.size()[-2:]) + + if num_augs == 1: + # proposals (List[List[Tensor]]): the outer list indicates + # test-time augs (multiscale, flip, etc.) and the inner list + # indicates images in a batch. + # The Tensor should have a shape Px4, where P is the number of + # proposals. + if 'proposals' in kwargs: + kwargs['proposals'] = kwargs['proposals'][0] + return self.simple_test(imgs[0], img_metas[0], **kwargs) + else: + assert imgs[0].size(0) == 1, 'aug test does not support ' \ + 'inference with batch size ' \ + f'{imgs[0].size(0)}' + # TODO: support test augmentation for predefined proposals + assert 'proposals' not in kwargs + return self.aug_test(imgs, img_metas, **kwargs) + + @auto_fp16(apply_to=('img', )) + def forward(self, img, img_metas, return_loss=True, **kwargs): + """Calls either :func:`forward_train` or :func:`forward_test` depending + on whether ``return_loss`` is ``True``. + + Note this setting will change the expected inputs. When + ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor + and List[dict]), and when ``resturn_loss=False``, img and img_meta + should be double nested (i.e. List[Tensor], List[List[dict]]), with + the outer list indicating test time augmentations. + """ + if torch.onnx.is_in_onnx_export(): + assert len(img_metas) == 1 + return self.onnx_export(img[0], img_metas[0]) + + if return_loss: + return self.forward_train(img, img_metas, **kwargs) + else: + return self.forward_test(img, img_metas, **kwargs) + + def _parse_losses(self, losses): + """Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw output of the network, which usually contain + losses and other necessary information. + + Returns: + tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor \ + which may be a weighted sum of all losses, log_vars contains \ + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError( + f'{loss_name} is not a tensor or list of tensors') + + loss = sum(_value for _key, _value in log_vars.items() + if 'loss' in _key) + + # If the loss_vars has different length, GPUs will wait infinitely + if dist.is_available() and dist.is_initialized(): + log_var_length = torch.tensor(len(log_vars), device=loss.device) + dist.all_reduce(log_var_length) + message = (f'rank {dist.get_rank()}' + + f' len(log_vars): {len(log_vars)}' + ' keys: ' + + ','.join(log_vars.keys())) + assert log_var_length == len(log_vars) * dist.get_world_size(), \ + 'loss log variables are different across GPUs!\n' + message + + log_vars['loss'] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars + + def train_step(self, data, optimizer): + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating is also defined in + this method, such as GAN. + + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \ + ``num_samples``. + + - ``loss`` is a tensor for back propagation, which can be a + weighted sum of multiple losses. + - ``log_vars`` contains all the variables to be sent to the + logger. + - ``num_samples`` indicates the batch size (when the model is + DDP, it means the batch size on each GPU), which is used for + averaging the logs. + """ + losses = self(**data) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, log_vars=log_vars, num_samples=len(data['img_metas'])) + + return outputs + + def val_step(self, data, optimizer=None): + """The iteration step during validation. + + This method shares the same signature as :func:`train_step`, but used + during val epochs. Note that the evaluation after training epochs is + not implemented with this method, but an evaluation hook. + """ + losses = self(**data) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, log_vars=log_vars, num_samples=len(data['img_metas'])) + + return outputs + + def show_result(self, + img, + result, + score_thr=0.3, + bbox_color=(72, 101, 241), + text_color=(72, 101, 241), + mask_color=None, + thickness=2, + font_size=13, + win_name='', + show=False, + wait_time=0, + out_file=None): + """Draw `result` over `img`. + + Args: + img (str or Tensor): The image to be displayed. + result (Tensor or tuple): The results to draw over `img` + bbox_result or (bbox_result, segm_result). + score_thr (float, optional): Minimum score of bboxes to be shown. + Default: 0.3. + bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines. + The tuple of color should be in BGR order. Default: 'green' + text_color (str or tuple(int) or :obj:`Color`):Color of texts. + The tuple of color should be in BGR order. Default: 'green' + mask_color (None or str or tuple(int) or :obj:`Color`): + Color of masks. The tuple of color should be in BGR order. + Default: None + thickness (int): Thickness of lines. Default: 2 + font_size (int): Font size of texts. Default: 13 + win_name (str): The window name. Default: '' + wait_time (float): Value of waitKey param. + Default: 0. + show (bool): Whether to show the image. + Default: False. + out_file (str or None): The filename to write the image. + Default: None. + + Returns: + img (Tensor): Only if not `show` or `out_file` + """ + img = dbnet_cv.imread(img) + img = img.copy() + if isinstance(result, tuple): + bbox_result, segm_result = result + if isinstance(segm_result, tuple): + segm_result = segm_result[0] # ms rcnn + else: + bbox_result, segm_result = result, None + bboxes = np.vstack(bbox_result) + labels = [ + np.full(bbox.shape[0], i, dtype=np.int32) + for i, bbox in enumerate(bbox_result) + ] + labels = np.concatenate(labels) + # draw segmentation masks + segms = None + if segm_result is not None and len(labels) > 0: # non empty + segms = dbnet_cv.concat_list(segm_result) + if isinstance(segms[0], torch.Tensor): + segms = torch.stack(segms, dim=0).detach().cpu().numpy() + else: + segms = np.stack(segms, axis=0) + # if out_file specified, do not show image in window + if out_file is not None: + show = False + # draw bounding boxes + img = imshow_det_bboxes( + img, + bboxes, + labels, + segms, + class_names=self.CLASSES, + score_thr=score_thr, + bbox_color=bbox_color, + text_color=text_color, + mask_color=mask_color, + thickness=thickness, + font_size=font_size, + win_name=win_name, + show=show, + wait_time=wait_time, + out_file=out_file) + + if not (show or out_file): + return img + + def onnx_export(self, img, img_metas): + raise NotImplementedError(f'{self.__class__.__name__} does ' + f'not support ONNX EXPORT') diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/models/detectors/single_stage.py b/cv/ocr/dbnet/pytorch/dbnet_det/models/detectors/single_stage.py new file mode 100755 index 00000000..ac9cfd78 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/models/detectors/single_stage.py @@ -0,0 +1,171 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch + +from dbnet_det.core import bbox2result +from ..builder import DETECTORS, build_backbone, build_head, build_neck +from .base import BaseDetector + + +@DETECTORS.register_module() +class SingleStageDetector(BaseDetector): + """Base class for single-stage detectors. + + Single-stage detectors directly and densely predict bounding boxes on the + output features of the backbone+neck. + """ + + def __init__(self, + backbone, + neck=None, + bbox_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None): + super(SingleStageDetector, self).__init__(init_cfg) + if pretrained: + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + backbone.pretrained = pretrained + self.backbone = build_backbone(backbone) + if neck is not None: + self.neck = build_neck(neck) + bbox_head.update(train_cfg=train_cfg) + bbox_head.update(test_cfg=test_cfg) + self.bbox_head = build_head(bbox_head) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def extract_feat(self, img): + """Directly extract features from the backbone+neck.""" + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def forward_dummy(self, img): + """Used for computing network flops. + + See `dbnet_detection/tools/analysis_tools/get_flops.py` + """ + x = self.extract_feat(img) + outs = self.bbox_head(x) + return outs + + def forward_train(self, + img, + img_metas, + gt_bboxes, + gt_labels, + gt_bboxes_ignore=None): + """ + Args: + img (Tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + img_metas (list[dict]): A List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + :class:`dbnet_det.datasets.pipelines.Collect`. + gt_bboxes (list[Tensor]): Each item are the truth boxes for each + image in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): Class indices corresponding to each box + gt_bboxes_ignore (None | list[Tensor]): Specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + super(SingleStageDetector, self).forward_train(img, img_metas) + x = self.extract_feat(img) + losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes, + gt_labels, gt_bboxes_ignore) + return losses + + def simple_test(self, img, img_metas, rescale=False): + """Test function without test-time augmentation. + + Args: + img (torch.Tensor): Images with shape (N, C, H, W). + img_metas (list[dict]): List of image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[list[np.ndarray]]: BBox results of each image and classes. + The outer list corresponds to each image. The inner list + corresponds to each class. + """ + feat = self.extract_feat(img) + results_list = self.bbox_head.simple_test( + feat, img_metas, rescale=rescale) + bbox_results = [ + bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) + for det_bboxes, det_labels in results_list + ] + return bbox_results + + def aug_test(self, imgs, img_metas, rescale=False): + """Test function with test time augmentation. + + Args: + imgs (list[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains all images in the batch. + img_metas (list[list[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. each dict has image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[list[np.ndarray]]: BBox results of each image and classes. + The outer list corresponds to each image. The inner list + corresponds to each class. + """ + assert hasattr(self.bbox_head, 'aug_test'), \ + f'{self.bbox_head.__class__.__name__}' \ + ' does not support test-time augmentation' + + feats = self.extract_feats(imgs) + results_list = self.bbox_head.aug_test( + feats, img_metas, rescale=rescale) + bbox_results = [ + bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) + for det_bboxes, det_labels in results_list + ] + return bbox_results + + def onnx_export(self, img, img_metas, with_nms=True): + """Test function without test time augmentation. + + Args: + img (torch.Tensor): input images. + img_metas (list[dict]): List of image information. + + Returns: + tuple[Tensor, Tensor]: dets of shape [N, num_det, 5] + and class labels of shape [N, num_det]. + """ + x = self.extract_feat(img) + outs = self.bbox_head(x) + # get origin input shape to support onnx dynamic shape + + # get shape as tensor + img_shape = torch._shape_as_tensor(img)[2:] + img_metas[0]['img_shape_for_onnx'] = img_shape + # get pad input shape to support onnx dynamic shape for exporting + # `CornerNet` and `CentripetalNet`, which 'pad_shape' is used + # for inference + img_metas[0]['pad_shape_for_onnx'] = img_shape + + if len(outs) == 2: + # add dummy score_factor + outs = (*outs, None) + # TODO Can we change to `get_bboxes` when `onnx_export` fail + det_bboxes, det_labels = self.bbox_head.onnx_export( + *outs, img_metas, with_nms=with_nms) + + return det_bboxes, det_labels diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/models/utils/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_det/models/utils/__init__.py new file mode 100755 index 00000000..d154cdee --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/models/utils/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# from .brick_wrappers import AdaptiveAvgPool2d, adaptive_avg_pool2d +# from .builder import build_linear_layer, build_transformer +# from .ckpt_convert import pvt_convert +# from .conv_upsample import ConvUpsample +# from .csp_layer import CSPLayer +# from .gaussian_target import gaussian_radius, gen_gaussian_target +from .inverted_residual import InvertedResidual +from .make_divisible import make_divisible +# from .misc import interpolate_as, sigmoid_geometric_mean +# from .normed_predictor import NormedConv2d, NormedLinear +# from .panoptic_gt_processing import preprocess_panoptic_gt +# from .point_sample import (get_uncertain_point_coords_with_randomness, +# get_uncertainty) +# from .positional_encoding import (LearnedPositionalEncoding, +# SinePositionalEncoding) +from .res_layer import ResLayer, SimplifiedBasicBlock +from .se_layer import SELayer +# from .transformer import (DetrTransformerDecoder, DetrTransformerDecoderLayer, +# DynamicConv, PatchEmbed, Transformer, nchw_to_nlc, +# nlc_to_nchw) + +# __all__ = [ +# 'ResLayer', 'gaussian_radius', 'gen_gaussian_target', +# 'DetrTransformerDecoderLayer', 'DetrTransformerDecoder', 'Transformer', +# 'build_transformer', 'build_linear_layer', 'SinePositionalEncoding', +# 'LearnedPositionalEncoding', 'DynamicConv', 'SimplifiedBasicBlock', +# 'NormedLinear', 'NormedConv2d', 'make_divisible', 'InvertedResidual', +# 'SELayer', 'interpolate_as', 'ConvUpsample', 'CSPLayer', +# 'adaptive_avg_pool2d', 'AdaptiveAvgPool2d', 'PatchEmbed', 'nchw_to_nlc', +# 'nlc_to_nchw', 'pvt_convert', 'sigmoid_geometric_mean', +# 'preprocess_panoptic_gt', 'DyReLU', +# 'get_uncertain_point_coords_with_randomness', 'get_uncertainty' +# ] diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/models/utils/inverted_residual.py b/cv/ocr/dbnet/pytorch/dbnet_det/models/utils/inverted_residual.py new file mode 100755 index 00000000..c6a2e642 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/models/utils/inverted_residual.py @@ -0,0 +1,128 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.utils.checkpoint as cp +from dbnet_cv.cnn import ConvModule +from dbnet_cv.cnn.bricks import DropPath +from dbnet_cv.runner import BaseModule + +from .se_layer import SELayer + + +class InvertedResidual(BaseModule): + """Inverted Residual Block. + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + mid_channels (int): The input channels of the depthwise convolution. + kernel_size (int): The kernel size of the depthwise convolution. + Default: 3. + stride (int): The stride of the depthwise convolution. Default: 1. + se_cfg (dict): Config dict for se layer. Default: None, which means no + se layer. + with_expand_conv (bool): Use expand conv or not. If set False, + mid_channels must be the same with in_channels. + Default: True. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + drop_path_rate (float): stochastic depth rate. Defaults to 0. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels, + kernel_size=3, + stride=1, + se_cfg=None, + with_expand_conv=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + drop_path_rate=0., + with_cp=False, + init_cfg=None): + super(InvertedResidual, self).__init__(init_cfg) + self.with_res_shortcut = (stride == 1 and in_channels == out_channels) + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0 else nn.Identity() + self.with_se = se_cfg is not None + self.with_expand_conv = with_expand_conv + + if self.with_se: + assert isinstance(se_cfg, dict) + if not self.with_expand_conv: + assert mid_channels == in_channels + + if self.with_expand_conv: + self.expand_conv = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.depthwise_conv = ConvModule( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=mid_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + if self.with_se: + self.se = SELayer(**se_cfg) + + self.linear_conv = ConvModule( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, x): + + def _inner_forward(x): + out = x + + if self.with_expand_conv: + out = self.expand_conv(out) + + out = self.depthwise_conv(out) + + if self.with_se: + out = self.se(out) + + out = self.linear_conv(out) + + if self.with_res_shortcut: + return x + self.drop_path(out) + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/models/utils/make_divisible.py b/cv/ocr/dbnet/pytorch/dbnet_det/models/utils/make_divisible.py new file mode 100755 index 00000000..5bbc0a6c --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/models/utils/make_divisible.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def make_divisible(value, divisor, min_value=None, min_ratio=0.9): + """Make divisible function. + This function rounds the channel number to the nearest value that can be + divisible by the divisor. It is taken from the original tf repo. It ensures + that all layers have a channel number that is divisible by divisor. It can + be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa + Args: + value (int): The original channel number. + divisor (int): The divisor to fully divide the channel number. + min_value (int): The minimum value of the output channel. + Default: None, means that the minimum value equal to the divisor. + min_ratio (float): The minimum ratio of the rounded channel number to + the original channel number. Default: 0.9. + Returns: + int: The modified output channel number. + """ + + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than (1-min_ratio). + if new_value < min_ratio * value: + new_value += divisor + return new_value \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/models/utils/res_layer.py b/cv/ocr/dbnet/pytorch/dbnet_det/models/utils/res_layer.py new file mode 100755 index 00000000..715bd398 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/models/utils/res_layer.py @@ -0,0 +1,190 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dbnet_cv.cnn import build_conv_layer, build_norm_layer +from dbnet_cv.runner import BaseModule, Sequential +from torch import nn as nn + + +class ResLayer(Sequential): + """ResLayer to build ResNet style backbone. + + Args: + block (nn.Module): block used to build ResLayer. + inplanes (int): inplanes of block. + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + downsample_first (bool): Downsample at the first block or last block. + False for Hourglass, True for ResNet. Default: True + """ + + def __init__(self, + block, + inplanes, + planes, + num_blocks, + stride=1, + avg_down=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + downsample_first=True, + **kwargs): + self.block = block + + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = [] + conv_stride = stride + if avg_down: + conv_stride = 1 + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=conv_stride, + bias=False), + build_norm_layer(norm_cfg, planes * block.expansion)[1] + ]) + downsample = nn.Sequential(*downsample) + + layers = [] + if downsample_first: + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + inplanes = planes * block.expansion + for _ in range(1, num_blocks): + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + + else: # downsample_first=False is for HourglassModule + for _ in range(num_blocks - 1): + layers.append( + block( + inplanes=inplanes, + planes=inplanes, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + super(ResLayer, self).__init__(*layers) + + +class SimplifiedBasicBlock(BaseModule): + """Simplified version of original basic residual block. This is used in + `SCNet `_. + + - Norm layer is now optional + - Last ReLU in forward function is removed + """ + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None, + init_fg=None): + super(SimplifiedBasicBlock, self).__init__(init_fg) + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + assert not with_cp, 'Not implemented yet.' + self.with_norm = norm_cfg is not None + with_bias = True if norm_cfg is None else False + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=with_bias) + if self.with_norm: + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, planes, postfix=1) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, planes, planes, 3, padding=1, bias=with_bias) + if self.with_norm: + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, planes, postfix=2) + self.add_module(self.norm2_name, norm2) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.with_cp = with_cp + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) if self.with_norm else None + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) if self.with_norm else None + + def forward(self, x): + """Forward function.""" + + identity = x + + out = self.conv1(x) + if self.with_norm: + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + if self.with_norm: + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/models/utils/se_layer.py b/cv/ocr/dbnet/pytorch/dbnet_det/models/utils/se_layer.py new file mode 100755 index 00000000..6aea608a --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/models/utils/se_layer.py @@ -0,0 +1,124 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dbnet_cv +import torch +import torch.nn as nn +from dbnet_cv.cnn import ConvModule +from dbnet_cv.runner import BaseModule + + +class SELayer(BaseModule): + """Squeeze-and-Excitation Module. + Args: + channels (int): The input (and output) channels of the SE layer. + ratio (int): Squeeze ratio in SELayer, the intermediate channel will be + ``int(channels/ratio)``. Default: 16. + conv_cfg (None or dict): Config dict for convolution layer. + Default: None, which means using conv2d. + act_cfg (dict or Sequence[dict]): Config dict for activation layer. + If act_cfg is a dict, two activation layers will be configurated + by this dict. If act_cfg is a sequence of dicts, the first + activation layer will be configurated by the first dict and the + second activation layer will be configurated by the second dict. + Default: (dict(type='ReLU'), dict(type='Sigmoid')) + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + channels, + ratio=16, + conv_cfg=None, + act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')), + init_cfg=None): + super(SELayer, self).__init__(init_cfg) + if isinstance(act_cfg, dict): + act_cfg = (act_cfg, act_cfg) + assert len(act_cfg) == 2 + assert dbnet_cv.is_tuple_of(act_cfg, dict) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.conv1 = ConvModule( + in_channels=channels, + out_channels=int(channels / ratio), + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[0]) + self.conv2 = ConvModule( + in_channels=int(channels / ratio), + out_channels=channels, + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[1]) + + def forward(self, x): + out = self.global_avgpool(x) + out = self.conv1(out) + out = self.conv2(out) + return x * out + + +class DyReLU(BaseModule): + """Dynamic ReLU (DyReLU) module. + See `Dynamic ReLU `_ for details. + Current implementation is specialized for task-aware attention in DyHead. + HSigmoid arguments in default act_cfg follow DyHead official code. + https://github.com/microsoft/DynamicHead/blob/master/dyhead/dyrelu.py + Args: + channels (int): The input (and output) channels of DyReLU module. + ratio (int): Squeeze ratio in Squeeze-and-Excitation-like module, + the intermediate channel will be ``int(channels/ratio)``. + Default: 4. + conv_cfg (None or dict): Config dict for convolution layer. + Default: None, which means using conv2d. + act_cfg (dict or Sequence[dict]): Config dict for activation layer. + If act_cfg is a dict, two activation layers will be configurated + by this dict. If act_cfg is a sequence of dicts, the first + activation layer will be configurated by the first dict and the + second activation layer will be configurated by the second dict. + Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0, + divisor=6.0)) + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + channels, + ratio=4, + conv_cfg=None, + act_cfg=(dict(type='ReLU'), + dict(type='HSigmoid', bias=3.0, divisor=6.0)), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + if isinstance(act_cfg, dict): + act_cfg = (act_cfg, act_cfg) + assert len(act_cfg) == 2 + assert dbnet_cv.is_tuple_of(act_cfg, dict) + self.channels = channels + self.expansion = 4 # for a1, b1, a2, b2 + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.conv1 = ConvModule( + in_channels=channels, + out_channels=int(channels / ratio), + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[0]) + self.conv2 = ConvModule( + in_channels=int(channels / ratio), + out_channels=channels * self.expansion, + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[1]) + + def forward(self, x): + """Forward function.""" + coeffs = self.global_avgpool(x) + coeffs = self.conv1(coeffs) + coeffs = self.conv2(coeffs) - 0.5 # value range: [-0.5, 0.5] + a1, b1, a2, b2 = torch.split(coeffs, self.channels, dim=1) + a1 = a1 * 2.0 + 1.0 # [-1.0, 1.0] + 1.0 + a2 = a2 * 2.0 # [-1.0, 1.0] + out = torch.max(x * a1 + b1, x * a2 + b2) + return out \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/utils/__init__.py b/cv/ocr/dbnet/pytorch/dbnet_det/utils/__init__.py new file mode 100755 index 00000000..d0c3a821 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/utils/__init__.py @@ -0,0 +1,18 @@ +# # Copyright (c) OpenMMLab. All rights reserved. +# from .collect_env import collect_env +# from .compat_config import compat_cfg +from .logger import get_caller_name, get_root_logger, log_img_scale +# from .memory import AvoidCUDAOOM, AvoidOOM +# from .misc import find_latest_checkpoint, update_data_root +# from .replace_cfg_vals import replace_cfg_vals +# from .setup_env import setup_multi_processes +# from .split_batch import split_batch +from .util_distribution import build_ddp, build_dp, get_device + + +# __all__ = [ +# 'get_root_logger', 'collect_env', 'find_latest_checkpoint', +# 'update_data_root', 'setup_multi_processes', 'get_caller_name', +# 'log_img_scale', 'compat_cfg', 'split_batch', 'build_ddp', 'build_dp', +# 'get_device', 'replace_cfg_vals', 'AvoidOOM', 'AvoidCUDAOOM' +# ] diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/utils/inverted_residual.py b/cv/ocr/dbnet/pytorch/dbnet_det/utils/inverted_residual.py new file mode 100755 index 00000000..4accfce1 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/utils/inverted_residual.py @@ -0,0 +1,209 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dbnet_cv.cnn import ConvModule +from torch import nn +from torch.utils import checkpoint as cp + +from .se_layer import SELayer + + +class InvertedResidual(nn.Module): + """InvertedResidual block for MobileNetV2. + Args: + in_channels (int): The input channels of the InvertedResidual block. + out_channels (int): The output channels of the InvertedResidual block. + stride (int): Stride of the middle (first) 3x3 convolution. + expand_ratio (int): Adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + dilation (int): Dilation rate of depthwise conv. Default: 1 + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + stride, + expand_ratio, + dilation=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + with_cp=False, + **kwargs): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and in_channels == out_channels + hidden_dim = int(round(in_channels * expand_ratio)) + + layers = [] + if expand_ratio != 1: + layers.append( + ConvModule( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs)) + layers.extend([ + ConvModule( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=hidden_dim, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs), + ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None, + **kwargs) + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + + def _inner_forward(x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class InvertedResidualV3(nn.Module): + """Inverted Residual Block for MobileNetV3. + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + mid_channels (int): The input channels of the depthwise convolution. + kernel_size (int): The kernel size of the depthwise convolution. + Default: 3. + stride (int): The stride of the depthwise convolution. Default: 1. + se_cfg (dict): Config dict for se layer. Default: None, which means no + se layer. + with_expand_conv (bool): Use expand conv or not. If set False, + mid_channels must be the same with in_channels. Default: True. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels, + kernel_size=3, + stride=1, + se_cfg=None, + with_expand_conv=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False): + super(InvertedResidualV3, self).__init__() + self.with_res_shortcut = (stride == 1 and in_channels == out_channels) + assert stride in [1, 2] + self.with_cp = with_cp + self.with_se = se_cfg is not None + self.with_expand_conv = with_expand_conv + + if self.with_se: + assert isinstance(se_cfg, dict) + if not self.with_expand_conv: + assert mid_channels == in_channels + + if self.with_expand_conv: + self.expand_conv = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.depthwise_conv = ConvModule( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=mid_channels, + conv_cfg=dict( + type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + if self.with_se: + self.se = SELayer(**se_cfg) + + self.linear_conv = ConvModule( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, x): + + def _inner_forward(x): + out = x + + if self.with_expand_conv: + out = self.expand_conv(out) + + out = self.depthwise_conv(out) + + if self.with_se: + out = self.se(out) + + out = self.linear_conv(out) + + if self.with_res_shortcut: + return x + out + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out \ No newline at end of file diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/utils/logger.py b/cv/ocr/dbnet/pytorch/dbnet_det/utils/logger.py new file mode 100755 index 00000000..29f57e7c --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/utils/logger.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +import logging + +from dbnet_cv.utils import get_logger + + +def get_root_logger(log_file=None, log_level=logging.INFO): + """Get root logger. + + Args: + log_file (str, optional): File path of log. Defaults to None. + log_level (int, optional): The level of logger. + Defaults to logging.INFO. + + Returns: + :obj:`logging.Logger`: The obtained logger + """ + logger = get_logger(name='dbnet_det', log_file=log_file, log_level=log_level) + + return logger + + +def get_caller_name(): + """Get name of caller method.""" + # this_func_frame = inspect.stack()[0][0] # i.e., get_caller_name + # callee_frame = inspect.stack()[1][0] # e.g., log_img_scale + caller_frame = inspect.stack()[2][0] # e.g., caller of log_img_scale + caller_method = caller_frame.f_code.co_name + try: + caller_class = caller_frame.f_locals['self'].__class__.__name__ + return f'{caller_class}.{caller_method}' + except KeyError: # caller is a function + return caller_method + + +def log_img_scale(img_scale, shape_order='hw', skip_square=False): + """Log image size. + + Args: + img_scale (tuple): Image size to be logged. + shape_order (str, optional): The order of image shape. + 'hw' for (height, width) and 'wh' for (width, height). + Defaults to 'hw'. + skip_square (bool, optional): Whether to skip logging for square + img_scale. Defaults to False. + + Returns: + bool: Whether to have done logging. + """ + if shape_order == 'hw': + height, width = img_scale + elif shape_order == 'wh': + width, height = img_scale + else: + raise ValueError(f'Invalid shape_order {shape_order}.') + + if skip_square and (height == width): + return False + + logger = get_root_logger() + caller = get_caller_name() + logger.info(f'image shape: height={height}, width={width} in {caller}') + + return True diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/utils/profiling.py b/cv/ocr/dbnet/pytorch/dbnet_det/utils/profiling.py new file mode 100755 index 00000000..2f53f456 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/utils/profiling.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import contextlib +import sys +import time + +import torch + +if sys.version_info >= (3, 7): + + @contextlib.contextmanager + def profile_time(trace_name, + name, + enabled=True, + stream=None, + end_stream=None): + """Print time spent by CPU and GPU. + + Useful as a temporary context manager to find sweet spots of code + suitable for async implementation. + """ + if (not enabled) or not torch.cuda.is_available(): + yield + return + stream = stream if stream else torch.cuda.current_stream() + end_stream = end_stream if end_stream else stream + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + stream.record_event(start) + try: + cpu_start = time.monotonic() + yield + finally: + cpu_end = time.monotonic() + end_stream.record_event(end) + end.synchronize() + cpu_time = (cpu_end - cpu_start) * 1000 + gpu_time = start.elapsed_time(end) + msg = f'{trace_name} {name} cpu_time {cpu_time:.2f} ms ' + msg += f'gpu_time {gpu_time:.2f} ms stream {stream}' + print(msg, end_stream) diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/utils/util_distribution.py b/cv/ocr/dbnet/pytorch/dbnet_det/utils/util_distribution.py new file mode 100755 index 00000000..a1c00dc8 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/utils/util_distribution.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from dbnet_cv.parallel import MMDataParallel, MMDistributedDataParallel + +dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel} + +ddp_factory = {'cuda': MMDistributedDataParallel} + + +def build_dp(model, device='cuda', dim=0, *args, **kwargs): + """build DataParallel module by device type. + + if device is cuda, return a MMDataParallel model; if device is mlu, + return a MLUDataParallel model. + + Args: + model (:class:`nn.Module`): model to be parallelized. + device (str): device type, cuda, cpu or mlu. Defaults to cuda. + dim (int): Dimension used to scatter the data. Defaults to 0. + + Returns: + nn.Module: the model to be parallelized. + """ + if device == 'cuda': + model = model.cuda() + elif device == 'mlu': + from dbnet_cv.device.mlu import MLUDataParallel + dp_factory['mlu'] = MLUDataParallel + model = model.mlu() + + return dp_factory[device](model, dim=dim, *args, **kwargs) + + +def build_ddp(model, device='cuda', *args, **kwargs): + """Build DistributedDataParallel module by device type. + + If device is cuda, return a MMDistributedDataParallel model; + if device is mlu, return a MLUDistributedDataParallel model. + + Args: + model (:class:`nn.Module`): module to be parallelized. + device (str): device type, mlu or cuda. + + Returns: + :class:`nn.Module`: the module to be parallelized + + References: + .. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel. + DistributedDataParallel.html + """ + assert device in ['cuda', 'mlu'], 'Only available for cuda or mlu devices.' + if device == 'cuda': + model = model.cuda() + elif device == 'mlu': + from dbnet_cv.device.mlu import MLUDistributedDataParallel + ddp_factory['mlu'] = MLUDistributedDataParallel + model = model.mlu() + + return ddp_factory[device](model, *args, **kwargs) + + +def is_mlu_available(): + """Returns a bool indicating if MLU is currently available.""" + return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available() + + +def get_device(): + """Returns an available device, cpu, cuda or mlu.""" + is_device_available = { + 'cuda': torch.cuda.is_available(), + 'mlu': is_mlu_available() + } + device_list = [k for k, v in is_device_available.items() if v] + return device_list[0] if len(device_list) == 1 else 'cpu' diff --git a/cv/ocr/dbnet/pytorch/dbnet_det/version.py b/cv/ocr/dbnet/pytorch/dbnet_det/version.py new file mode 100755 index 00000000..56e9b075 --- /dev/null +++ b/cv/ocr/dbnet/pytorch/dbnet_det/version.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +__version__ = '2.25.0' +short_version = __version__ + + +def parse_version_info(version_str): + version_info = [] + for x in version_str.split('.'): + if x.isdigit(): + version_info.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + version_info.append(int(patch_version[0])) + version_info.append(f'rc{patch_version[1]}') + return tuple(version_info) + + +version_info = parse_version_info(__version__) -- Gitee