From ebce3b8d219fc54fa46cd12c11a7045f5a96bd83 Mon Sep 17 00:00:00 2001 From: Mingkai Chan Date: Mon, 18 Nov 2024 18:15:46 +0800 Subject: [PATCH] Refactor multi-scale deformable attention function: rename and replace deprecated function, update documentation --- docs/api/README.md | 17 +- .../bevformer/modules/decoder.py | 2 +- .../modules/spatial_cross_attention.py | 3 +- .../modules/temporal_self_attention.py | 3 +- .../bevformer/modules/decoder.py | 4 +- .../modules/occ_temporal_attention.py | 18 +- .../modules/spatial_cross_attention.py | 17 +- .../modules/temporal_self_attention.py | 18 +- .../modules/spatial_cross_attention.py | 30 +- .../modules/cross_view_hybrid_attention.py | 2 +- .../modules/image_cross_attention.py | 2 +- .../modules/cross_view_hybrid_attention.py | 2 +- .../modules/image_cross_attention.py | 2 +- .../mmcv_need/multi_scale_deform_attn.py | 2 +- .../motion_deformable_attn.py | 2 +- .../mmdet3d_plugin/uniad/modules/decoder.py | 2 +- .../uniad/modules/spatial_cross_attention.py | 2 +- .../uniad/modules/temporal_self_attention.py | 2 +- mx_driving/_C/__init__.pyi | 424 ++++++++++++++++++ mx_driving/fused/__init__.py | 13 +- .../ops/csrc/MultiScaleDeformableAttn.cpp | 109 +++++ .../csrc/MultiScaleDeformableAttnFunction.cpp | 165 ------- mx_driving/fused/ops/csrc/functions.h | 8 +- mx_driving/fused/ops/csrc/pybind.cpp | 7 +- .../fused/ops/multi_scale_deformable_attn.py | 61 +++ ...pu_multi_scale_deformable_attn_function.py | 35 -- mx_driving/fused/ops/onnx/wrapper_onnx_ops.py | 2 +- pyproject.toml | 26 ++ ...st_multi_scale_deformable_attn_function.py | 2 +- 29 files changed, 708 insertions(+), 274 deletions(-) create mode 100644 mx_driving/_C/__init__.pyi create mode 100644 mx_driving/fused/ops/csrc/MultiScaleDeformableAttn.cpp delete mode 100644 mx_driving/fused/ops/csrc/MultiScaleDeformableAttnFunction.cpp create mode 100644 mx_driving/fused/ops/multi_scale_deformable_attn.py delete mode 100644 mx_driving/fused/ops/npu_multi_scale_deformable_attn_function.py create mode 100644 pyproject.toml diff --git a/docs/api/README.md b/docs/api/README.md index 43fc0776..7956f4e0 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -641,29 +641,30 @@ output = border_align(features.npu(), rois.npu(), pooled_size) ``` # 融合算子 -## npu_multi_scale_deformable_attn_function +## multi_scale_deformable_attn(MultiScaleDeformableAttnFunction.Apply) ### 接口原型 ```python -mx_driving.fused.npu_multi_scale_deformable_attn_function(Tensor value, Tensor shape, Tensor offset, Tensor locations, Tensor weight) -> Tensor +mx_driving.fused.multi_scale_deformable_attn(Tensor value, Tensor value_spatial_shapes, Tensor value_level_start_index, Tensor sampling_locations, Tensor attention_weights) -> Tensor ``` ### 功能描述 多尺度可变形注意力机制, 将多个视角的特征图进行融合。 ### 参数说明 - `value(Tensor)`:特征张量,数据类型为`float32, float16`。shape为`[bs, num_keys, num_heads, embed_dims]`。其中`bs`为batch size,`num_keys`为特征图的大小,`num_heads`为头的数量,`embed_dims`为特征图的维度,其中`embed_dims`需要为8的倍数。 -- `shape(Tensor)`:特征图的形状,数据类型为`int32, int64`。shape为`[num_levels, 2]`。其中`num_levels`为特征图的数量,`2`分别代表`H, W`。 -- `offset(Tensor)`:偏移量张量,数据类型为`int32, int64`。shape为`[num_levels]`。 -- `locations(Tensor)`:位置张量,数据类型为`float32, float16`。shape为`[bs, num_queries, num_heads, num_levels, num_points, 2]`。其中`bs`为batch size,`num_queries`为查询的数量,`num_heads`为头的数量,`num_levels`为特征图的数量,`num_points`为采样点的数量,`2`分别代表`y, x`。 -- `weight(Tensor)`:权重张量,数据类型为`float32, float16`。shape为`[bs, num_queries, num_heads, num_levels, num_points]`。其中`bs`为batch size,`num_queries`为查询的数量,`num_heads`为头的数量,`num_levels`为特征图的数量,`num_points`为采样点的数量。 +- `value_spatial_shapes(Tensor)`:特征图的形状,数据类型为`int32, int64`。shape为`[num_levels, 2]`。其中`num_levels`为特征图的数量,`2`分别代表`H, W`。 +- `value_level_start_index(Tensor)`:偏移量张量,数据类型为`int32, int64`。shape为`[num_levels]`。 +- `sampling_locations(Tensor)`:位置张量,数据类型为`float32, float16`。shape为`[bs, num_queries, num_heads, num_levels, num_points, 2]`。其中`bs`为batch size,`num_queries`为查询的数量,`num_heads`为头的数量,`num_levels`为特征图的数量,`num_points`为采样点的数量,`2`分别代表`y, x`。 +- `attention_weights(Tensor)`:权重张量,数据类型为`float32, float16`。shape为`[bs, num_queries, num_heads, num_levels, num_points]`。其中`bs`为batch size,`num_queries`为查询的数量,`num_heads`为头的数量,`num_levels`为特征图的数量,`num_points`为采样点的数量。 ### 返回值 - `output(Tensor)`:融合后的特征张量,数据类型为`float32, float16`。shape为`[bs, num_queries, num_heads*embed_dims]`。 ### 支持的型号 - Atlas A2 训练系列产品 ### 约束说明 - `locations`的值在`[0, 1]`之间。 +- 当前版本只支持`num_keys` ≤ 8,`num_heads` ≤ 8,`embed_dims` == 16或32,`num_points` = 1或偶数。 ### 调用示例 ```python import torch, torch_npu -from mx_driving.fused import npu_multi_scale_deformable_attn_function +from mx_driving.fused import multi_scale_deformable_attn bs, num_levels, num_heads, num_points, num_queries, embed_dims = 1, 1, 4, 8, 16, 32 shapes = torch.as_tensor([(100, 100)], dtype=torch.long) @@ -674,7 +675,7 @@ sampling_locations = torch.ones(bs, num_queries, num_heads, num_levels, num_poin attention_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points) + 1e-5 level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) -out = npu_multi_scale_deformable_attn_function(value.npu(), shapes.npu(), level_start_index.npu(), sampling_locations.npu(), attention_weights.npu()) +out = multi_scale_deformable_attn(value.npu(), shapes.npu(), level_start_index.npu(), sampling_locations.npu(), attention_weights.npu()) ``` ## npu_max_pool2d diff --git a/model_examples/BEVFormer/projects/mmdet3d_plugin/bevformer/modules/decoder.py b/model_examples/BEVFormer/projects/mmdet3d_plugin/bevformer/modules/decoder.py index 5adece07..80703d52 100644 --- a/model_examples/BEVFormer/projects/mmdet3d_plugin/bevformer/modules/decoder.py +++ b/model_examples/BEVFormer/projects/mmdet3d_plugin/bevformer/modules/decoder.py @@ -318,7 +318,7 @@ class CustomMSDeformableAttention(BaseModule): f' 2 or 4, but get {reference_points.shape[-1]} instead.') if torch.cuda.is_available() and value.is_cuda: - output = mx_driving.fused.npu_multi_scale_deformable_attn_function(value, spatial_shapes, level_start_index, + output = mx_driving.fused.multi_scale_deformable_attn(value, spatial_shapes, level_start_index, sampling_locations, attention_weights) else: output = multi_scale_deformable_attn_pytorch( diff --git a/model_examples/BEVFormer/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py b/model_examples/BEVFormer/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py index 0211414d..9129dd03 100644 --- a/model_examples/BEVFormer/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py +++ b/model_examples/BEVFormer/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py @@ -395,8 +395,7 @@ class MSDeformableAttention3D(BaseModule): # attention_weights.shape: bs, num_query, num_heads, num_levels, num_all_points if torch.cuda.is_available() and value.is_cuda: - - output = mx_driving.fused.npu_multi_scale_deformable_attn_function(value, spatial_shapes, level_start_index, + output = mx_driving.fused.multi_scale_deformable_attn(value, spatial_shapes, level_start_index, sampling_locations, attention_weights) else: output = multi_scale_deformable_attn_pytorch( diff --git a/model_examples/BEVFormer/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py b/model_examples/BEVFormer/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py index 30d0a5fe..fcdfb04c 100644 --- a/model_examples/BEVFormer/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py +++ b/model_examples/BEVFormer/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py @@ -235,8 +235,7 @@ class TemporalSelfAttention(BaseModule): f'Last dim of reference_points must be' f' 2 or 4, but get {reference_points.shape[-1]} instead.') if torch.cuda.is_available() and value.is_cuda: - - output = mx_driving.fused.npu_multi_scale_deformable_attn_function(value, spatial_shapes, level_start_index, + output = mx_driving.fused.multi_scale_deformable_attn(value, spatial_shapes, level_start_index, sampling_locations, attention_weights) else: diff --git a/model_examples/PanoOcc/projects/mmdet3d_plugin/bevformer/modules/decoder.py b/model_examples/PanoOcc/projects/mmdet3d_plugin/bevformer/modules/decoder.py index 03a74e57..3b164a26 100644 --- a/model_examples/PanoOcc/projects/mmdet3d_plugin/bevformer/modules/decoder.py +++ b/model_examples/PanoOcc/projects/mmdet3d_plugin/bevformer/modules/decoder.py @@ -23,7 +23,7 @@ from mmcv.runner.base_module import BaseModule, ModuleList, Sequential from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning, to_2tuple) -from mx_driving.fused import npu_multi_scale_deformable_attn_function +from mx_driving.fused import multi_scale_deformable_attn def inverse_sigmoid(x, eps=1e-5): @@ -318,7 +318,7 @@ class CustomMSDeformableAttention(BaseModule): f'Last dim of reference_points must be' f' 2 or 4, but get {reference_points.shape[-1]} instead.') if torch.cuda.is_available() and value.is_cuda: - output = npu_multi_scale_deformable_attn_function(value, spatial_shapes, level_start_index, + output = multi_scale_deformable_attn(value, spatial_shapes, level_start_index, sampling_locations, attention_weights) else: output = multi_scale_deformable_attn_pytorch( diff --git a/model_examples/PanoOcc/projects/mmdet3d_plugin/bevformer/modules/occ_temporal_attention.py b/model_examples/PanoOcc/projects/mmdet3d_plugin/bevformer/modules/occ_temporal_attention.py index ef4cd9e7..4fee8283 100644 --- a/model_examples/PanoOcc/projects/mmdet3d_plugin/bevformer/modules/occ_temporal_attention.py +++ b/model_examples/PanoOcc/projects/mmdet3d_plugin/bevformer/modules/occ_temporal_attention.py @@ -4,20 +4,24 @@ # Modified by Zhiqi Li # --------------------------------------------- -from projects.mmdet3d_plugin.models.utils.bricks import run_time -from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32 -from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch +import math import warnings + import torch import torch.nn as nn -from mmcv.cnn import xavier_init, constant_init +from mmcv.cnn import constant_init, xavier_init from mmcv.cnn.bricks.registry import ATTENTION -import math +from mmcv.ops.multi_scale_deform_attn import \ + multi_scale_deformable_attn_pytorch from mmcv.runner.base_module import BaseModule, ModuleList, Sequential from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning, to_2tuple) +from projects.mmdet3d_plugin.models.utils.bricks import run_time + +from mx_driving.fused import multi_scale_deformable_attn -from mx_driving.fused import npu_multi_scale_deformable_attn_function +from .multi_scale_deformable_attn_function import \ + MultiScaleDeformableAttnFunction_fp32 @ATTENTION.register_module() @@ -241,7 +245,7 @@ class OccTemporalAttention(BaseModule): sampling_locations = sampling_locations.contiguous() if torch.cuda.is_available() and value.is_cuda: - output = npu_multi_scale_deformable_attn_function(value, spatial_shapes, level_start_index, + output = multi_scale_deformable_attn(value, spatial_shapes, level_start_index, sampling_locations, attention_weights) else: diff --git a/model_examples/PanoOcc/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py b/model_examples/PanoOcc/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py index 26db4596..9ea28957 100644 --- a/model_examples/PanoOcc/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py +++ b/model_examples/PanoOcc/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py @@ -6,21 +6,22 @@ # Modified by Zhexu Liu # --------------------------------------------- -import warnings import math +import warnings + import torch import torch.nn as nn import torch.nn.functional as F -from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch -from mmcv.cnn import xavier_init, constant_init -from mmcv.cnn.bricks.registry import (ATTENTION, - TRANSFORMER_LAYER, - TRANSFORMER_LAYER_SEQUENCE) +from mmcv.cnn import constant_init, xavier_init +from mmcv.cnn.bricks.registry import ATTENTION, TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE from mmcv.cnn.bricks.transformer import build_attention +from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch from mmcv.runner import force_fp32 from mmcv.runner.base_module import BaseModule, ModuleList, Sequential from projects.mmdet3d_plugin.models.utils.bricks import run_time -from mx_driving.fused import npu_multi_scale_deformable_attn_function + +from mx_driving.fused import multi_scale_deformable_attn + indexes_global = None max_len_global = None @@ -394,7 +395,7 @@ class MSDeformableAttention3D(BaseModule): f' 2 or 4, but get {reference_points.shape[-1]} instead.') if torch.cuda.is_available() and value.is_cuda: - output = npu_multi_scale_deformable_attn_function(value, spatial_shapes, level_start_index, + output = multi_scale_deformable_attn(value, spatial_shapes, level_start_index, sampling_locations, attention_weights) else: output = multi_scale_deformable_attn_pytorch( diff --git a/model_examples/PanoOcc/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py b/model_examples/PanoOcc/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py index 0ba7a278..2f44d72b 100644 --- a/model_examples/PanoOcc/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py +++ b/model_examples/PanoOcc/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py @@ -4,20 +4,24 @@ # Modified by Zhiqi Li # --------------------------------------------- -from projects.mmdet3d_plugin.models.utils.bricks import run_time -from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32 -from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch +import math import warnings + import torch import torch.nn as nn -from mmcv.cnn import xavier_init, constant_init +from mmcv.cnn import constant_init, xavier_init from mmcv.cnn.bricks.registry import ATTENTION -import math +from mmcv.ops.multi_scale_deform_attn import \ + multi_scale_deformable_attn_pytorch from mmcv.runner.base_module import BaseModule, ModuleList, Sequential from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning, to_2tuple) +from projects.mmdet3d_plugin.models.utils.bricks import run_time + +from mx_driving.fused import multi_scale_deformable_attn -from mx_driving.fused import npu_multi_scale_deformable_attn_function +from .multi_scale_deformable_attn_function import \ + MultiScaleDeformableAttnFunction_fp32 @ATTENTION.register_module() @@ -236,7 +240,7 @@ class TemporalSelfAttention(BaseModule): f'Last dim of reference_points must be' f' 2 or 4, but get {reference_points.shape[-1]} instead.') if torch.cuda.is_available() and value.is_cuda: - output = npu_multi_scale_deformable_attn_function(value, spatial_shapes, level_start_index, + output = multi_scale_deformable_attn(value, spatial_shapes, level_start_index, sampling_locations, attention_weights) else: diff --git a/model_examples/SurroundOcc/projects/mmdet3d_plugin/surroundocc/modules/spatial_cross_attention.py b/model_examples/SurroundOcc/projects/mmdet3d_plugin/surroundocc/modules/spatial_cross_attention.py index 05b9076d..1a432806 100644 --- a/model_examples/SurroundOcc/projects/mmdet3d_plugin/surroundocc/modules/spatial_cross_attention.py +++ b/model_examples/SurroundOcc/projects/mmdet3d_plugin/surroundocc/modules/spatial_cross_attention.py @@ -18,28 +18,32 @@ # Modified by Zhiqi Li # --------------------------------------------- -from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch +import math import warnings + import torch import torch.nn as nn import torch.nn.functional as F -from mmcv.cnn import xavier_init, constant_init -from mmcv.cnn.bricks.registry import (ATTENTION, - TRANSFORMER_LAYER, - TRANSFORMER_LAYER_SEQUENCE) +from mmcv.cnn import constant_init, xavier_init +from mmcv.cnn.bricks.registry import ATTENTION, TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE from mmcv.cnn.bricks.transformer import build_attention -import math -from mmcv.runner import force_fp32, auto_fp16 - +from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch +from mmcv.runner import auto_fp16, force_fp32 from mmcv.runner.base_module import BaseModule, ModuleList, Sequential - from mmcv.utils import ext_loader -from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32, \ - MultiScaleDeformableAttnFunction_fp16 from projects.mmdet3d_plugin.models.utils.bricks import run_time + +from mx_driving.fused import multi_scale_deformable_attn + +from .multi_scale_deformable_attn_function import ( + MultiScaleDeformableAttnFunction_fp16, + MultiScaleDeformableAttnFunction_fp32, +) + + ext_module = ext_loader.load_ext( '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward']) -from mx_driving.fused import npu_multi_scale_deformable_attn_function + @ATTENTION.register_module() class SpatialCrossAttention(BaseModule): @@ -400,7 +404,7 @@ class MSDeformableAttention3D(BaseModule): MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32 else: MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32 - output = npu_multi_scale_deformable_attn_function(value, spatial_shapes, level_start_index, sampling_locations, attention_weights) + output = multi_scale_deformable_attn(value, spatial_shapes, level_start_index, sampling_locations, attention_weights) else: output = multi_scale_deformable_attn_pytorch( value, spatial_shapes, sampling_locations, attention_weights) diff --git a/model_examples/TPVFormer/tpvformer04/modules/cross_view_hybrid_attention.py b/model_examples/TPVFormer/tpvformer04/modules/cross_view_hybrid_attention.py index 4c0ad9b6..4b212b8e 100644 --- a/model_examples/TPVFormer/tpvformer04/modules/cross_view_hybrid_attention.py +++ b/model_examples/TPVFormer/tpvformer04/modules/cross_view_hybrid_attention.py @@ -227,7 +227,7 @@ class TPVCrossViewHybridAttention(BaseModule): f'Last dim of reference_points must be' f' 2 or 4, but get {reference_points.shape[-1]} instead.') - output = mx_driving.fused.npu_multi_scale_deformable_attn_function( + output = mx_driving.fused.multi_scale_deformable_attn( value, spatial_shapes, level_start_index, sampling_locations, attention_weights) # output shape (bs*num_tpv_queue, num_query, embed_dims) diff --git a/model_examples/TPVFormer/tpvformer04/modules/image_cross_attention.py b/model_examples/TPVFormer/tpvformer04/modules/image_cross_attention.py index 398460ee..66bc7bae 100644 --- a/model_examples/TPVFormer/tpvformer04/modules/image_cross_attention.py +++ b/model_examples/TPVFormer/tpvformer04/modules/image_cross_attention.py @@ -440,7 +440,7 @@ class TPVMSDeformableAttention3D(BaseModule): # sampling_locations.shape: bs, num_query, num_heads, num_levels, num_all_points, 2 # attention_weights.shape: bs, num_query, num_heads, num_levels, num_all_points - output = mx_driving.fused.npu_multi_scale_deformable_attn_function( + output = mx_driving.fused.multi_scale_deformable_attn( value, spatial_shapes, level_start_index, sampling_locations, attention_weights) output = self.reshape_output(output, query_lens) diff --git a/model_examples/TPVFormer/tpvformer10/modules/cross_view_hybrid_attention.py b/model_examples/TPVFormer/tpvformer10/modules/cross_view_hybrid_attention.py index dc61a4c1..bcb37916 100644 --- a/model_examples/TPVFormer/tpvformer10/modules/cross_view_hybrid_attention.py +++ b/model_examples/TPVFormer/tpvformer10/modules/cross_view_hybrid_attention.py @@ -200,7 +200,7 @@ class TPVCrossViewHybridAttention(BaseModule): f' 2, but get {reference_points.shape[-1]} instead.') - output = mx_driving.fused.npu_multi_scale_deformable_attn_function( + output = mx_driving.fused.multi_scale_deformable_attn( value, spatial_shapes, level_start_index, sampling_locations, attention_weights) outputs = self.reshape_output(output, query_lens) diff --git a/model_examples/TPVFormer/tpvformer10/modules/image_cross_attention.py b/model_examples/TPVFormer/tpvformer10/modules/image_cross_attention.py index 49d62bbc..82242208 100644 --- a/model_examples/TPVFormer/tpvformer10/modules/image_cross_attention.py +++ b/model_examples/TPVFormer/tpvformer10/modules/image_cross_attention.py @@ -422,7 +422,7 @@ class TPVMSDeformableAttention3D(BaseModule): # sampling_locations.shape: bs, num_query, num_heads, num_levels, num_all_points, 2 # attention_weights.shape: bs, num_query, num_heads, num_levels, num_all_points - output = mx_driving.fused.npu_multi_scale_deformable_attn_function( + output = mx_driving.fused.multi_scale_deformable_attn( value, spatial_shapes, level_start_index, sampling_locations, attention_weights) output = self.reshape_output(output, query_lens) diff --git a/model_examples/UniAD/mmcv_need/multi_scale_deform_attn.py b/model_examples/UniAD/mmcv_need/multi_scale_deform_attn.py index 243d0efb..0107208f 100644 --- a/model_examples/UniAD/mmcv_need/multi_scale_deform_attn.py +++ b/model_examples/UniAD/mmcv_need/multi_scale_deform_attn.py @@ -365,7 +365,7 @@ class MultiScaleDeformableAttention(BaseModule): if ((IS_CUDA_AVAILABLE and value.is_cuda) or (IS_MLU_AVAILABLE and value.is_mlu) or (IS_NPU_AVAILABLE and value.device.type == 'npu')): - output = mx_driving.fused.npu_multi_scale_deformable_attn_function(value, spatial_shapes, level_start_index, + output = mx_driving.fused.multi_scale_deformable_attn(value, spatial_shapes, level_start_index, sampling_locations, attention_weights) else: output = multi_scale_deformable_attn_pytorch( diff --git a/model_examples/UniAD/projects/mmdet3d_plugin/uniad/dense_heads/motion_head_plugin/motion_deformable_attn.py b/model_examples/UniAD/projects/mmdet3d_plugin/uniad/dense_heads/motion_head_plugin/motion_deformable_attn.py index 5efadfc5..8c71956c 100644 --- a/model_examples/UniAD/projects/mmdet3d_plugin/uniad/dense_heads/motion_head_plugin/motion_deformable_attn.py +++ b/model_examples/UniAD/projects/mmdet3d_plugin/uniad/dense_heads/motion_head_plugin/motion_deformable_attn.py @@ -454,7 +454,7 @@ class MotionDeformableAttention(BaseModule): f' 2 or 4, but get {reference_trajs.shape[-1]} instead.') if torch.cuda.is_available() and value.is_cuda: - output = mx_driving.fused.npu_multi_scale_deformable_attn_function(value, spatial_shapes, level_start_index, + output = mx_driving.fused.multi_scale_deformable_attn(value, spatial_shapes, level_start_index, sampling_locations, attention_weights) else: output = multi_scale_deformable_attn_pytorch( diff --git a/model_examples/UniAD/projects/mmdet3d_plugin/uniad/modules/decoder.py b/model_examples/UniAD/projects/mmdet3d_plugin/uniad/modules/decoder.py index 8af5c01f..bfd31b79 100644 --- a/model_examples/UniAD/projects/mmdet3d_plugin/uniad/modules/decoder.py +++ b/model_examples/UniAD/projects/mmdet3d_plugin/uniad/modules/decoder.py @@ -325,7 +325,7 @@ class CustomMSDeformableAttention(BaseModule): f' 2 or 4, but get {reference_points.shape[-1]} instead.') if torch.cuda.is_available() and value.is_cuda: - output = mx_driving.fused.npu_multi_scale_deformable_attn_function(value, spatial_shapes, level_start_index, + output = mx_driving.fused.multi_scale_deformable_attn(value, spatial_shapes, level_start_index, sampling_locations, attention_weights) else: output = multi_scale_deformable_attn_pytorch( diff --git a/model_examples/UniAD/projects/mmdet3d_plugin/uniad/modules/spatial_cross_attention.py b/model_examples/UniAD/projects/mmdet3d_plugin/uniad/modules/spatial_cross_attention.py index 4eb3d1ec..616b9054 100644 --- a/model_examples/UniAD/projects/mmdet3d_plugin/uniad/modules/spatial_cross_attention.py +++ b/model_examples/UniAD/projects/mmdet3d_plugin/uniad/modules/spatial_cross_attention.py @@ -384,7 +384,7 @@ class MSDeformableAttention3D(BaseModule): # if torch.cuda.is_available() and value.is_cuda: - output = mx_driving.fused.npu_multi_scale_deformable_attn_function(value, spatial_shapes, level_start_index, + output = mx_driving.fused.multi_scale_deformable_attn(value, spatial_shapes, level_start_index, sampling_locations, attention_weights) else: output = multi_scale_deformable_attn_pytorch( diff --git a/model_examples/UniAD/projects/mmdet3d_plugin/uniad/modules/temporal_self_attention.py b/model_examples/UniAD/projects/mmdet3d_plugin/uniad/modules/temporal_self_attention.py index 44b63fec..620391cf 100644 --- a/model_examples/UniAD/projects/mmdet3d_plugin/uniad/modules/temporal_self_attention.py +++ b/model_examples/UniAD/projects/mmdet3d_plugin/uniad/modules/temporal_self_attention.py @@ -238,7 +238,7 @@ class TemporalSelfAttention(BaseModule): f' 2 or 4, but get {reference_points.shape[-1]} instead.') if torch.cuda.is_available() and value.is_cuda: - output = mx_driving.fused.npu_multi_scale_deformable_attn_function(value, spatial_shapes, level_start_index, + output = mx_driving.fused.multi_scale_deformable_attn(value, spatial_shapes, level_start_index, sampling_locations, attention_weights) else: diff --git a/mx_driving/_C/__init__.pyi b/mx_driving/_C/__init__.pyi new file mode 100644 index 00000000..5a3866a1 --- /dev/null +++ b/mx_driving/_C/__init__.pyi @@ -0,0 +1,424 @@ +from typing import List, Optional, Tuple + +import torch + +def _init_op_api_so_path(so_path: str) -> None: ... +def knn( + xyz: torch.Tensor, center_xyz: torch.Tensor, k: int, is_from_knn: bool +) -> Tuple[torch.Tensor, torch.Tensor]: ... +def npu_three_interpolate( + b: int, c: int, m: int, n: int, points: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor +) -> torch.Tensor: ... +def npu_three_interpolate_backward( + b: int, c: int, n: int, m: int, grad_out: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor +) -> torch.Tensor: ... +def scatter_max_with_argmax_v2( + updates: torch.Tensor, indices: torch.Tensor, out: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: ... +def npu_scatter_max_backward( + x: torch.Tensor, segment_ids: torch.Tensor, num_segments: torch.Tensor +) -> torch.Tensor: ... +def npu_scatter(self: torch.Tensor, indices: torch.Tensor, updates: torch.Tensor, dim: int) -> torch.Tensor: ... +def npu_scatter_mean_grad( + grad_out: torch.Tensor, index: torch.Tensor, count: torch.Tensor, dim: int +) -> torch.Tensor: ... +def npu_scatter_mean( + src: torch.Tensor, + index: torch.Tensor, + out: Optional[torch.Tensor] = None, + dim: Optional[int] = None, + dim_size: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: ... +def npu_sort_pairs( + keys_in: torch.Tensor, values_in: torch.Tensor, dim: int, descending: bool +) -> Tuple[torch.Tensor, torch.Tensor]: ... +def npu_hypot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ... +def npu_hypot_grad( + x: torch.Tensor, y: torch.Tensor, out: torch.Tensor, out_grad: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: ... +def assign_score_withk( + points: torch.Tensor, + centers: torch.Tensor, + scores: torch.Tensor, + knn_idx: torch.Tensor, + output: torch.Tensor, + B: int, + N: int, + npoint: int, + M: int, + K: int, + out_dim: int, + aggregate: int, +) -> None: ... +def npu_max_pool2d(x: torch.Tensor, kernel_size: int, stride: int, padding: int) -> torch.Tensor: ... +def multi_scale_deformable_attn( + value: torch.Tensor, + value_spatial_shapes: torch.Tensor, + value_level_start_index: torch.Tensor, + sampling_locations: torch.Tensor, + attention_weights: torch.Tensor, +) -> torch.Tensor: ... +def multi_scale_deformable_attn_backward( + value: torch.Tensor, + shape: torch.Tensor, + level_start_index: torch.Tensor, + location_trans: torch.Tensor, + attn_weight_trans: torch.Tensor, + grad_output: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... +def npu_add_relu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ... +def npu_add_relu_grad(self: torch.Tensor, grad_output: torch.Tensor) -> torch.Tensor: ... +def fused_bias_leaky_relu(x: torch.Tensor, bias: torch.Tensor, negative_slop: float, scale: float) -> torch.Tensor: ... +def deformable_aggregation( + mc_ms_feat: torch.Tensor, + spatial_shape: torch.Tensor, + scale_start_index: torch.Tensor, + sampling_location: torch.Tensor, + weights: torch.Tensor, +) -> torch.Tensor: ... +def deformable_aggregation_grad( + mc_ms_feat: torch.Tensor, + spatial_shape: torch.Tensor, + scale_start_index: torch.Tensor, + sampling_location: torch.Tensor, + weights: torch.Tensor, + grad_output: torch.Tensor, + grad_mc_ms_feat: torch.Tensor, + grad_sampling_location: torch.Tensor, + grad_weights: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... +def deformable_conv2d( + input: torch.Tensor, + offset: torch.Tensor, + weight: torch.Tensor, + kernel_size: Tuple[int, int], + stride: Tuple[int, int], + padding: Tuple[int, int], + dilation: Tuple[int, int], + groups: int, + deformable_groups: int, +) -> Tuple[torch.Tensor, torch.Tensor]: ... +def modulated_deformable_conv2d( + input: torch.Tensor, + offset: torch.Tensor, + mask: torch.Tensor, + weight: torch.Tensor, + bias_opt: Optional[torch.Tensor], + kernel_size: Tuple[int, int], + stride: Tuple[int, int], + padding: Tuple[int, int], + dilation: Tuple[int, int], + groups: int, + deformable_groups: int, + with_bias: int, +) -> Tuple[torch.Tensor, torch.Tensor]: ... +def deformable_conv2d_backward( + input: torch.Tensor, + weight: torch.Tensor, + offset: torch.Tensor, + offset_output: torch.Tensor, + grad_y: torch.Tensor, + kernel_size: Tuple[int, int], + stride: Tuple[int, int], + padding: Tuple[int, int], + dilation: Tuple[int, int], + groups: int, + deformable_groups: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... +def modulated_deformable_conv2d_backward( + input: torch.Tensor, + offset: torch.Tensor, + mask: torch.Tensor, + weight: torch.Tensor, + bias_opt: Optional[torch.Tensor], + offset_output: torch.Tensor, + grad_y: torch.Tensor, + kernel_size: Tuple[int, int], + stride: Tuple[int, int], + padding: Tuple[int, int], + dilation: Tuple[int, int], + groups: int, + deformable_groups: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ... +def npu_subm_sparse_conv3d( + feature: torch.Tensor, + indices: torch.Tensor, + weight: torch.Tensor, + kernel_size: Tuple[int, int, int], + out_channel: int, + outSpatialShape: Tuple[int, int, int], + batch_size: int, + temp: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... +def multi_to_sparse( + out_features: torch.Tensor, + unique_indices_offset: torch.Tensor, + sorted_idx_to_former_indices: torch.Tensor, + outidx_pair: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: ... +def multi_to_sparse_v2( + features: torch.Tensor, + weight: torch.Tensor, + unique_indices_offset: torch.Tensor, + sorted_idx_to_former_indices: torch.Tensor, + outidx_pair: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: ... +def npu_sparse_conv3d( + indices: torch.Tensor, + kernel_size: Tuple[int, int, int], + stride: Tuple[int, int, int], + padding: Tuple[int, int, int], + out_channel: int, + outSpatialShape: Tuple[int, int, int], + batch_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: ... +def npu_sparse_inverse_conv3d( + feature: torch.Tensor, + indices: torch.Tensor, + weight: torch.Tensor, + kernel_size: Tuple[int, int, int], + stride: Tuple[int, int, int], + padding: Tuple[int, int, int], + dilation: Tuple[int, int, int], + output_padding: Tuple[int, int, int], + out_channel: int, + outSpatialShape: Tuple[int, int, int], + batch_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... +def npu_sparse_conv3d_grad( + indices_offset: torch.Tensor, + former_sorted_indices: torch.Tensor, + feature: torch.Tensor, + weight: torch.Tensor, + grad: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: ... +def npu_prepare_subm_conv3d( + flattenIndices: torch.Tensor, outSpatialShape: Tuple[int, int, int], batch_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: ... +def nms3d_normal(boxes: torch.Tensor, nms_overlap_thresh: float) -> Tuple[torch.Tensor, torch.Tensor]: ... +def nms3d(boxes: torch.Tensor, threshold: float) -> Tuple[torch.Tensor, torch.Tensor]: ... +def npu_rotated_overlaps(self: torch.Tensor, query_boxes: torch.Tensor, trans: bool) -> torch.Tensor: ... +def npu_rotated_iou( + boxes: torch.Tensor, + query_boxes: torch.Tensor, + trans: bool, + mode: int, + is_cross: bool, + v_threshold: float, + e_threshold: float, +) -> torch.Tensor: ... +def npu_boxes_overlap_bev(boxes_a: torch.Tensor, boxes_b: torch.Tensor) -> torch.Tensor: ... +def roi_align_rotated_v2_forward_npu( + input: torch.Tensor, + rois_map: torch.Tensor, + output: torch.Tensor, + spatial_scale: float, + sampling_ratio: int, + pooled_height: int, + pooled_width: int, + aligned: bool, + clockwise: bool, +) -> None: ... +def npu_roi_align_rotated_grad_v2( + input: torch.Tensor, + rois: torch.Tensor, + grad_output: torch.Tensor, + pooled_height: int, + pooled_width: int, + spatial_scale: float, + sampling_ratio: int, + aligned: bool, + clockwise: bool, +) -> torch.Tensor: ... +def npu_box_iou_quadri(boxes_a: torch.Tensor, boxes_b: torch.Tensor, mode_flag: int, aligned: bool) -> torch.Tensor: ... +def npu_box_iou_rotated( + boxes_a: torch.Tensor, boxes_b: torch.Tensor, mode_flag: int, aligned: bool +) -> torch.Tensor: ... +def border_align_forward_npu( + input: torch.Tensor, rois: torch.Tensor, output: torch.Tensor, pooled_size: int +) -> None: ... +def border_align_backward( + grad_out: torch.Tensor, boxes: torch.Tensor, argmax_idx: torch.Tensor, pool_size: int, height: int, width: int +) -> torch.Tensor: ... +def npu_roiaware_pool3d_forward( + rois: torch.Tensor, + pts: torch.Tensor, + pts_feature: torch.Tensor, + argmax: torch.Tensor, + pts_idx_of_voxels: torch.Tensor, + pooled_features: torch.Tensor, + mode: int, +) -> None: ... +def roiaware_pool3d_grad( + pts_idx_of_voxels: torch.Tensor, argmax: torch.Tensor, grad_out: torch.Tensor, npoints: int, pool_method: int +) -> torch.Tensor: ... +def npu_points_in_box(boxes: torch.Tensor, pts: torch.Tensor) -> torch.Tensor: ... +def npu_points_in_box_all(boxes: torch.Tensor, pts: torch.Tensor) -> torch.Tensor: ... +def npu_roipoint_pool3d_forward( + num_sampled_points: int, points: torch.Tensor, point_features: torch.Tensor, boxes3d: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: ... +def group_points( + points: torch.Tensor, idx: torch.Tensor, b: int, c: int, n: int, npoints: int, nsample: int +) -> torch.Tensor: ... +def group_points_backward( + grad_out: torch.Tensor, idx: torch.Tensor, b: int, c: int, n: int, npoints: int, nsample: int +) -> torch.Tensor: ... +def vec_pool_backward( + grad_new_features: torch.Tensor, point_cnt_of_grid: torch.Tensor, grouped_idxs: torch.Tensor, n: int, num_c_in: int +) -> torch.Tensor: ... +def point_to_voxel( + points: torch.Tensor, voxel_sizes: List[float], coor_ranges: List[float], layout: str +) -> torch.Tensor: ... +def voxel_to_point( + voxels: torch.Tensor, voxel_sizes: List[float], coor_ranges: List[float], layout: str +) -> torch.Tensor: ... +def unique_voxel(voxels: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ... +def hard_voxelize( + points: torch.Tensor, voxel_sizes: List[float], coor_ranges: List[float], max_points: int, max_voxels: int +) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]: ... +def npu_bev_pool( + feat: torch.Tensor, + geom_feat: torch.Tensor, + interval_lengths: torch.Tensor, + interval_starts: torch.Tensor, + b: int, + d: int, + h: int, + w: int, +) -> torch.Tensor: ... +def npu_bev_pool_backward( + grad_out: torch.Tensor, + geom_feat: torch.Tensor, + interval_lengths: torch.Tensor, + interval_starts: torch.Tensor, + b: int, + d: int, + h: int, + w: int, +) -> torch.Tensor: ... +def npu_bev_pool_v2( + depth: torch.Tensor, + feat: torch.Tensor, + ranks_depth: torch.Tensor, + ranks_feat: torch.Tensor, + ranks_bev: torch.Tensor, + interval_lengths: torch.Tensor, + interval_starts: torch.Tensor, + b: int, + d: int, + h: int, + w: int, +) -> torch.Tensor: ... +def npu_bev_pool_v2_backward( + grad_out: torch.Tensor, + depth: torch.Tensor, + feat: torch.Tensor, + ranks_depth: torch.Tensor, + ranks_feat: torch.Tensor, + ranks_bev: torch.Tensor, + interval_lengths: torch.Tensor, + interval_starts: torch.Tensor, + b: int, + d: int, + h: int, + w: int, +) -> Tuple[torch.Tensor, torch.Tensor]: ... +def furthest_point_sampling_with_dist( + points_dist: torch.Tensor, nearest_temp: torch.Tensor, num_points: int +) -> torch.Tensor: ... +def npu_dynamic_scatter( + feats: torch.Tensor, + coors: torch.Tensor, + prefix_sum_point_per_voxel: torch.Tensor, + argsort_coor: torch.Tensor, + num_voxels: int, + reduce_type: str, +) -> Tuple[torch.Tensor, torch.Tensor]: ... +def npu_dynamic_scatter_grad( + grad_point_feats: torch.Tensor, + grad_voxel_feats: torch.Tensor, + prefix_sum_point_per_voxel: torch.Tensor, + argsort_coor: torch.Tensor, + compare_mask: torch.Tensor, + reduce_type: str, +) -> None: ... +def npu_furthest_point_sampling( + point_xyz: torch.Tensor, nearset_temp: torch.Tensor, num_points: int +) -> torch.Tensor: ... +def voxel_pooling_train( + inputFeatures: torch.Tensor, + geom: torch.Tensor, + outputFeatures: torch.Tensor, + posMemo: torch.Tensor, + batchSize: int, + numPoints: int, + numChannels: int, + numVoxelX: int, + numVoxelY: int, + numVoxelZ: int, +) -> Tuple[torch.Tensor, torch.Tensor]: ... +def voxel_pool_train_backward( + grad_out: torch.Tensor, posMemo: torch.Tensor, batchSize: int, numPoints: int, numChannels: int, h: int, w: int +) -> torch.Tensor: ... +def dynamic_voxelization( + points: torch.Tensor, + coors: torch.Tensor, + grid_x: int, + grid_y: int, + grid_z: int, + voxel_x: float, + voxel_y: float, + voxel_z: float, + coors_min_x: float, + coors_min_y: float, + coorsMinZ: float, +) -> torch.Tensor: ... + +__all__ = [ + "knn", + "npu_three_interpolate", + "npu_three_interpolate_backward", + "scatter_max_with_argmax_v2", + "npu_scatter_max_backward", + "npu_scatter", + "npu_scatter_mean_grad", + "npu_scatter_mean", + "npu_sort_pairs", + "npu_hypot", + "npu_hypot_grad", + "assign_score_withk", + "npu_max_pool2d", + "multi_scale_deformable_attn", + "multi_scale_deformable_attn_backward", + "npu_add_relu", + "npu_add_relu_grad", + "fused_bias_leaky_relu", + "deformable_aggregation", + "deformable_aggregation_grad", + "deformable_conv2d", + "modulated_deformable_conv2d", + "deformable_conv2d_backward", + "modulated_deformable_conv2d_backward", + "npu_subm_sparse_conv3d", + "nms3d_normal", + "nms3d", + "npu_rotated_overlaps", + "npu_rotated_iou", + "npu_boxes_overlap_bev", + "npu_points_in_box", + "npu_points_in_box_all", + "npu_roipoint_pool3d_forward", + "group_points", + "group_points_backward", + "vec_pool_backward", + "point_to_voxel", + "voxel_pooling_train", + "voxel_pool_train_backward", + "dynamic_voxelization", + "furthest_point_sampling_with_dist", + "npu_dynamic_scatter", + "npu_dynamic_scatter_grad", + "npu_furthest_point_sampling", + "npu_bve_pool_v3", + "npu_bev_pool_v3_backward", +] diff --git a/mx_driving/fused/__init__.py b/mx_driving/fused/__init__.py index aee3a03e..87330a8c 100644 --- a/mx_driving/fused/__init__.py +++ b/mx_driving/fused/__init__.py @@ -1,7 +1,10 @@ -from .ops.npu_max_pool2d import npu_max_pool2d -from .ops.npu_add_relu import npu_add_relu -from .ops.npu_multi_scale_deformable_attn_function import npu_multi_scale_deformable_attn_function +from .ops.deform_conv2d import DeformConv2dFunction, deform_conv2d from .ops.fused_bias_leaky_relu import npu_fused_bias_leaky_relu +from .ops.modulated_deform_conv2d import (ModulatedDeformConv2dFunction, + modulated_deform_conv2d) +from .ops.multi_scale_deformable_attn import ( + MultiScaleDeformableAttnFunction, multi_scale_deformable_attn, + npu_multi_scale_deformable_attn_function) +from .ops.npu_add_relu import npu_add_relu from .ops.npu_deformable_aggregation import npu_deformable_aggregation -from .ops.deform_conv2d import deform_conv2d, DeformConv2dFunction -from .ops.modulated_deform_conv2d import modulated_deform_conv2d, ModulatedDeformConv2dFunction +from .ops.npu_max_pool2d import npu_max_pool2d diff --git a/mx_driving/fused/ops/csrc/MultiScaleDeformableAttn.cpp b/mx_driving/fused/ops/csrc/MultiScaleDeformableAttn.cpp new file mode 100644 index 00000000..2e81bdf2 --- /dev/null +++ b/mx_driving/fused/ops/csrc/MultiScaleDeformableAttn.cpp @@ -0,0 +1,109 @@ +// Copyright (c) 2024 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "csrc/OpApiCommon.h" +#include "functions.h" + +namespace { +constexpr size_t BATCH_IDX = 0; +constexpr size_t QUERY_IDX = 1; +constexpr size_t HEAD_IDX = 2; +constexpr size_t EMBED_IDX = 3; +constexpr size_t LEVEL_IDX = 3; +} // namespace + +at::Tensor multi_scale_deformable_attn(const at::Tensor& value, const at::Tensor& value_spatial_shapes, + const at::Tensor& value_level_start_index, const at::Tensor& sampling_locations, + const at::Tensor& attention_weights) +{ + TORCH_CHECK(value.scalar_type() == at::kHalf || value.scalar_type() == at::kFloat, + "value: float16 or float32 tensor expected but got a tensor with dtype: ", value.scalar_type()); + TORCH_CHECK(value_spatial_shapes.scalar_type() == at::kInt, + "value_spatial_shapes: int32 tensor expected but got a tensor with dtype: ", + value_spatial_shapes.scalar_type()); + TORCH_CHECK(value_level_start_index.scalar_type() == at::kInt, + "value_level_start_index: int32 tensor expected but got a tensor with dtype: ", + value_level_start_index.scalar_type()); + TORCH_CHECK(sampling_locations.scalar_type() == at::kHalf || sampling_locations.scalar_type() == at::kFloat, + "sampling_locations: float16 or float32 tensor expected but got a tensor with dtype: ", + sampling_locations.scalar_type()); + TORCH_CHECK(attention_weights.scalar_type() == at::kHalf || attention_weights.scalar_type() == at::kFloat, + "attention_weights: float16 or float32 tensor expected but got a tensor with dtype: ", + attention_weights.scalar_type()); + + at::SmallVector output_size = {sampling_locations.size(BATCH_IDX), sampling_locations.size(QUERY_IDX), + value.size(HEAD_IDX) * value.size(EMBED_IDX)}; + at::Tensor output = at::empty(output_size, value.options().dtype(at::kFloat)); + + if (ASCEND_UNLIKELY(value.scalar_type() == at::kHalf)) { + at::Tensor value_fp32 = value.to(at::kFloat); + at::Tensor sampling_locations_fp32 = sampling_locations.to(at::kFloat); + at::Tensor attention_weights_fp32 = attention_weights.to(at::kFloat); + EXEC_NPU_CMD(aclnnMultiScaleDeformableAttn, value_fp32, value_spatial_shapes, value_level_start_index, + sampling_locations_fp32, attention_weights_fp32, output); + } + + EXEC_NPU_CMD(aclnnMultiScaleDeformableAttn, value, value_spatial_shapes, value_level_start_index, + sampling_locations, attention_weights, output); + return output; +} + +std::tuple multi_scale_deformable_attn_backward(const at::Tensor& value, + const at::Tensor& value_spatial_shapes, const at::Tensor& value_level_start_index, + const at::Tensor& sampling_locations, const at::Tensor& attention_weights, const at::Tensor& grad_output) +{ + TORCH_CHECK(value.scalar_type() == at::kHalf || value.scalar_type() == at::kFloat, + "value: float16 or float32 tensor expected but got a tensor with dtype: ", value.scalar_type()); + TORCH_CHECK(value_spatial_shapes.scalar_type() == at::kInt, + "value_spatial_shapes: int32 or int64 tensor expected but got a tensor with dtype: ", + value_spatial_shapes.scalar_type()); + TORCH_CHECK(value_level_start_index.scalar_type() == at::kInt, + "value_level_start_index: int32 or int64 tensor expected but got a tensor with dtype: ", + value_level_start_index.scalar_type()); + TORCH_CHECK(sampling_locations.scalar_type() == at::kHalf || sampling_locations.scalar_type() == at::kFloat, + "sampling_locations: float16 or float32 tensor expected but got a tensor with dtype: ", + sampling_locations.scalar_type()); + TORCH_CHECK(attention_weights.scalar_type() == at::kHalf || attention_weights.scalar_type() == at::kFloat, + "attn_weight_trans: float16 or float32 tensor expected but got a tensor with dtype: ", + attention_weights.scalar_type()); + TORCH_CHECK(grad_output.scalar_type() == at::kHalf || grad_output.scalar_type() == at::kFloat, + "grad_output: float16 or float32 tensor expected but got a tensor with dtype: ", grad_output.scalar_type()); + + at::Tensor grad_value = at::zeros_like(value, value.options().dtype(at::kFloat)); + at::Tensor grad_sampling_loc = at::empty_like(sampling_locations, sampling_locations.options().dtype(at::kFloat)); + at::Tensor grad_attn_weight = at::empty_like(attention_weights, attention_weights.options().dtype(at::kFloat)); + + // Check if the number of spatial shapes does not match the number of attention weights + if (ASCEND_UNLIKELY(value_spatial_shapes.size(0) != attention_weights.size(LEVEL_IDX))) { + grad_sampling_loc.zero_(); + grad_attn_weight.zero_(); + } + + if (ASCEND_UNLIKELY(value.scalar_type() == at::kHalf)) { + at::Tensor value_fp32 = value.to(at::kFloat); + at::Tensor sampling_locations_fp32 = sampling_locations.to(at::kFloat); + at::Tensor attention_weights_fp32 = attention_weights.to(at::kFloat); + EXEC_NPU_CMD(aclnnMultiScaleDeformableAttnGrad, value_fp32, value_spatial_shapes, value_level_start_index, + sampling_locations_fp32, attention_weights_fp32, grad_output, grad_value, grad_sampling_loc, + grad_attn_weight); + return std::make_tuple( + grad_value.to(at::kHalf), grad_sampling_loc.to(at::kHalf), grad_attn_weight.to(at::kHalf)); + } + + EXEC_NPU_CMD(aclnnMultiScaleDeformableAttnGrad, value, value_spatial_shapes, value_level_start_index, + sampling_locations, attention_weights, grad_output, grad_value, grad_sampling_loc, grad_attn_weight); + return std::make_tuple(grad_value, grad_sampling_loc, grad_attn_weight); +} diff --git a/mx_driving/fused/ops/csrc/MultiScaleDeformableAttnFunction.cpp b/mx_driving/fused/ops/csrc/MultiScaleDeformableAttnFunction.cpp deleted file mode 100644 index 9fbfacbc..00000000 --- a/mx_driving/fused/ops/csrc/MultiScaleDeformableAttnFunction.cpp +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright (c) 2024 Huawei Technologies Co., Ltd -// Copyright (c) 2019, Facebook CORPORATION. -// All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "csrc/OpApiCommon.h" -#include "functions.h" - -constexpr size_t BATCH_SIZE_IDX = 0; -constexpr size_t NUM_QUERIES_IDX = 1; -constexpr size_t NUM_HEADS_IDX = 3; -constexpr size_t NUM_POINTS_IDX = 5; -constexpr size_t NUM_LEVELS_IDX = 4; -constexpr size_t EMBED_DIMS_IDX = 3; - -at::Tensor npu_multi_scale_deformable_attn_function(const at::Tensor& value, const at::Tensor& value_spatial_shapes, - const at::Tensor& value_level_start_index, const at::Tensor& sampling_locations, - const at::Tensor& attention_weights) -{ - TORCH_CHECK(value.scalar_type() == at::kHalf || value.scalar_type() == at::kFloat, - "value: float16 or float32 tensor expected but got a tensor with dtype: ", value.scalar_type()); - TORCH_CHECK(value_spatial_shapes.scalar_type() == at::kInt || value_spatial_shapes.scalar_type() == at::kLong, - "value_spatial_shapes: int32 or int64 tensor expected but got a tensor with dtype: ", - value_spatial_shapes.scalar_type()); - TORCH_CHECK(value_level_start_index.scalar_type() == at::kInt || value_level_start_index.scalar_type() == at::kLong, - "value_level_start_index: int32 or int64 tensor expected but got a tensor with dtype: ", - value_level_start_index.scalar_type()); - TORCH_CHECK(sampling_locations.scalar_type() == at::kHalf || sampling_locations.scalar_type() == at::kFloat, - "sampling_locations: float16 or float32 tensor expected but got a tensor with dtype: ", - sampling_locations.scalar_type()); - TORCH_CHECK(attention_weights.scalar_type() == at::kHalf || attention_weights.scalar_type() == at::kFloat, - "attention_weights: float16 or float32 tensor expected but got a tensor with dtype: ", - attention_weights.scalar_type()); - - auto ori_dtype = value.scalar_type(); - // construct the output tensor of the NPU - auto value_size = value.sizes(); - auto location_size = sampling_locations.sizes(); - auto embed_dims = value_size[3]; - auto output_size = {value_size[0], location_size[1], value_size[2] * embed_dims}; - - at::Tensor result = at::empty(output_size, value.options().dtype(at::kFloat)); - - at::Tensor value_cp = value.to(at::kFloat); - at::Tensor value_spatial_shapes_cp = value_spatial_shapes.to(at::kInt); - at::Tensor value_level_start_index_cp = value_level_start_index.to(at::kInt); - at::Tensor sampling_locations_cp = sampling_locations.to(at::kFloat); - at::Tensor attention_weights_cp = attention_weights.to(at::kFloat); - - EXEC_NPU_CMD(aclnnMultiScaleDeformableAttn, value_cp, value_spatial_shapes_cp, value_level_start_index_cp, - sampling_locations_cp, attention_weights_cp, result); - - return result.to(ori_dtype); -} - -std::tuple multi_scale_deformable_attn_grad_v2(const at::Tensor& value_trans, - const at::Tensor& shape, const at::Tensor& level_start_index, const at::Tensor& location_trans, - const at::Tensor& attn_weight_trans, const at::Tensor& grad_output) -{ - TORCH_CHECK(value_trans.scalar_type() == at::kHalf || value_trans.scalar_type() == at::kFloat, - "value_trans: float16 or float32 tensor expected but got a tensor with dtype: ", value_trans.scalar_type()); - TORCH_CHECK(shape.scalar_type() == at::kInt || shape.scalar_type() == at::kLong, - "spatial_shapes: int32 or int64 tensor expected but got a tensor with dtype: ", shape.scalar_type()); - TORCH_CHECK(level_start_index.scalar_type() == at::kInt || level_start_index.scalar_type() == at::kLong, - "level_start_index: int32 or int64 tensor expected but got a tensor with dtype: ", - level_start_index.scalar_type()); - TORCH_CHECK(location_trans.scalar_type() == at::kHalf || location_trans.scalar_type() == at::kFloat, - "sampling_locations: float16 or float32 tensor expected but got a tensor with dtype: ", - location_trans.scalar_type()); - TORCH_CHECK(attn_weight_trans.scalar_type() == at::kHalf || attn_weight_trans.scalar_type() == at::kFloat, - "attn_weight_trans: float16 or float32 tensor expected but got a tensor with dtype: ", - attn_weight_trans.scalar_type()); - TORCH_CHECK(grad_output.scalar_type() == at::kHalf || grad_output.scalar_type() == at::kFloat, - "grad_output: float16 or float32 tensor expected but got a tensor with dtype: ", grad_output.scalar_type()); - - auto ori_dtype = value_trans.scalar_type(); - auto value_trans_size = value_trans.sizes(); - auto location_trans_size = location_trans.sizes(); - auto attn_weight_trans_size = attn_weight_trans.sizes(); - auto num_heads = value_trans_size[1]; - auto embed_dims = value_trans_size[3]; - auto num_points = location_trans_size[3]; - auto num_levels = location_trans_size[2]; - auto data_total = embed_dims + num_points + num_levels; - TORCH_CHECK(data_total < 512, "data_total is over 512: embed_dims ", embed_dims, " num_points is ", num_points, - " num_level is ", num_levels, "."); - TORCH_CHECK(embed_dims % 8 == 0, "embed_dims must be a multiple of 8, but embed_dims is ", embed_dims, "."); - - at::Tensor grad_value_trans = at::zeros(value_trans_size, value_trans.options().dtype(at::kFloat)); - at::Tensor grad_location_trans = at::zeros(location_trans_size, location_trans.options().dtype(at::kFloat)); - at::Tensor grad_attn_weight_trans = - at::zeros(attn_weight_trans_size, attn_weight_trans.options().dtype(at::kFloat)); - - at::Tensor value_trans_fp = value_trans.to(at::kFloat); - at::Tensor shape_fp = shape.to(at::kInt); - at::Tensor level_start_index_fp = level_start_index.to(at::kInt); - at::Tensor sampling_locations_fp = location_trans.to(at::kFloat); - at::Tensor attn_weight_fp = attn_weight_trans.to(at::kFloat); - at::Tensor grad_output_fp = grad_output.to(at::kFloat); - EXEC_NPU_CMD(aclnnMultiScaleDeformableAttnGradV2, value_trans_fp, shape_fp, level_start_index_fp, - sampling_locations_fp, attn_weight_fp, grad_output_fp, grad_value_trans, grad_location_trans, - grad_attn_weight_trans); - return std::make_tuple( - grad_value_trans.to(ori_dtype), grad_location_trans.to(ori_dtype), grad_attn_weight_trans.to(ori_dtype)); -} - -std::tuple multi_scale_deformable_attn_grad(const at::Tensor& value, - const at::Tensor& shape, const at::Tensor& level_start_index, const at::Tensor& location, - const at::Tensor& attn_weight, const at::Tensor& grad_output) -{ - TORCH_CHECK(value.scalar_type() == at::kHalf || value.scalar_type() == at::kFloat, - "value: float16 or float32 tensor expected but got a tensor with dtype: ", value.scalar_type()); - TORCH_CHECK(shape.scalar_type() == at::kInt || shape.scalar_type() == at::kLong, - "spatial_shapes: int32 or int64 tensor expected but got a tensor with dtype: ", shape.scalar_type()); - TORCH_CHECK(level_start_index.scalar_type() == at::kInt || level_start_index.scalar_type() == at::kLong, - "level_start_index: int32 or int64 tensor expected but got a tensor with dtype: ", - level_start_index.scalar_type()); - TORCH_CHECK(location.scalar_type() == at::kHalf || location.scalar_type() == at::kFloat, - "sampling_locations: float16 or float32 tensor expected but got a tensor with dtype: ", location.scalar_type()); - TORCH_CHECK(attn_weight.scalar_type() == at::kHalf || attn_weight.scalar_type() == at::kFloat, - "attn_weight_trans: float16 or float32 tensor expected but got a tensor with dtype: ", - attn_weight.scalar_type()); - TORCH_CHECK(grad_output.scalar_type() == at::kHalf || grad_output.scalar_type() == at::kFloat, - "grad_output: float16 or float32 tensor expected but got a tensor with dtype: ", grad_output.scalar_type()); - - auto ori_dtype = value.scalar_type(); - auto value_size = value.sizes(); - auto location_size = location.sizes(); - auto shape_size = shape.sizes(); - auto attn_size = attn_weight.sizes(); - auto embed_dims = value_size[EMBED_DIMS_IDX]; - - TORCH_CHECK(embed_dims % 8 == 0, "embed_dims must be a multiple of 8, but embed_dims is ", embed_dims, "."); - - at::Tensor grad_value = at::zeros(value_size, value.options().dtype(at::kFloat)); - at::Tensor grad_location, grad_attn_weight; - if (shape_size[0] != attn_size[4]) { - grad_location = at::zeros(location_size, location.options().dtype(at::kFloat)); - grad_attn_weight = at::zeros(attn_size, attn_weight.options().dtype(at::kFloat)); - } else { - grad_location = at::empty(location_size, location.options().dtype(at::kFloat)); - grad_attn_weight = at::empty(attn_size, attn_weight.options().dtype(at::kFloat)); - } - - at::Tensor value_fp = value.to(at::kFloat); - at::Tensor shape_fp = shape.to(at::kInt); - at::Tensor level_start_index_fp = level_start_index.to(at::kInt); - at::Tensor sampling_locations_fp = location.to(at::kFloat); - at::Tensor attn_weight_fp = attn_weight.to(at::kFloat); - at::Tensor grad_output_fp = grad_output.to(at::kFloat); - EXEC_NPU_CMD(aclnnMultiScaleDeformableAttnGrad, value_fp, shape_fp, level_start_index_fp, sampling_locations_fp, - attn_weight_fp, grad_output_fp, grad_value, grad_location, grad_attn_weight); - return std::make_tuple(grad_value.to(ori_dtype), grad_location.to(ori_dtype), grad_attn_weight.to(ori_dtype)); -} diff --git a/mx_driving/fused/ops/csrc/functions.h b/mx_driving/fused/ops/csrc/functions.h index 0d3fa24a..7f52b6aa 100644 --- a/mx_driving/fused/ops/csrc/functions.h +++ b/mx_driving/fused/ops/csrc/functions.h @@ -19,13 +19,13 @@ at::Tensor npu_max_pool2d(const at::Tensor& x, int kernel_size, int stride, int padding); -at::Tensor npu_multi_scale_deformable_attn_function(const at::Tensor& value, const at::Tensor& value_spatial_shapes, +at::Tensor multi_scale_deformable_attn(const at::Tensor& value, const at::Tensor& value_spatial_shapes, const at::Tensor& value_level_start_index, const at::Tensor& sampling_locations, const at::Tensor& attention_weights); -std::tuple multi_scale_deformable_attn_grad(const at::Tensor& value, - const at::Tensor& shape, const at::Tensor& level_start_index, const at::Tensor& location_trans, - const at::Tensor& attn_weight_trans, const at::Tensor& grad_output); +std::tuple multi_scale_deformable_attn_backward(const at::Tensor& value, + const at::Tensor& value_spatial_shapes, const at::Tensor& value_level_start_index, + const at::Tensor& sampling_locations, const at::Tensor& attention_weights, const at::Tensor& grad_output); std::tuple multi_scale_deformable_attn_grad_v2(const at::Tensor& value, const at::Tensor& shape, const at::Tensor& level_start_index, const at::Tensor& location_trans, diff --git a/mx_driving/fused/ops/csrc/pybind.cpp b/mx_driving/fused/ops/csrc/pybind.cpp index d388cbfb..e082ab32 100644 --- a/mx_driving/fused/ops/csrc/pybind.cpp +++ b/mx_driving/fused/ops/csrc/pybind.cpp @@ -7,10 +7,9 @@ void init_fused(pybind11::module& m) { // nnpu_max_pool2d m.def("npu_max_pool2d", &npu_max_pool2d); - // npu_multi_scale_deformable_attn_function - m.def("npu_multi_scale_deformable_attn_function", &npu_multi_scale_deformable_attn_function); - m.def("multi_scale_deformable_attn_grad", &multi_scale_deformable_attn_grad); - m.def("multi_scale_deformable_attn_grad_v2", &multi_scale_deformable_attn_grad_v2); + // mullti_scale_deformable_attn + m.def("multi_scale_deformable_attn", &multi_scale_deformable_attn); + m.def("multi_scale_deformable_attn_backward", &multi_scale_deformable_attn_backward); // npu_add_relu m.def("npu_add_relu", &npu_add_relu); diff --git a/mx_driving/fused/ops/multi_scale_deformable_attn.py b/mx_driving/fused/ops/multi_scale_deformable_attn.py new file mode 100644 index 00000000..d2cdda81 --- /dev/null +++ b/mx_driving/fused/ops/multi_scale_deformable_attn.py @@ -0,0 +1,61 @@ +""" +Copyright (c) OpenMMLab. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +Modification by: Huawei Developers +Modification date: 2024-06-04 +Modification Description: +Modification 1. Add support for Ascend NPU +""" + +import warnings + +import torch +from torch.autograd.function import Function, once_differentiable + +import mx_driving._C + + +class MultiScaleDeformableAttnFunction(Function): + @staticmethod + # pylint: disable=too-many-arguments,huawei-too-many-arguments + def forward( + ctx, + value: torch.Tensor, + value_spatial_shapes: torch.Tensor, + value_level_start_index: torch.Tensor, + sampling_locations: torch.Tensor, + attention_weights: torch.Tensor, + ) -> torch.Tensor: + value_spatial_shapes = value_spatial_shapes.int() + value_level_start_index = value_level_start_index.int() + sampling_locations = sampling_locations.type_as(value) + attention_weights = attention_weights.type_as(value) + + output = mx_driving._C.multi_scale_deformable_attn( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights + ) + ctx.save_for_backward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights + ) + return output + + @staticmethod + @once_differentiable + # pylint: disable=too-many-return-values + def backward(ctx, grad_output: torch.Tensor) -> tuple: + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = mx_driving._C.multi_scale_deformable_attn_backward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output + ) + return grad_value, None, None, grad_sampling_loc, grad_attn_weight + + +multi_scale_deformable_attn = MultiScaleDeformableAttnFunction.apply + + +def npu_multi_scale_deformable_attn_function(value, shape, offset, locations, weight): + warnings.warn( + "`npu_multi_scale_deformable_attn_function` will be deprecated in future. Please use `multi_scale_deformable_attn` instead.", + DeprecationWarning, + ) + return MultiScaleDeformableAttnFunction.apply(value, shape, offset, locations, weight) diff --git a/mx_driving/fused/ops/npu_multi_scale_deformable_attn_function.py b/mx_driving/fused/ops/npu_multi_scale_deformable_attn_function.py deleted file mode 100644 index ffdb66c5..00000000 --- a/mx_driving/fused/ops/npu_multi_scale_deformable_attn_function.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -Copyright (c) OpenMMLab. All rights reserved. -Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. -Modification by: Huawei Developers -Modification date: 2024-06-04 -Modification Description: -Modification 1. Add support for Ascend NPU -""" - -import torch -from torch.autograd import Function -from torch.nn import Module - -import torch_npu -import mx_driving._C - - -class MultiScaleDeformableAttnFunction(Function): - @staticmethod - # 'pylint: disable=too-many-arguments,huawei-too-many-arguments - def forward(ctx, value, shape, offset, locations, weight): - result = mx_driving._C.npu_multi_scale_deformable_attn_function(value, shape, offset, locations, weight) - ctx.save_for_backward(value, shape, offset, locations, weight) - return result - - @staticmethod - def backward(ctx, grad_output): - value, shape, offset, locations, weight = ctx.saved_tensors - grad_value, grad_locations, grad_weight = mx_driving._C.multi_scale_deformable_attn_grad( - value, shape, offset, locations, weight, grad_output - ) - return grad_value, None, None, grad_locations, grad_weight - - -npu_multi_scale_deformable_attn_function = MultiScaleDeformableAttnFunction.apply diff --git a/mx_driving/fused/ops/onnx/wrapper_onnx_ops.py b/mx_driving/fused/ops/onnx/wrapper_onnx_ops.py index 04d8901e..12b6baa6 100644 --- a/mx_driving/fused/ops/onnx/wrapper_onnx_ops.py +++ b/mx_driving/fused/ops/onnx/wrapper_onnx_ops.py @@ -8,7 +8,7 @@ import mx_driving.fused class NPUMultiScaleDeformableAttnOP(torch.autograd.Function): @staticmethod def forward(ctx, *args, **kwargs): - return mx_driving.fused.npu_multi_scale_deformable_attn_function(*args, **kwargs) + return mx_driving.fused.multi_scale_deformable_attn(*args, **kwargs) @staticmethod # 'pylint: disable=too-many-arguments,huawei-too-many-arguments diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..8ea7351a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,26 @@ +[build-system] +requires = [ + "setuptools", + "wheel", + "numpy", + "pyyaml", + "cmake", + "typing-extensions>=4.10.0", +] + +[tool.black] +line-length = 120 +target-version = ["py38"] + +[tool.isort] +src_paths = ["mx_driving"] +extra_standard_library = ["typing_extensions"] +skip_gitignore = true +atomic = true +profile = "black" +indent = 4 +line_length = 120 +lines_after_imports = 2 +multi_line_output = 3 +include_trailing_comma = true +combine_as_imports = true diff --git a/tests/torch/test_multi_scale_deformable_attn_function.py b/tests/torch/test_multi_scale_deformable_attn_function.py index 3c945da6..e0d622c3 100644 --- a/tests/torch/test_multi_scale_deformable_attn_function.py +++ b/tests/torch/test_multi_scale_deformable_attn_function.py @@ -138,7 +138,7 @@ class TestMultiScaleDeformableAttnFunction(TestCase): npu_sampling_locations = npu_inputs.sampling_locations npu_attention_weights = npu_inputs.attention_weights npu_grad_output = npu_inputs.grad_output - npu_output = mx_driving.fused.npu_multi_scale_deformable_attn_function( + npu_output = mx_driving.fused.multi_scale_deformable_attn( npu_value, npu_shapes, npu_offset, npu_sampling_locations, npu_attention_weights ) npu_output.backward(npu_grad_output) -- Gitee