From c894d481147ead3d26fa1c85bf3ccc94c5b6de14 Mon Sep 17 00:00:00 2001 From: "@ding-jing12" <@ding-jing12> Date: Wed, 27 Nov 2024 15:51:05 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0precisonCompare=E5=AF=B9?= =?UTF-8?q?=E5=A4=96=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tf_adapter/kernels/aicpu/npu_cpu_ops.cc | 8 ++++++++ tf_adapter/ops/aicpu/npu_cpu_ops.cc | 8 ++++++++ tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py | 11 +++++++++++ 3 files changed, 27 insertions(+) diff --git a/tf_adapter/kernels/aicpu/npu_cpu_ops.cc b/tf_adapter/kernels/aicpu/npu_cpu_ops.cc index ff9dc6d2f..39b2a3218 100644 --- a/tf_adapter/kernels/aicpu/npu_cpu_ops.cc +++ b/tf_adapter/kernels/aicpu/npu_cpu_ops.cc @@ -221,6 +221,13 @@ class ScatterElementsV2Op : public OpKernel { } }; +class PrecisionCompareOp : public OpKernel { + public: + explicit PrecisionCompareOp(OpKernelConstruction *context) : OpKernel(context) {} + ~PrecisionCompareOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "PrecisionCompareOp Compute"; } +}; + REGISTER_KERNEL_BUILDER(Name("ScatterElementsV2").Device(DEVICE_CPU), ScatterElementsV2Op); REGISTER_KERNEL_BUILDER(Name("EmbeddingRankId").Device(DEVICE_CPU), EmbeddingRankIdOpKernel); REGISTER_KERNEL_BUILDER(Name("EmbeddingLocalIndex").Device(DEVICE_CPU), EmbeddingLocalIndexOpKernel); @@ -242,6 +249,7 @@ REGISTER_KERNEL_BUILDER(Name("NonZeroWithValueShape").Device(DEVICE_CPU), NonZer REGISTER_KERNEL_BUILDER(Name("NonZeroWithValueShapeV2").Device(DEVICE_CPU), NonZeroWithValueShapeV2Op); REGISTER_KERNEL_BUILDER(Name("WarpAffineV2").Device(DEVICE_CPU), WarpAffineV2Op); REGISTER_KERNEL_BUILDER(Name("ResizeV2").Device(DEVICE_CPU), ResizeV2Op); +REGISTER_KERNEL_BUILDER(Name("PrecisionCompare").Device(DEVICE_CPU), PrecisionCompareOp); class DecodeImageV3Op : public OpKernel { public: diff --git a/tf_adapter/ops/aicpu/npu_cpu_ops.cc b/tf_adapter/ops/aicpu/npu_cpu_ops.cc index 40857b449..048ea0c04 100644 --- a/tf_adapter/ops/aicpu/npu_cpu_ops.cc +++ b/tf_adapter/ops/aicpu/npu_cpu_ops.cc @@ -627,4 +627,12 @@ REGISTER_OP("ResizeV2") .Output("image: dtype") .Attr("expand_animations: bool = true") .SetShapeFn(DecodeImageV3ShapeFn); + + REGISTER_OP("PrecisionCompare") + .Input("x1: T") + .Input("x2: T") + .Output("y: uint32") + .Attr("detect_type: int = 0") + .Attr("T: {float32, float16, bfloat16}") + .SetShapeFn(tensorflow::shape_inference::ScalarShape); } // namespace tensorflow diff --git a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py index b247e96a9..c49ef9fc5 100644 --- a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py +++ b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py @@ -465,3 +465,14 @@ def host_feature_mapping_import(path): """ host feature mapping export. """ result = gen_npu_cpu_ops.FeatureMappingImport(path=path) return result + +## 提供硬件精度检测功能 +# @param golden 支持float/float16/bfloat16类型 +# @param realdata 支持float/float16/bfloat16类型 +# @param detect_type int32 类型, 取值范围:0/1/2 +# @return uint32 +def precision_compare(golden, realdata, detect_type): + """ precision compare. """ + result = gen_npu_cpu_ops.PrecisionCompare( + x1=golden, x2=realdata, detect_type=detect_type) + return result -- Gitee