diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 3d1e2cf3adbcb2559e838f853a56ee13a79baa13..f3d91cc8e28a8c5120acffb91c92474a34dcb89b 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -2003,6 +2003,8 @@ custom: bscpp_op: True - func: npu_rotary_mul(Tensor self, Tensor r1, Tensor r2) -> Tensor - func: npu_rotary_mul_backward(Tensor grad, Tensor self, Tensor r1, Tensor r2) -> (Tensor, Tensor, Tensor) + - func: npu_prompt_flash_attention(Tensor query, Tensor key, Tensor value, int head_num, str input_layout, Tensor? padding_mask=None, Tensor? atten_mask=None, Tensor? actual_seq_lengths=None, float scale=1., int pre_tockens=2147483647) -> Tensor + - func: npu_incre_flash_attention(Tensor query, Tensor key, Tensor value, int head_num, str input_layout, Tensor? padding_mask=None, Tensor? atten_mask=None, Tensor? actual_seq_lengths=None, float scale=1.) -> Tensor custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor diff --git a/torch_npu/csrc/aten/ops/op_api/FlashAttentionKernelNpuOpApi.cpp b/torch_npu/csrc/aten/ops/op_api/FlashAttentionKernelNpuOpApi.cpp index dc5cf179f20bca0866377d3695ab93a057818f37..403b00ba1bab990f59245ace40452ac77fa3702e 100644 --- a/torch_npu/csrc/aten/ops/op_api/FlashAttentionKernelNpuOpApi.cpp +++ b/torch_npu/csrc/aten/ops/op_api/FlashAttentionKernelNpuOpApi.cpp @@ -402,5 +402,39 @@ std::vector NPUNativeFunctions::npu_flash_attention( return NPUFlashAttentionFunction::apply(query, key, value, head_num, input_layout, pse, padding_mask, atten_mask, scale, keep_prob, pre_tockens, next_tockens, gen_mask_parallel, sync); } + +at::Tensor NPUNativeFunctions::npu_prompt_flash_attention( + const at::Tensor &query, const at::Tensor &key, + const at::Tensor &value, int64_t head_num, c10::string_view input_layout, + const c10::optional &padding_mask, + const c10::optional &atten_mask, + const c10::optional &actual_seq_lengths, + double scale, int64_t pre_tockens) +{ + auto output = OpPreparation::ApplyTensor(query); + std::string input_layout_str = std::string(input_layout); + char* input_layout_ptr = const_cast(input_layout_str.c_str()); + EXEC_NPU_NO_FORMAT_CHECK_CMD( + aclnnPromptFlashAttention, query, key, value, padding_mask, atten_mask, + actual_seq_lengths, head_num, input_layout_ptr, scale, pre_tockens, output); + return output; //TODO confirm this part +} + +at::Tensor NPUNativeFunctions::npu_incre_flash_attention( + const at::Tensor &query, const at::Tensor &key, + const at::Tensor &value, int64_t head_num, c10::string_view input_layout, + const c10::optional &padding_mask, + const c10::optional &atten_mask, + const c10::optional &actual_seq_lengths, + double scale, int64_t pre_tockens) +{ + auto output = OpPreparation::ApplyTensor(query); + std::string input_layout_str = std::string(input_layout); + char* input_layout_ptr = const_cast(input_layout_str.c_str()); + EXEC_NPU_NO_FORMAT_CHECK_CMD( + aclnnIncreFlashAttention, query, key, value, padding_mask, atten_mask, + actual_seq_lengths, head_num, input_layout_ptr, scale, output); + return output; //TODO confirm this part +} } // namespace native } // namespace at_npu