From 2ea56c163328a8343617e828a8648cc091f6dc89 Mon Sep 17 00:00:00 2001 From: guopeian Date: Wed, 13 Nov 2024 14:52:02 +0800 Subject: [PATCH] custom op --- .../aicore/flash_attention_score_ops.cc | 31 +++++++ .../aicore/incre_flash_attention_ops.cc.cc | 32 +++++++ tf_adapter/ops/aicore/npu_aicore_ops.cc | 85 +++++++++++++++++++ .../python/npu_bridge/tbe/npu_cube_ops.py | 38 +++++++++ .../tensorflow/npu_supported_ops.json | 8 ++ 5 files changed, 194 insertions(+) create mode 100644 tf_adapter/kernels/aicore/flash_attention_score_ops.cc create mode 100644 tf_adapter/kernels/aicore/incre_flash_attention_ops.cc.cc diff --git a/tf_adapter/kernels/aicore/flash_attention_score_ops.cc b/tf_adapter/kernels/aicore/flash_attention_score_ops.cc new file mode 100644 index 000000000..bf294d9ae --- /dev/null +++ b/tf_adapter/kernels/aicore/flash_attention_score_ops.cc @@ -0,0 +1,31 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2020. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tf_adapter/common/adapter_logger.h" + +namespace tensorflow { +class FlashAttentionScoreOp : public OpKernel { + public: + explicit FlashAttentionScoreOp(OpKernelConstruction *context) : OpKernel(context) {} + ~FlashAttentionScoreOp() override = default; + void Compute(OpKernelContext *context) override { + ADP_LOG(INFO) << "FlashAttentionScoreOp Compute, num_inputs: " << context->num_inputs(); + } +}; + +REGISTER_KERNEL_BUILDER(Name("NpuFlashAttentionScore").Device(DEVICE_CPU), FlashAttentionScoreOp); +} // namespace tensorflow diff --git a/tf_adapter/kernels/aicore/incre_flash_attention_ops.cc.cc b/tf_adapter/kernels/aicore/incre_flash_attention_ops.cc.cc new file mode 100644 index 000000000..93a8e1ca8 --- /dev/null +++ b/tf_adapter/kernels/aicore/incre_flash_attention_ops.cc.cc @@ -0,0 +1,32 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2020. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tf_adapter/common/adapter_logger.h" + +namespace tensorflow { +class IncreFlashAttentionOp : public OpKernel { + public: + explicit IncreFlashAttentionOp(OpKernelConstruction *context) : OpKernel(context) {} + ~IncreFlashAttentionOp() override = default; + void Compute(OpKernelContext *context) override { + ADP_LOG(INFO) << "IncreFlashAttentionOp Compute, num_inputs: " << context->num_inputs(); + } + bool IsExpensive() override { return false; } +}; + +REGISTER_KERNEL_BUILDER(Name("NpuIncreFlashAttention").Device(DEVICE_CPU), IncreFlashAttentionOp); +} // namespace tensorflow diff --git a/tf_adapter/ops/aicore/npu_aicore_ops.cc b/tf_adapter/ops/aicore/npu_aicore_ops.cc index fbe6035bd..d3f252701 100644 --- a/tf_adapter/ops/aicore/npu_aicore_ops.cc +++ b/tf_adapter/ops/aicore/npu_aicore_ops.cc @@ -265,6 +265,91 @@ c->set_output(6, output_dw_att_shape); return Status::OK(); }); +REGISTER_OP("NpuFlashAttentionScore") + .Input("query: T") + .Input("key: T") + .Input("value: T") + .Input("real_shift: real_shift_type") + .Input("drop_mask: drop_mask_type") + .Input("padding_mask: padding_mask_type") + .Input("atten_mask: atten_mask_type") + .Input("prefix: prefix_type") + .Input("actual_seq_qlen: actual_seq_qlen_type") + .Input("actual_seq_kvlen: actual_seq_kvlen_type") + .Input("q_start_idx: q_start_idx_type") + .Input("kv_start_idx: kv_start_idx_type") + .Output("softmax_max: float32") + .Output("softmax_sum: float32") + .Output("softmax_out: T") + .Output("attention_out: T") + .Attr("scale_value: float = 1.0") + .Attr("keep_prob: float = 1.0") + .Attr("pre_tockens: int = 2147483647") + .Attr("next_tockens: int = 2147483647") + .Attr("head_num: int") + .Attr("input_layout: string") + .Attr("inner_precise: int = 0") + .Attr("sparse_mode: int = 0") + .Attr("pse_type: int = 1") + .Attr("T: {float16, float32, bfloat16} = DT_FLOAT") + .Attr("real_shift_type: list({float16, float32, bfloat16}) >= 0") + .Attr("drop_mask_type: list({uint8}) >= 0") + .Attr("padding_mask_type: list({float16, float32, bfloat16}) >= 0") + .Attr("atten_mask_type: list({bool, uint8}) >= 0") + .Attr("prefix_type: list({int64}) >= 0") + .Attr("actual_seq_qlen_type: list({int64}) >= 0") + .Attr("actual_seq_kvlen_type: list({int64}) >= 0") + .Attr("q_start_idx_type: list({int64}) >= 0") + .Attr("kv_start_idx_type: list({int64}) >= 0") + .SetShapeFn([](InferenceContext *c) { + return Status::OK(); + }); + +REGISTER_OP("NpuIncreFlashAttention") + .Input("query: T") + .Input("key: N * T") + .Input("value: N * T") + .Input("pse_shift: pse_shift_type") + .Input("atten_mask: atten_mask_type") + .Input("actual_seq_lengths: actual_seq_lengths_type") + .Input("dequant_scale1: dequant_scale1_type") + .Input("quant_scale1: quant_scale1_type") + .Input("dequant_scale2: dequant_scale2_type") + .Input("quant_scale2: quant_scale2_type") + .Input("quant_offset2: quant_offset2_type") + .Input("antiquant_scale: antiquant_scale_type") + .Input("antiquant_offset: antiquant_offset_type") + .Input("block_table: block_table_type") + .Input("kv_padding_size: kv_padding_size_type") + .Output("attention_out: T") + .Attr("num_heads: int") + .Attr("scale_value: float = 1.0") + .Attr("input_layout: string = 'BSH'") + .Attr("num_key_value_heads: int = 1") + .Attr("block_size: int = 0") + .Attr("inner_precise: int = 1") + .Attr("T: {float16, int8, bfloat16} = DT_INT8") + .Attr("M: {float16, bfloat16, float32} = DT_FLOAT") + .Attr("S: {bool, int8, uint8} = DT_UINT8") + .Attr("R: {uint64, float32} = DT_FLOAT") + .Attr("H: {float32, bfloat16} = DT_FLOAT") + .Attr("N: int >= 0") + .Attr("pse_shift_type: list(type) >= 0") + .Attr("atten_mask_type: list({bool, int8, uint8}) >= 0") + .Attr("actual_seq_lengths_type: list({int64}) >= 0") + .Attr("dequant_scale1_type: list({uint64, float32}) >= 0") + .Attr("quant_scale1_type: list({float32}) >= 0") + .Attr("dequant_scale2_type: list({uint64, float32}) >= 0") + .Attr("quant_scale2_type: list({float32, bfloat16}) >= 0") + .Attr("quant_offset2_type: list({float32, bfloat16}) >= 0") + .Attr("antiquant_scale_type: list(type) >= 0") + .Attr("antiquant_offset_type: list(type) >= 0") + .Attr("block_table_type: list({int32}) >= 0") + .Attr("kv_padding_size_type: list({int64}) >= 0") + .SetShapeFn([](InferenceContext *c) { + return Status::OK(); + }); + REGISTER_OP("DynamicRnn") .Input("x: T") .Input("w: T") diff --git a/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py index 1f381db5d..6cf68100a 100644 --- a/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py +++ b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py @@ -20,6 +20,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import nn_ops from npu_bridge.helper import helper +import tensorflow as tf gen_npu_ops = helper.get_gen_ops() @@ -117,6 +118,43 @@ def deformable_conv2d( # pylint: disable=redefined-builtin return op_res +def create_optional_input_list(input): + input_list = [] + if not input is None: + input_list.append(input) + return input_list + + +def flash_attention_score(query, key, value, head_num, input_layout, real_shift=None, drop_mask=None, padding_mask=None, + atten_mask=None, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, + q_start_idx=None, kv_start_idx=None, scale_value=1.0, keep_prob=1.0, + pre_tockens=2147483647, next_tockens=2147483647, inner_precise=0, sparse_mode=0, + pse_type=1): + output = gen_npu_ops.npu_flash_attention_score(query=query, key=key, value=value, + real_shift=create_optional_input_list(real_shift), drop_mask=create_optional_input_list(drop_mask), + padding_mask=create_optional_input_list(padding_mask), atten_mask=create_optional_input_list(atten_mask), + prefix=create_optional_input_list(prefix), actual_seq_qlen=create_optional_input_list(actual_seq_qlen), + actual_seq_kvlen=create_optional_input_list(actual_seq_kvlen), q_start_idx=create_optional_input_list(q_start_idx), + kv_start_idx=create_optional_input_list(kv_start_idx), scale_value=scale_value, keep_prob=keep_prob, + pre_tockens=pre_tockens, next_tockens=next_tockens, head_num=head_num, input_layout=input_layout, + inner_precise=inner_precise, sparse_mode=sparse_mode, pse_type=pse_type) + return output + +def incre_flash_attention(query, key, value, num_heads, pse_shift=None, atten_mask=None, actual_seq_lengths=None, + dequant_scale1=None, quant_scale1=None, dequant_scale2=None, quant_scale2=None, + quant_offset2=None, antiquant_scale=None, antiquant_offset=None, block_table=None, + kv_padding_size=None, scale_value=1.0, input_layout='BSH', num_key_value_heads=1, + block_size=0, inner_precise=1): + output = gen_npu_ops.npu_incre_flash_attention(query=query, key=key, value=value, + pse_shift=create_optional_input_list(pse_shift), atten_mask=create_optional_input_list(atten_mask), + actual_seq_lengths=create_optional_input_list(actual_seq_lengths), dequant_scale1=create_optional_input_list(dequant_scale1), + quant_scale1=create_optional_input_list(quant_scale1), dequant_scale2=create_optional_input_list(dequant_scale2), + quant_scale2=create_optional_input_list(quant_scale2), quant_offset2=create_optional_input_list(quant_offset2), + antiquant_scale=create_optional_input_list(antiquant_scale), antiquant_offset=create_optional_input_list(antiquant_offset), + block_table=create_optional_input_list(block_table), kv_padding_size=create_optional_input_list(kv_padding_size), + num_heads=num_heads, scale_value=scale_value, input_layout=input_layout, num_key_value_heads=num_key_value_heads, + block_size=block_size, inner_precise=inner_precise) + return output @ops.RegisterGradient("DeformableOffsets") def deformable_offsets_grad(op, grad): diff --git a/tf_adapter/tests/depends/support_json/framework/built-in/tensorflow/npu_supported_ops.json b/tf_adapter/tests/depends/support_json/framework/built-in/tensorflow/npu_supported_ops.json index 969388131..1256f519e 100644 --- a/tf_adapter/tests/depends/support_json/framework/built-in/tensorflow/npu_supported_ops.json +++ b/tf_adapter/tests/depends/support_json/framework/built-in/tensorflow/npu_supported_ops.json @@ -23,6 +23,14 @@ "isGray": false, "isHeavy": false }, + "FlashAttentionScore": { + "isGray": false, + "isHeavy": false + }, + "IncreFlashAttention": { + "isGray": false, + "isHeavy": false + }, "Enter": { "isGray": false, "isHeavy": false -- Gitee