diff --git a/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/internimage_det.patch b/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/internimage_det.patch new file mode 100755 index 0000000000000000000000000000000000000000..c4afc2dac3b6da09039bff79251a38e04ce9cc74 --- /dev/null +++ b/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/internimage_det.patch @@ -0,0 +1,229 @@ +diff --git a/detection/configs/_base_/datasets/coco_instance.py b/detection/configs/_base_/datasets/coco_instance.py +index 91461aa..d667dee 100644 +--- a/detection/configs/_base_/datasets/coco_instance.py ++++ b/detection/configs/_base_/datasets/coco_instance.py +@@ -17,15 +17,15 @@ test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', +- img_scale=(1333, 800), +- flip=False, ++ scales=(1333, 800), # for adjusting mmdet v3 ++ allow_flip=False, + transforms=[ +- dict(type='Resize', keep_ratio=True), +- dict(type='RandomFlip'), ++ dict(type='Resize', keep_ratio=True, scale=(1333, 800)), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), +- dict(type='ImageToTensor', keys=['img']), +- dict(type='Collect', keys=['img']), ++ dict(type='mmdet.PackDetInputs', # for adjusting mmdet v3 ++ meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', ++ 'scale_factor')) + ]) + ] + data = dict( +@@ -43,7 +43,20 @@ data = dict( + pipeline=test_pipeline), + test=dict( + type=dataset_type, +- ann_file=data_root + 'annotations/instances_val2017.json', +- img_prefix=data_root + 'val2017/', ++ data_root=data_root, ++ ann_file='annotations/instances_val2017.json', ++ data_prefix=dict(img='val2017/',), + pipeline=test_pipeline)) + evaluation = dict(metric=['bbox', 'segm'], classwise=True) ++test_dataloader = dict( # for adjusting mmdet v3 ++ dataset=dict( ++ type=dataset_type, ++ ann_file='annotations/instances_val2017.json', ++ data_prefix=dict(img='val2017/',), ++ pipeline=test_pipeline ++ ), ++ batch_size=1, ++ num_workers=1, ++ sampler=dict(type='DefaultSampler', shuffle=False), ++ drop_last=False, ++) +diff --git a/detection/configs/coco/cascade_internimage_xl_fpn_3x_coco.py b/detection/configs/coco/cascade_internimage_xl_fpn_3x_coco.py +index 37dd4c3..f519591 100644 +--- a/detection/configs/coco/cascade_internimage_xl_fpn_3x_coco.py ++++ b/detection/configs/coco/cascade_internimage_xl_fpn_3x_coco.py +@@ -145,7 +145,6 @@ optimizer = dict( + depths=[5, 5, 24, 5], offset_lr_scale=0.01)) + optimizer_config = dict(grad_clip=None) + # fp16 = dict(loss_scale=dict(init_scale=512)) +-evaluation = dict(save_best='auto') + checkpoint_config = dict( + interval=1, + max_keep_ckpts=3, +diff --git a/detection/mmcv_custom/__init__.py b/detection/mmcv_custom/__init__.py +index b066460..5c5235f 100644 +--- a/detection/mmcv_custom/__init__.py ++++ b/detection/mmcv_custom/__init__.py +@@ -5,7 +5,6 @@ + # -------------------------------------------------------- + + # -*- coding: utf-8 -*- +-from .custom_layer_decay_optimizer_constructor import \ +- CustomLayerDecayOptimizerConstructor + +-__all__ = ['CustomLayerDecayOptimizerConstructor'] ++ ++ +diff --git a/detection/mmdet_custom/__init__.py b/detection/mmdet_custom/__init__.py +index eebe73e..56bfa61 100644 +--- a/detection/mmdet_custom/__init__.py ++++ b/detection/mmdet_custom/__init__.py +@@ -4,5 +4,4 @@ + # Licensed under The MIT License [see LICENSE for details] + # -------------------------------------------------------- + +-from .datasets import * + from .models import * # noqa: F401,F403 +diff --git a/detection/mmdet_custom/models/__init__.py b/detection/mmdet_custom/models/__init__.py +index 59d47d5..1526e4d 100644 +--- a/detection/mmdet_custom/models/__init__.py ++++ b/detection/mmdet_custom/models/__init__.py +@@ -5,6 +5,3 @@ + # -------------------------------------------------------- + + from .backbones import * # noqa: F401,F403 +-from .dense_heads import * # noqa: F401,F403 +-from .detectors import * # noqa: F401,F403 +-from .utils import * # noqa: F401,F403 +diff --git a/detection/mmdet_custom/models/backbones/intern_image.py b/detection/mmdet_custom/models/backbones/intern_image.py +index 6deb823..c850086 100644 +--- a/detection/mmdet_custom/models/backbones/intern_image.py ++++ b/detection/mmdet_custom/models/backbones/intern_image.py +@@ -10,12 +10,16 @@ import torch + import torch.nn as nn + import torch.nn.functional as F + import torch.utils.checkpoint as checkpoint +-from mmcv.cnn import constant_init, trunc_normal_init +-from mmcv.runner import _load_checkpoint +-from mmdet.models.builder import BACKBONES +-from mmdet.utils import get_root_logger ++ + from ops_dcnv3 import modules as dcnv3 +-from timm.models.layers import DropPath, trunc_normal_ ++ ++from timm.models.layers import trunc_normal_, DropPath ++from mmengine.runner.checkpoint import _load_checkpoint ++from mmengine.model import constant_init, trunc_normal_init ++from mmengine.registry import MODELS ++BACKBONES = MODELS ++HEADS = MODELS ++LOSSES = MODELS + + + class to_channels_first(nn.Module): +@@ -588,15 +592,16 @@ class InternImage(nn.Module): + self.init_cfg = init_cfg + self.out_indices = out_indices + self.level2_post_norm_block_ids = level2_post_norm_block_ids +- logger = get_root_logger() +- logger.info(f'using core type: {core_op}') +- logger.info(f'using activation layer: {act_layer}') +- logger.info(f'using main norm layer: {norm_layer}') +- logger.info(f'using dpr: {drop_path_type}, {drop_path_rate}') +- logger.info(f'level2_post_norm: {level2_post_norm}') +- logger.info(f'level2_post_norm_block_ids: {level2_post_norm_block_ids}') +- logger.info(f'res_post_norm: {res_post_norm}') +- logger.info(f'use_dcn_v4_op: {use_dcn_v4_op}') ++ # get_root_logger method is no longer supported in mmdet v3 ++ # logger = get_root_logger() ++ # logger.info(f'using core type: {core_op}') ++ # logger.info(f'using activation layer: {act_layer}') ++ # logger.info(f'using main norm layer: {norm_layer}') ++ # logger.info(f'using dpr: {drop_path_type}, {drop_path_rate}') ++ # logger.info(f'level2_post_norm: {level2_post_norm}') ++ # logger.info(f'level2_post_norm_block_ids: {level2_post_norm_block_ids}') ++ # logger.info(f'res_post_norm: {res_post_norm}') ++ # logger.info(f'use_dcn_v4_op: {use_dcn_v4_op}') + + in_chans = 3 + self.patch_embed = StemLayer(in_chans=in_chans, +@@ -644,9 +649,8 @@ class InternImage(nn.Module): + self.apply(self._init_deform_weights) + + def init_weights(self): +- logger = get_root_logger() + if self.init_cfg is None: +- logger.warn(f'No pre-trained weights for ' ++ print(f'No pre-trained weights for ' + f'{self.__class__.__name__}, ' + f'training start from scratch') + for m in self.modules(): +@@ -660,7 +664,6 @@ class InternImage(nn.Module): + f'`init_cfg` in ' \ + f'{self.__class__.__name__} ' + ckpt = _load_checkpoint(self.init_cfg.checkpoint, +- logger=logger, + map_location='cpu') + if 'state_dict' in ckpt: + _state_dict = ckpt['state_dict'] +@@ -682,7 +685,7 @@ class InternImage(nn.Module): + + # load state_dict + meg = self.load_state_dict(state_dict, False) +- logger.info(meg) ++ print(meg) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): +diff --git a/detection/ops_dcnv3/functions/__init__.py b/detection/ops_dcnv3/functions/__init__.py +index 0634879..fdbb29e 100644 +--- a/detection/ops_dcnv3/functions/__init__.py ++++ b/detection/ops_dcnv3/functions/__init__.py +@@ -4,4 +4,4 @@ + # Licensed under The MIT License [see LICENSE for details] + # -------------------------------------------------------- + +-from .dcnv3_func import DCNv3Function, dcnv3_core_pytorch ++from .dcnv3_func import dcnv3_core_pytorch +diff --git a/detection/ops_dcnv3/functions/dcnv3_func.py b/detection/ops_dcnv3/functions/dcnv3_func.py +index 198e70c..e586cf7 100644 +--- a/detection/ops_dcnv3/functions/dcnv3_func.py ++++ b/detection/ops_dcnv3/functions/dcnv3_func.py +@@ -6,7 +6,6 @@ + + from __future__ import absolute_import, division, print_function + +-import DCNv3 + import torch + import torch.nn.functional as F + from torch.autograd import Function +diff --git a/detection/ops_dcnv3/modules/__init__.py b/detection/ops_dcnv3/modules/__init__.py +index acb73a8..2287382 100644 +--- a/detection/ops_dcnv3/modules/__init__.py ++++ b/detection/ops_dcnv3/modules/__init__.py +@@ -4,4 +4,4 @@ + # Licensed under The MIT License [see LICENSE for details] + # -------------------------------------------------------- + +-from .dcnv3 import DCNv3, DCNv3_pytorch ++from .dcnv3 import DCNv3_pytorch as DCNv3 +diff --git a/detection/ops_dcnv3/modules/dcnv3.py b/detection/ops_dcnv3/modules/dcnv3.py +index 788b211..4026212 100644 +--- a/detection/ops_dcnv3/modules/dcnv3.py ++++ b/detection/ops_dcnv3/modules/dcnv3.py +@@ -13,7 +13,7 @@ import torch.nn.functional as F + from torch import nn + from torch.nn.init import constant_, xavier_uniform_ + +-from ..functions import DCNv3Function, dcnv3_core_pytorch ++from ..functions import dcnv3_core_pytorch + + try: + from DCNv4.functions import DCNv4Function +@@ -107,6 +107,7 @@ class DCNv3_pytorch(nn.Module): + offset_scale=1.0, + act_layer='GELU', + norm_layer='LN', ++ use_dcn_v4_op=None, + center_feature_scale=False): + """ + DCNv3 Module \ No newline at end of file diff --git a/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/mmdet.patch b/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/mmdet.patch new file mode 100755 index 0000000000000000000000000000000000000000..e712fe648e9059b27bc8f5d8535a6f2b5c684577 --- /dev/null +++ b/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/mmdet.patch @@ -0,0 +1,283 @@ +diff --git a/mmdet/datasets/transforms/formatting.py b/mmdet/datasets/transforms/formatting.py +index 05263807..77a2d59b 100644 +--- a/mmdet/datasets/transforms/formatting.py ++++ b/mmdet/datasets/transforms/formatting.py +@@ -136,7 +136,7 @@ class PackDetInputs(BaseTransform): + if key in results: + img_meta[key] = results[key] + data_sample.set_metainfo(img_meta) +- packed_results['data_samples'] = data_sample ++ packed_results['data_sample'] = data_sample + + return packed_results + +diff --git a/mmdet/models/dense_heads/base_dense_head.py b/mmdet/models/dense_heads/base_dense_head.py +index d0a4469e..33ac469d 100644 +--- a/mmdet/models/dense_heads/base_dense_head.py ++++ b/mmdet/models/dense_heads/base_dense_head.py +@@ -189,7 +189,7 @@ class BaseDenseHead(BaseModule, metaclass=ABCMeta): + after the post process. + """ + batch_img_metas = [ +- data_samples.metainfo for data_samples in batch_data_samples ++ data_samples['metainfo'] for data_samples in batch_data_samples + ] + + outs = self(x) +diff --git a/mmdet/models/dense_heads/rpn_head.py b/mmdet/models/dense_heads/rpn_head.py +index 6b544009..4530cdc1 100644 +--- a/mmdet/models/dense_heads/rpn_head.py ++++ b/mmdet/models/dense_heads/rpn_head.py +@@ -17,6 +17,55 @@ from mmdet.structures.bbox import (cat_boxes, empty_box_as, get_box_tensor, + from mmdet.utils import InstanceList, MultiConfig, OptInstanceList + from .anchor_head import AnchorHead + ++class BatchNMSOp(torch.autograd.Function): # use customized OP because ONNX doesn't have cooresponding OP ++ @staticmethod ++ def forward(ctx, bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size): ++ """ ++ boxes (torch.Tensor): boxes in shape (batch, N, C, 4). ++ scores (torch.Tensor): scores in shape (batch, N, C). ++ return: ++ nmsed_boxes: (1, N, 4) ++ nmsed_scores: (1, N) ++ nmsed_classes: (1, N) ++ nmsed_num: (1,) ++ """ ++ ++ # Phony implementation for onnx export ++ nmsed_boxes = bboxes[:, :max_total_size, 0, :] ++ nmsed_scores = scores[:, :max_total_size, 0] ++ nmsed_classes = torch.arange(max_total_size, dtype=torch.long) ++ nmsed_num = torch.Tensor([max_total_size]) ++ ++ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num ++ ++ @staticmethod ++ def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class, max_t_size): ++ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = g.op('Ascend::BatchMultiClassNMS', ++ bboxes, scores, score_threshold_f=score_thr, iou_threshold_f=iou_thr, ++ max_size_per_class_i=max_size_p_class, max_total_size_i=max_t_size, outputs=4) ++ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num ++ ++def batch_nms_op(bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size): ++ """ ++ boxes (torch.Tensor): boxes in shape (N, 4). ++ scores (torch.Tensor): scores in shape (N, ). ++ """ ++ ++ if bboxes.dtype == torch.float32: ++ bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4).half() ++ scores = scores.reshape(1, scores.shape[0].numpy(), -1).half() ++ else: ++ bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4) ++ scores = scores.reshape(1, scores.shape[0].numpy(), -1) ++ ++ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = BatchNMSOp.apply(bboxes, scores, ++ score_threshold, iou_threshold, max_size_per_class, max_total_size) # max_total_size num_bbox ++ nmsed_boxes = nmsed_boxes.float() ++ nmsed_scores = nmsed_scores.float() ++ nmsed_classes = nmsed_classes.long() ++ dets = torch.cat((nmsed_boxes.reshape((max_total_size, 4)), nmsed_scores.reshape((max_total_size, 1))), -1) ++ labels = nmsed_classes.reshape((max_total_size, )) ++ return dets, labels + + @MODELS.register_module() + class RPNHead(AnchorHead): +@@ -203,11 +252,14 @@ class RPNHead(AnchorHead): + + scores = torch.squeeze(scores) + if 0 < nms_pre < scores.shape[0]: ++ if torch.onnx.is_in_onnx_export(): ++ scores, topk_inds = scores.topk(cfg.nms_pre) + # sort is faster than topk + # _, topk_inds = scores.topk(cfg.nms_pre) +- ranked_scores, rank_inds = scores.sort(descending=True) +- topk_inds = rank_inds[:nms_pre] +- scores = ranked_scores[:nms_pre] ++ else: ++ ranked_scores, rank_inds = scores.sort(descending=True) ++ topk_inds = rank_inds[:nms_pre] ++ scores = ranked_scores[:nms_pre] + bbox_pred = bbox_pred[topk_inds, :] + priors = priors[topk_inds] + +@@ -281,16 +333,23 @@ class RPNHead(AnchorHead): + + if results.bboxes.numel() > 0: + bboxes = get_box_tensor(results.bboxes) +- det_bboxes, keep_idxs = batched_nms(bboxes, results.scores, ++ if torch.onnx.is_in_onnx_export(): ++ dets, labels = batch_nms_op(bboxes, results.scores, 0.0, cfg.nms.get("iou_threshold"), cfg.get("max_per_img",1000), cfg.get("max_per_img",1000)) ++ results = results[:cfg.get("max_per_img",1000)] ++ results.bboxes = dets[:, :4] ++ results.scores = dets[:, -1] ++ results.labels = labels ++ else: ++ det_bboxes, keep_idxs = batched_nms(bboxes, results.scores, + results.level_ids, cfg.nms) +- results = results[keep_idxs] +- # some nms would reweight the score, such as softnms +- results.scores = det_bboxes[:, -1] +- results = results[:cfg.max_per_img] +- # TODO: This would unreasonably show the 0th class label +- # in visualization +- results.labels = results.scores.new_zeros( +- len(results), dtype=torch.long) ++ results = results[keep_idxs] ++ # some nms would reweight the score, such as softnms ++ results.scores = det_bboxes[:, -1] ++ results = results[:cfg.max_per_img] ++ # TODO: This would unreasonably show the 0th class label ++ # in visualization ++ results.labels = results.scores.new_zeros( ++ len(results), dtype=torch.long) + del results.level_ids + else: + # To avoid some potential error +diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py +index 1a193b0c..780a06ba 100644 +--- a/mmdet/models/detectors/base.py ++++ b/mmdet/models/detectors/base.py +@@ -57,6 +57,7 @@ class BaseDetector(BaseModel, metaclass=ABCMeta): + + def forward(self, + inputs: torch.Tensor, ++ img_shape: torch.Tensor, + data_samples: OptSampleList = None, + mode: str = 'tensor') -> ForwardResults: + """The unified entry for a forward process in both training and test. +@@ -93,7 +94,7 @@ class BaseDetector(BaseModel, metaclass=ABCMeta): + elif mode == 'predict': + return self.predict(inputs, data_samples) + elif mode == 'tensor': +- return self._forward(inputs, data_samples) ++ return self._forward(inputs, [dict(metainfo=dict(img_shape=img_shape))]) + else: + raise RuntimeError(f'Invalid mode "{mode}". ' + 'Only supports loss, predict and tensor mode') +diff --git a/mmdet/models/roi_heads/cascade_roi_head.py b/mmdet/models/roi_heads/cascade_roi_head.py +index 81db6711..9e065d0f 100644 +--- a/mmdet/models/roi_heads/cascade_roi_head.py ++++ b/mmdet/models/roi_heads/cascade_roi_head.py +@@ -539,7 +539,7 @@ class CascadeRoIHead(BaseRoIHead): + """ + results = () + batch_img_metas = [ +- data_samples.metainfo for data_samples in batch_data_samples ++ data_samples['metainfo'] for data_samples in batch_data_samples + ] + proposals = [rpn_results.bboxes for rpn_results in rpn_results_list] + num_proposals_per_img = tuple(len(p) for p in proposals) +@@ -549,6 +549,7 @@ class CascadeRoIHead(BaseRoIHead): + rois, cls_scores, bbox_preds = self._refine_roi( + x, rois, batch_img_metas, num_proposals_per_img) + results = results + (cls_scores, bbox_preds) ++ return cls_scores[0], bbox_preds[0], x, rois + # mask head + if self.with_mask: + aug_masks = [] +diff --git a/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py b/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py +index 59229e0b..9d1faf15 100644 +--- a/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py ++++ b/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py +@@ -8,6 +8,28 @@ from mmdet.registry import MODELS + from mmdet.utils import ConfigType, OptMultiConfig + from .base_roi_extractor import BaseRoIExtractor + ++class RoiExtractor(torch.autograd.Function): # use customized OP because to improve precision ++ @staticmethod ++ def forward(self, f0, f1, f2, f3, rois, aligned=1, finest_scale=56, pooled_height=7, pooled_width=7, ++ pool_mode='avg', roi_scale_factor=0, sample_num=0, spatial_scale=[0.25, 0.125, 0.0625, 0.03125]): ++ """ ++ feats (torch.Tensor): feats in shape (batch, 256, H, W). ++ rois (torch.Tensor): rois in shape (k, 5). ++ return: ++ roi_feats (torch.Tensor): (k, 256, pooled_width, pooled_width) ++ """ ++ ++ # phony implementation for shape inference ++ k = rois.size()[0] ++ roi_feats = torch.ones(k, 256, pooled_height, pooled_width) ++ return roi_feats ++ ++ @staticmethod ++ def symbolic(g, f0, f1, f2, f3, rois): ++ # TODO: support tensor list type for feats ++ roi_feats = g.op('ascend::RoiExtractor', f0, f1, f2, f3, rois, aligned_i=1, finest_scale_i=56, pooled_height_i=7, pooled_width_i=7, ++ pool_mode_s='avg', roi_scale_factor_i=0, sample_num_i=0, spatial_scale_f=[0.25, 0.125, 0.0625, 0.03125], outputs=1) ++ return roi_feats + + @MODELS.register_module() + class SingleRoIExtractor(BaseRoIExtractor): +@@ -82,8 +104,12 @@ class SingleRoIExtractor(BaseRoIExtractor): + rois = rois.type_as(feats[0]) + out_size = self.roi_layers[0].output_size + num_levels = len(feats) +- roi_feats = feats[0].new_zeros( +- rois.size(0), self.out_channels, *out_size) ++ if torch.onnx.is_in_onnx_export(): ++ roi_feats = RoiExtractor.apply(feats[0], feats[1], feats[2], feats[3], rois) ++ return roi_feats ++ else: ++ roi_feats = feats[0].new_zeros( ++ rois.size(0), self.out_channels, *out_size) + + # TODO: remove this when parrots supports + if torch.__version__ == 'parrots': +diff --git a/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py b/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py +index c2b60b5e..117418f6 100644 +--- a/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py ++++ b/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py +@@ -473,7 +473,7 @@ def onnx_delta2bbox(rois: Tensor, + if clip_border and max_shape is not None: + # clip bboxes with dynamic `min` and `max` for onnx + if torch.onnx.is_in_onnx_export(): +- from mmdet.core.export import dynamic_clip_for_onnx ++ # from mmdet.core.export import dynamic_clip_for_onnx (no longer supported in mmdet v3) + x1, y1, x2, y2 = dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape) + bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size()) + return bboxes +@@ -577,3 +577,41 @@ def delta2bbox_glip(rois: Tensor, + bboxes[..., 1::2].clamp_(min=0, max=max_shape[0] - 1) # Note + bboxes = bboxes.reshape(num_bboxes, -1) + return bboxes ++ ++def dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape): ++ """Clip boxes dynamically for onnx. ++ ++ Since torch.clamp cannot have dynamic `min` and `max`, we scale the ++ boxes by 1/max_shape and clamp in the range [0, 1]. ++ ++ Args: ++ x1 (Tensor): The x1 for bounding boxes. ++ y1 (Tensor): The y1 for bounding boxes. ++ x2 (Tensor): The x2 for bounding boxes. ++ y2 (Tensor): The y2 for bounding boxes. ++ max_shape (Tensor or torch.Size): The (H,W) of original image. ++ Returns: ++ tuple(Tensor): The clipped x1, y1, x2, y2. ++ """ ++ max_shape = torch.Tensor(max_shape) ++ assert isinstance( ++ max_shape, ++ torch.Tensor), '`max_shape` should be tensor of (h,w) for onnx' ++ # scale by 1/max_shape ++ x1 = x1 / max_shape[1] ++ y1 = y1 / max_shape[0] ++ x2 = x2 / max_shape[1] ++ y2 = y2 / max_shape[0] ++ ++ # clamp [0, 1] ++ x1 = torch.clamp(x1, 0, 1) ++ y1 = torch.clamp(y1, 0, 1) ++ x2 = torch.clamp(x2, 0, 1) ++ y2 = torch.clamp(y2, 0, 1) ++ ++ # scale back ++ x1 = x1 * max_shape[1] ++ y1 = y1 * max_shape[0] ++ x2 = x2 * max_shape[1] ++ y2 = y2 * max_shape[0] ++ return x1, y1, x2, y2 +\ No newline at end of file \ No newline at end of file