From ce90fd4dbfd96988fb49d5171370b7444cbf5481 Mon Sep 17 00:00:00 2001 From: rayshine <1324789704@qq.com> Date: Fri, 29 Aug 2025 16:18:45 +0800 Subject: [PATCH] =?UTF-8?q?Add=20UT=20for=20kdnn:gemm=20&=20kpGather?= =?UTF-8?q?=E3=80=81SparseReshape=E3=80=81SparseSelect?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tensorflow/core/kernels/BUILD | 57 ++ .../core/kernels/embedding_fused_gather.cc | 13 +- .../kernels/embedding_fused_gather_test.cc | 200 ++++++ .../embedding_fused_sparse_reshape_test.cc | 236 +++++++ .../kernels/embedding_fused_sparse_select.cc | 4 +- .../embedding_fused_sparse_select_test.cc | 182 ++++++ tensorflow/core/kernels/matmul_op_test.cc | 618 ++++++++++++++++++ .../fused_embedding_gather_test.py | 17 +- 8 files changed, 1314 insertions(+), 13 deletions(-) create mode 100644 tensorflow/core/kernels/embedding_fused_gather_test.cc create mode 100644 tensorflow/core/kernels/embedding_fused_sparse_reshape_test.cc create mode 100644 tensorflow/core/kernels/embedding_fused_sparse_select_test.cc diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 13d3a7641..8c30f1034 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4093,6 +4093,25 @@ tf_kernel_library( deps = MATH_DEPS, ) +tf_cc_test( + name = "embedding_fused_gather_test", + srcs = if_enable_annc([ + "embedding_fused_gather_test.cc", + ]), + deps = [ + ":ops_testutil", + ":ops_util", + ":embedding_fused_gather_op", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_kernel_library( name = "embedding_fused_padding_op", srcs = if_enable_annc([ @@ -4119,6 +4138,25 @@ tf_kernel_library( ], ) +tf_cc_test( + name = "embedding_fused_sparse_reshape_test", + srcs = if_enable_annc([ + "embedding_fused_sparse_reshape_test.cc", + ]), + deps = [ + ":ops_testutil", + ":ops_util", + ":embedding_fused_reshape_op", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_kernel_library( name = "embedding_fused_sparse_segment_reduce_op", srcs = if_enable_annc([ @@ -4143,6 +4181,25 @@ tf_kernel_library( deps = MATH_DEPS, ) +tf_cc_test( + name = "embedding_fused_sparse_select_test", + srcs = if_enable_annc([ + "embedding_fused_sparse_select_test.cc", + ]), + deps = [ + ":ops_testutil", + ":ops_util", + ":embedding_fused_sparse_select_op", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_library( name = "embedding_fused_ops", deps = if_enable_annc([ diff --git a/tensorflow/core/kernels/embedding_fused_gather.cc b/tensorflow/core/kernels/embedding_fused_gather.cc index 8a3a585a4..a93912255 100644 --- a/tensorflow/core/kernels/embedding_fused_gather.cc +++ b/tensorflow/core/kernels/embedding_fused_gather.cc @@ -29,10 +29,10 @@ class KPFusedGather : public OpKernel { const Tensor& begin = context->input(2); OP_REQUIRES(context, slice_input.dims() == 2, errors::Internal("slice_input dims must == 2")); - OP_REQUIRES(context, data.dims() == 2, errors::Internal("indentity dims must == 2")); - OP_REQUIRES(context, data.dim_size(1) == 12, errors::Internal("indentity dim size must == [n, 12]")); + OP_REQUIRES(context, data.dims() == 2, errors::Internal("identity dims must == 2")); + OP_REQUIRES(context, data.dim_size(1) == 12, errors::Internal("identity dim size must == [n, 12]")); - VLOG(1) << "Input indentity shape: " << data.shape().DebugString(); + VLOG(1) << "Input identity shape: " << data.shape().DebugString(); VLOG(1) << "Input slice_input shape: " << slice_input.shape().DebugString(); VLOG(1) << "Input slice_input: " << slice_input.SummarizeValue(1000); VLOG(1) << "Input begin value: " << begin.SummarizeValue(10); @@ -73,18 +73,17 @@ class KPFusedGather : public OpKernel { context->allocate_output( 1, TensorShape({static_cast(indices.size())}), &out_indices)); std::memcpy(out_indices->data(), indices.data(), indices.size() * sizeof(int32_t)); - OP_REQUIRES(context, data.dim_size(1) * unique_values.size() % 12 == 0, - errors::Internal("cannot reshape to [-1, 12]")); OP_REQUIRES_OK(context, context->allocate_output( - 2, TensorShape({unique_values.size(), 12}), &out_data)); + 2, TensorShape({unique_values.size(), data.dim_size(1)}), &out_data)); auto output_data = out_data->matrix(); + int64_t data_row = data.dim_size(0); int64_t cols = data.dim_size(1); for (int64_t cur_row = 0; cur_row < unique_values.size(); ++cur_row) { int64_t idx = unique_values[cur_row]; - OP_REQUIRES(context, idx < data.dim_size(0), errors::Internal("idx must < data_row")); + OP_REQUIRES(context, idx < data_row, errors::Internal("idx must < data_row")); const float* src = data_mat.data() + idx * cols; float* dst = output_data.data() + cur_row * cols; std::memcpy(dst, src, cols * sizeof(float)); diff --git a/tensorflow/core/kernels/embedding_fused_gather_test.cc b/tensorflow/core/kernels/embedding_fused_gather_test.cc new file mode 100644 index 000000000..ef93bfb3f --- /dev/null +++ b/tensorflow/core/kernels/embedding_fused_gather_test.cc @@ -0,0 +1,200 @@ +/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace { +using tensorflow::AllocatorAttributes; +using tensorflow::DT_FLOAT; +using tensorflow::DT_INT32; +using tensorflow::DT_INT64; +using tensorflow::int64; +using tensorflow::int32; +using tensorflow::NodeDefBuilder; +using tensorflow::OpsTestBase; +using tensorflow::Status; +using tensorflow::Tensor; +using tensorflow::TensorShape; +using tensorflow::test::ExpectClose; +using tensorflow::test::FillValues; +using tensorflow::test::AsTensor; +using tensorflow::test::ExpectTensorEqual; + +class KPFusedGatherTest : public OpsTestBase { + protected: + void RunValidCase(const TensorShape& data_shape, + const TensorShape& slice_shape, + const std::vector& begin_val, + const std::vector& slice_data, + const std::vector& data_data, + const std::vector& expected_unique, + const std::vector& expected_indices, + const std::vector& expected_output_data) { + TF_EXPECT_OK(NodeDefBuilder("kp_fused_gather", "KPFusedGather") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_INT64)) + .Input(FakeInput(DT_INT32)) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + + AddInputFromArray(data_shape, data_data); + AddInputFromArray(slice_shape, slice_data); + AddInputFromArray(TensorShape({2}), begin_val); + + TF_ASSERT_OK(RunOpKernel()); + + const Tensor& out_unique = *GetOutput(0); + const Tensor& out_indices = *GetOutput(1); + const Tensor& out_data = *GetOutput(2); + + // 验证输出0: unique_values + Tensor expected_unique_tensor( + allocator(), DT_INT64, + TensorShape({static_cast(expected_unique.size())}) + ); + FillValues(&expected_unique_tensor, expected_unique); + ExpectTensorEqual(expected_unique_tensor, out_unique); + + // 验证输出1: indices + Tensor expected_indices_tensor( + allocator(), DT_INT32, + TensorShape({static_cast(expected_indices.size())}) + ); + FillValues(&expected_indices_tensor, expected_indices); + ExpectTensorEqual(expected_indices_tensor, out_indices); + + // 验证输出2: out_data + Tensor expected_data_tensor(allocator(), DT_FLOAT, + TensorShape({static_cast(expected_unique.size()), 12})); + FillValues(&expected_data_tensor, expected_output_data); + ExpectClose(expected_data_tensor, out_data); // float 用 ExpectClose + } + + Status RunOpExpectFailure(const TensorShape& data_shape, + const TensorShape& slice_shape, + const std::vector& begin_val, + const std::vector& slice_data, + const std::vector& data_data) { + TF_CHECK_OK(NodeDefBuilder("kp_fused_gather", "KPFusedGather") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_INT64)) + .Input(FakeInput(DT_INT32)) + .Finalize(node_def())); + TF_CHECK_OK(InitOp()); + + AddInputFromArray(data_shape, data_data); + AddInputFromArray(slice_shape, slice_data); + AddInputFromArray(TensorShape({2}), begin_val); + + return RunOpKernel(); + } +}; + +// 正向测试:正常输入 +TEST_F(KPFusedGatherTest, Valid_NormalInput) { + RunValidCase( + TensorShape({2, 12}), // data shape + TensorShape({4, 3}), // slice_input shape + {0, 1}, // begin[1] = 1 → 取第1列 + {1, 1, 3, + 0, 1, 5, + 1, 0, 7, + 0, 1, 9}, // slice_input 数据 + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, + 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f}, + {1, 0}, // unique values from col=1 + {0, 0, 1, 0}, // indices mapping + {13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, // data[1] + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} // data[0] + ); +} + +// 反例1:data不是2维 +TEST_F(KPFusedGatherTest, Invalid_DataDimsNot2) { + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; + Status s = RunOpExpectFailure( + TensorShape({4}), // data 不是二维 + TensorShape({2, 2}), + {0, 0}, + {0, 1, 2, 3}, + data + ); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find("identity dims must == 2") != std::string::npos); +} + +// 反例2:data 第二维不是12 +TEST_F(KPFusedGatherTest, Invalid_DataDimSizeNot12) { + std::vector data(2 * 10, 1.0f); + Status s = RunOpExpectFailure( + TensorShape({2, 10}), // data 第二维不是12 + TensorShape({2, 2}), + {0, 0}, + {0, 1, 2, 3}, + data + ); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find("identity dim size must == [n, 12]") != std::string::npos); +} + +// 反例3:slice_input 不是2维 +TEST_F(KPFusedGatherTest, Invalid_SliceInputDimsNot2) { + std::vector data(2 * 12, 1.0f); + Status s = RunOpExpectFailure( + TensorShape({2, 12}), + TensorShape({4}), // 1D slice_input + {0, 0}, + {0, 1, 2, 3}, + data + ); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find("slice_input dims must == 2") != std::string::npos); +} + +// 反例4: begin[1] 超出列范围 +TEST_F(KPFusedGatherTest, Invalid_BeginColOutOfRange) { + std::vector data(2 * 12, 1.0f); + Status s = RunOpExpectFailure( + TensorShape({2, 12}), + TensorShape({2, 2}), + {0, 2}, // begin[1] = 2,但只有 2 列 → 索引 0,1 + {0, 1, 2, 3}, + data + ); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find("begin[1] must < slice_input.dim_size(1)") != std::string::npos); +} + +// 反例5: gather 索引超出 data 行数 +TEST_F(KPFusedGatherTest, Invalid_IndexOutOfRangeInData) { + std::vector data(2 * 12, 1.0f); + Status s = RunOpExpectFailure( + TensorShape({2, 12}), + TensorShape({2, 2}), + {0, 0}, + {0, 1, + 2, 3}, // 索引 2 超出 data 行数(只有 0,1) + data + ); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find("idx must < data_row") != std::string::npos); +} + +} \ No newline at end of file diff --git a/tensorflow/core/kernels/embedding_fused_sparse_reshape_test.cc b/tensorflow/core/kernels/embedding_fused_sparse_reshape_test.cc new file mode 100644 index 000000000..c96b685cb --- /dev/null +++ b/tensorflow/core/kernels/embedding_fused_sparse_reshape_test.cc @@ -0,0 +1,236 @@ +/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace { +using tensorflow::AllocatorAttributes; +using tensorflow::DT_FLOAT; +using tensorflow::DT_INT32; +using tensorflow::DT_INT64; +using tensorflow::int64; +using tensorflow::int32; +using tensorflow::NodeDefBuilder; +using tensorflow::OpsTestBase; +using tensorflow::Status; +using tensorflow::Tensor; +using tensorflow::TensorShape; +using tensorflow::test::FillValues; +using tensorflow::test::ExpectTensorEqual; + +class KPFusedSparseReshapeTest : public OpsTestBase { + protected: + void RunValidCase(const TensorShape& slice_shape, + const std::vector& slice_data, + const std::vector& begin_val, + const std::vector& new_shape_val, + const std::vector& pack_const_val, + const TensorShape& expected_indices_shape, + const std::vector& expected_shape_val) { + TF_EXPECT_OK(NodeDefBuilder("kp_fused_sparse_reshape", "KPFusedSparseReshape") + .Input(FakeInput(DT_INT64)) // slice_input + .Input(FakeInput(DT_INT32)) // begin + .Input(FakeInput(DT_INT64)) // new_shape + .Input(FakeInput(DT_INT64)) // pack_const + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + + AddInputFromArray(slice_shape, slice_data); + AddInputFromArray(TensorShape({2}), begin_val); + AddInputFromArray(TensorShape({2}), new_shape_val); + AddInputFromArray(TensorShape({1}), pack_const_val); + + TF_ASSERT_OK(RunOpKernel()); + + // 输出0: result_indices + const Tensor& out_indices = *GetOutput(0); + EXPECT_EQ(out_indices.shape(), expected_indices_shape); + + // 输出1: result_shape + const Tensor& out_shape = *GetOutput(1); + Tensor expected_shape_tensor(DT_INT64, + TensorShape({static_cast(expected_shape_val.size())})); + FillValues(&expected_shape_tensor, expected_shape_val); + ExpectTensorEqual(expected_shape_tensor, out_shape); + } + + Status RunOpExpectFailure(const TensorShape& slice_shape, + const std::vector& slice_data, + const std::vector& begin_val, + const std::vector& new_shape_val, + const std::vector& pack_const_val) { + TF_CHECK_OK(NodeDefBuilder("kp_fused_sparse_reshape", "KPFusedSparseReshape") + .Input(FakeInput(DT_INT64)) // slice_input + .Input(FakeInput(DT_INT32)) // begin + .Input(FakeInput(DT_INT64)) // new_shape + .Input(FakeInput(DT_INT64)) // pack_const + .Finalize(node_def())); + TF_CHECK_OK(InitOp()); + + AddInputFromArray(slice_shape, slice_data); + AddInputFromArray(TensorShape({2}), begin_val); + AddInputFromArray(TensorShape({static_cast(new_shape_val.size())}), new_shape_val); + AddInputFromArray(TensorShape({1}), pack_const_val); + + return RunOpKernel(); + } +}; + +// ==================== 正向测试 ==================== + +// 正常 reshape 案例 +// pack_const=2 +TEST_F(KPFusedSparseReshapeTest, Valid_NormalInput) { + RunValidCase( + TensorShape({4, 2}), // slice_input shape + {0, 1, + 1, 2, + 2, 3, + 3, 0}, // slice_input 数据 + {0, 1}, // begin = (0,1),选第1列 + {2, 4}, // new_shape = [2,4] + {2}, // pack_const = [2] + TensorShape({4, 2}), // 预期 indices 形状 + {2, 4}); // 预期 shape +} + +// pack_const = 1 +TEST_F(KPFusedSparseReshapeTest, Valid_PackConst1) { + RunValidCase( + TensorShape({1, 2}), // slice_input shape + {0, 1}, // slice_input 数据 + {0, 1}, // begin = (0,1),选第1列 + {-1, 1}, // new_shape = [-1,1] + {1}, // pack_const = [1] + TensorShape({1, 2}), // 预期 indices 形状 + {1, 1}); // 预期 shape +} + +// ==================== 反向测试 ==================== + +// 反例1:slice_input 不是二维 +TEST_F(KPFusedSparseReshapeTest, Invalid_SliceInputNot2D) { + Status s = RunOpExpectFailure( + TensorShape({4}), {0, 1, 2, 3}, + {0, 0}, + {2, 2}, + {4}); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find("slice_input dims must == 2") != std::string::npos); +} + +// 反例2:new_shape dim size 不是 2 +TEST_F(KPFusedSparseReshapeTest, Invalid_NewShapeNotLen2) { + Status s = RunOpExpectFailure( + TensorShape({2, 2}), {0, 1, 1, 0}, + {0, 0}, + {4, 2, 1}, // new_shape 多了1个元素 + {2}); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find("new_shape dim size must == 2") != std::string::npos); +} + +// 反例3:begin[1] 超出 slice_input 列数 +TEST_F(KPFusedSparseReshapeTest, Invalid_BeginOutOfRange) { + Status s = RunOpExpectFailure( + TensorShape({2, 2}), {0, 1, 1, 0}, + {0, 2}, // 超过列数 + {2, 2}, + {2}); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find("begin[1] must < slice_input.dim_size(1)") != std::string::npos); +} + +// 反例4:target shape 有多个 -1 +TEST_F(KPFusedSparseReshapeTest, Invalid_MultipleUnknownDims) { + Status s = RunOpExpectFailure( + TensorShape({2, 2}), {0, 1, 1, 0}, + {0, 1}, + {-1, -1}, // 两个 -1 + {2}); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find("only one output dimension may be -1") != std::string::npos); +} + +// 反例5:reshape 推断维度时,总元素数不能整除,导致无法匹配 --> product * missing != dense_size +TEST_F(KPFusedSparseReshapeTest, Invalid_InferredShapeDoesNotMatch) { + TensorShape input_indices_shape({6, 2}); // 6 个非零元素,rank=2 + std::vector input_indices_data = { + 0, 0, + 0, 1, + 0, 2, + 1, 0, + 1, 1, + 1, 2 + }; // 对应 2x3 的 dense tensor + + std::vector begin_val = {0, 0}; // 假设的 begin 输入 + std::vector new_shape_val = {-1, 4}; // reshape 到 ?x4 + std::vector pack_const_val = {1}; + + Status s = RunOpExpectFailure( + input_indices_shape, + input_indices_data, + begin_val, + new_shape_val, + pack_const_val); + + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find("Input to reshape is a SparseTensor with") != std::string::npos); +} + +// 反例6:reshape 后元素数量不匹配 --> output_shape.num_elements() != dense_size +TEST_F(KPFusedSparseReshapeTest, Invalid_SizeMismatch) { + Status s = RunOpExpectFailure( + TensorShape({2, 2}), {0, 1, 1, 0}, + {0, 1}, + {3, 3}, // 期望 9 元素,但输入 dense size = 4 + {2}); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find("Input to reshape is a tensor with") != std::string::npos); +} + +// 反例7:target_shape 包含负数但不是 -1 +TEST_F(KPFusedSparseReshapeTest, Invalid_NegativeDimNotMinusOne) { + Status s = RunOpExpectFailure( + TensorShape({2, 2}), {0, 1, 1, 0}, + {0, 0}, + {2, -2}, // -2 是非法的 + {2}); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find("size 1 must be non-negative, not -2") != std::string::npos) + << "Actual error: " << s.error_message(); +} + +// 反例8:target_shape 有 -1,但其他维度乘积为 0 +TEST_F(KPFusedSparseReshapeTest, Invalid_ProductZeroWithUnknownDim) { + // dense_size = 0(空 SparseTensor),target_shape = [-1, 0] + // product = 0 → 不允许 infer + Status s = RunOpExpectFailure( + TensorShape({0, 2}), {}, // 空的 slice_input + {0, 0}, + {-1, 0}, // product = 0 + {2}); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find("reshape cannot infer the missing input size for an empty tensor") != std::string::npos) + << "Actual error: " << s.error_message(); +} + +} // namespace diff --git a/tensorflow/core/kernels/embedding_fused_sparse_select.cc b/tensorflow/core/kernels/embedding_fused_sparse_select.cc index 0a653cdd0..df19fa100 100644 --- a/tensorflow/core/kernels/embedding_fused_sparse_select.cc +++ b/tensorflow/core/kernels/embedding_fused_sparse_select.cc @@ -53,9 +53,9 @@ public: VLOG(1) << "input_b shape: " << input_b.shape().DebugString(); VLOG(1) << "input_c shape: " << input_c.shape().DebugString(); OP_REQUIRES(context, input_a.NumElements() == input_b.NumElements(), - errors::InvalidArgument("Input num elements must match")); + errors::InvalidArgument("Input num elements of a and b must match")); OP_REQUIRES(context, input_a.NumElements() == input_c.NumElements(), - errors::InvalidArgument("Input num elements must match")); + errors::InvalidArgument("Input num elements of a and c must match")); auto N = input_a.NumElements(); Eigen::TensorMap> a_reshaped_tensor(a_flat.data(), N, 1); diff --git a/tensorflow/core/kernels/embedding_fused_sparse_select_test.cc b/tensorflow/core/kernels/embedding_fused_sparse_select_test.cc new file mode 100644 index 000000000..a68b2e05b --- /dev/null +++ b/tensorflow/core/kernels/embedding_fused_sparse_select_test.cc @@ -0,0 +1,182 @@ +/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace { +using tensorflow::AllocatorAttributes; +using tensorflow::DT_FLOAT; +using tensorflow::DT_INT32; +using tensorflow::DT_INT64; +using tensorflow::int64; +using tensorflow::int32; +using tensorflow::NodeDefBuilder; +using tensorflow::OpsTestBase; +using tensorflow::Status; +using tensorflow::Tensor; +using tensorflow::TensorShape; +using tensorflow::test::ExpectClose; +using tensorflow::test::FillValues; +using tensorflow::test::AsTensor; +using tensorflow::test::ExpectTensorEqual; + +class KPFusedSparseSelectTest : public OpsTestBase { + protected: + void RunValidCase( + const TensorShape& shape, + const std::vector& a_data, + const std::vector& b_data, + const std::vector& c_data, + int32_t greater_val, + int32_t equal1_val, + int32_t equal2_val, + const std::vector& expected_y, + const std::vector& expected_w_col0) { + + TF_EXPECT_OK(NodeDefBuilder("kp_fused_sparse_select", "KPFusedSparseSelect") + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) // greater + .Input(FakeInput(DT_INT32)) // equal1 + .Input(FakeInput(DT_INT32)) // equal2 + .Input(FakeInput(DT_INT32)) // equal3 + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + + AddInputFromArray(shape, a_data); + AddInputFromArray(shape, b_data); + AddInputFromArray(shape, c_data); + AddInputFromArray(TensorShape({}), {greater_val}); // scalar + AddInputFromArray(TensorShape({}), {equal1_val}); + AddInputFromArray(TensorShape({}), {equal2_val}); + AddInputFromArray(TensorShape({}), {0}); // equal3_val (未使用) + + TF_ASSERT_OK(RunOpKernel()); + + const Tensor& out_x = *GetOutput(0); + const Tensor& out_y = *GetOutput(1); + const Tensor& out_w = *GetOutput(2); + + int32 Num_elements = expected_y.size(); + // 验证 output_x: 就是 input_a + std::vector a_data_float(a_data.begin(), a_data.end()); + ExpectTensorEqual(out_x, AsTensor(a_data_float, {Num_elements, 1})); + + // 验证 output_y + ExpectTensorEqual(out_y, AsTensor(expected_y, {Num_elements, 1})); + // 验证 output_w 第一列 + auto w_mat = out_w.matrix(); + for (int i = 0; i < w_mat.dimension(0); ++i) { + EXPECT_FLOAT_EQ(w_mat(i, 0), expected_w_col0[i]); + EXPECT_FLOAT_EQ(w_mat(i, 1), 1.0f); // 第二列必须是 1.0 + } + } + + Status RunOpExpectFailure( + const TensorShape& shape, + const std::vector& a_data, + const std::vector& b_data, + const std::vector& c_data, + int32_t greater_val, + int32_t equal1_val, + int32_t equal2_val) { + + TF_CHECK_OK(NodeDefBuilder("kp_fused_sparse_select", "KPFusedSparseSelect") + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Finalize(node_def())); + TF_CHECK_OK(InitOp()); + TensorShape b_shape({static_cast(b_data.size())}); + TensorShape c_shape({static_cast(c_data.size())}); + AddInputFromArray(shape, a_data); + AddInputFromArray(b_shape, b_data); + AddInputFromArray(c_shape, c_data); + AddInputFromArray(TensorShape({}), {greater_val}); + AddInputFromArray(TensorShape({}), {equal1_val}); + AddInputFromArray(TensorShape({}), {equal2_val}); + AddInputFromArray(TensorShape({}), {0}); + + return RunOpKernel(); + } +}; + +// ==================== 正向测试 ==================== +// 更多正向验证参考 fused_embedding_sparse_select_test.py +TEST_F(KPFusedSparseSelectTest, Valid_NormalInput) { + RunValidCase( + TensorShape({3}), // shape + {5, 3, 8}, // input_a + {1, 2, 1}, // input_b + {9, 8, 7}, // input_c (未使用) + 4, // greater_val + 1, // equal1_val + 3, // equal2_val + {1.0f, 0.0f, 1.0f}, // expected_y + {1.0f, 0.0f, 1.0f} // expected_w_col0 + ); +} + +TEST_F(KPFusedSparseSelectTest, Valid_2DInput) { + RunValidCase( + TensorShape({2, 2}), + {6, 3, 8, 2}, + {2, 1, 3, 4}, + {0, 0, 0, 0}, + 5, + 2, + 3, + {1.0f, 0.0f, 1.0f, 0.0f}, + {1.0f, 0.0f, 1.0f, 0.0f} + ); +} +// ==================== 反向测试 ==================== +// 反例1:input_a 与 input_b 元素数不匹配 +TEST_F(KPFusedSparseSelectTest, Invalid_DimMismatch_AB) { + Status s = RunOpExpectFailure( + TensorShape({3}), // a 有 3 个元素 + {1, 2, 3}, + {4, 5}, // b 有 2 个元素 → 不匹配! + {6, 7, 8}, + 0, 1, 2 + ); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find("Input num elements of a and b must match") != std::string::npos); +} + +// 反例2:input_a 与 input_c 元素数不匹配 +TEST_F(KPFusedSparseSelectTest, Invalid_DimMismatch_AC) { + Status s = RunOpExpectFailure( + TensorShape({2}), + {1, 2}, + {3, 4}, + {5}, // c 只有 1 个元素 → 不匹配! + 0, 1, 2 + ); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find("Input num elements of a and c must match") != std::string::npos); +} + +} \ No newline at end of file diff --git a/tensorflow/core/kernels/matmul_op_test.cc b/tensorflow/core/kernels/matmul_op_test.cc index aa4c8efb6..6f8b11c42 100644 --- a/tensorflow/core/kernels/matmul_op_test.cc +++ b/tensorflow/core/kernels/matmul_op_test.cc @@ -27,6 +27,187 @@ limitations under the License. namespace tensorflow { +template +class MatMulOpTest : public OpsTestBase { + protected: + using MatMulGraphRunner = + std::function; + + void RunAndFetch(const tensorflow::Scope& root, const string& fetch, + Tensor* output, bool allow_gpu_device, + const NodeDef* fetch_node = nullptr) { + tensorflow::GraphDef graph; + TF_ASSERT_OK(root.ToGraphDef(&graph)); + + if (fetch_node) { + *graph.add_node() = *fetch_node; + } + + // We really want to make sure that graph executed exactly as we passed it + // to the session, so we disable various optimizations. + tensorflow::SessionOptions session_options; + + // Disable common runtime constant folding. + session_options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_opt_level(OptimizerOptions::L0); + + // Disable Grappler optimizations for tests. + tensorflow::RewriterConfig* cfg = + session_options.config.mutable_graph_options() + ->mutable_rewrite_options(); + cfg->set_constant_folding(tensorflow::RewriterConfig::OFF); + cfg->set_layout_optimizer(tensorflow::RewriterConfig::OFF); + cfg->set_remapping(tensorflow::RewriterConfig::OFF); + + std::unique_ptr session( + tensorflow::NewSession(session_options)); + + std::vector available_devices; + TF_ASSERT_OK(session->ListDevices(&available_devices)) + << "Failed to get available session devices"; + + // Check if session has an available GPU device. + const bool has_gpu_device = + absl::c_any_of(available_devices, [](const DeviceAttributes& device) { + return device.device_type() == DEVICE_GPU; + }); + + // If fused computation implemented only for CPU, in this test we don't want + // to compare GPU vs CPU numbers, so place all nodes on CPU in this case. + const bool place_all_on_gpu = allow_gpu_device && has_gpu_device; + + const string device = place_all_on_gpu ? "/device:GPU:0" : "/device:CPU:0"; + for (NodeDef& mutable_node : *graph.mutable_node()) { + mutable_node.set_device(device); + } + + TF_ASSERT_OK(session->Create(graph)); + + std::vector unfused_tensors; + TF_ASSERT_OK(session->Run({}, {fetch}, {}, &unfused_tensors)); + + *output = unfused_tensors[0]; + } + + void RunRefMatMul(const Tensor& a, const Tensor& b, + bool transpose_a, bool transpose_b, Tensor* out, + bool allow_gpu_device = false) { + auto lhs = a.flat().data(); + auto rhs = b.flat().data(); + + auto a_dim = a.shape().dim_sizes(); + auto b_dim = b.shape().dim_sizes(); + + int m = a_dim[0]; + int k = a_dim[1]; + int n = b_dim[1]; + + TensorShape out_shape({m, n}); + *out = Tensor(DataTypeToEnum::v(), out_shape); + auto output = out->flat().data(); + + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + T sum = T(0); + for (int p = 0; p < k; ++p) { + T a_val = lhs[i * k + p]; + T b_val = rhs[p * n + j]; + sum += a_val * b_val; + } + output[i * n + j] = sum; + } + } + } + + void RunKMatMul(const Tensor& lhs_data, const Tensor& rhs_data, + bool transpose_a, bool transpose_b, Tensor* output, + bool allow_gpu_device = false) { + Scope root = tensorflow::Scope::NewRootScope(); + + ops::MatMul kmatmul = ops::MatMul( + root.WithOpName("kmatmul"), + ops::Const(root.WithOpName("lhs"), Input::Initializer(lhs_data)), + ops::Const(root.WithOpName("rhs"), Input::Initializer(rhs_data)), + ops::MatMul::Attrs().TransposeA(transpose_a).TransposeB(transpose_b)); + + RunAndFetch(root, "kmatmul", output, allow_gpu_device); + } + + void VerifyMatMulTensorsNear( + const Tensor& lhs, + const Tensor& rhs, + const MatMulGraphRunner& run_reference, + const MatMulGraphRunner& run_kdnn) { + + Tensor matmul; + Tensor kmatmul; + + run_reference(lhs, rhs, &matmul); + run_kdnn(lhs, rhs, &kmatmul); + + ASSERT_EQ(matmul.dtype(), kmatmul.dtype()); + ASSERT_EQ(matmul.shape(), kmatmul.shape()); + + // 数值对比(允许浮点误差) + test::ExpectClose(matmul, kmatmul, /*atol=*/1e-5); + } + + void VerifyMatMul(int m, int k, int n, bool transpose_a, bool transpose_b) { + DataType dtype = DataTypeToEnum::v(); + Tensor lhs(dtype, {m, k}); + lhs.flat() = lhs.flat().setRandom(); + Tensor rhs(dtype, {k, n}); + rhs.flat() = rhs.flat().setRandom(); + + const MatMulGraphRunner run_reference = + [&](const Tensor& a, const Tensor& b, + Tensor* out) { + RunRefMatMul(a, b, false, false, out, false); + }; + + const MatMulGraphRunner run_kdnn = + [&](const Tensor& a, const Tensor& b, + Tensor* out) { + RunKMatMul(a, b, false, false, out, false); + }; + + VerifyMatMulTensorsNear(lhs, rhs, run_reference, run_kdnn); + } + + void VerifyMatMulWithInputs( + int m, int k, int n, + const std::vector& A_data, + const std::vector& B_data, + bool transpose_a = false, + bool transpose_b = false) { + ASSERT_EQ(A_data.size(), static_cast(m * k)); + ASSERT_EQ(B_data.size(), static_cast(k * n)); + + DataType dtype = DataTypeToEnum::v(); + + Tensor lhs(dtype, {m, k}); + Tensor rhs(dtype, {k, n}); + std::copy(A_data.begin(), A_data.end(), lhs.flat().data()); + std::copy(B_data.begin(), B_data.end(), rhs.flat().data()); + + const MatMulGraphRunner run_reference = + [&](const Tensor& a, const Tensor& b, Tensor* out) { + RunRefMatMul(a, b, transpose_a, transpose_b, out, false); + }; + + const MatMulGraphRunner run_kdnn = + [&](const Tensor& a, const Tensor& b, Tensor* out) { + RunKMatMul(a, b, transpose_a, transpose_b, out, false); + }; + + VerifyMatMulTensorsNear(lhs, rhs, run_reference, run_kdnn); + } +}; + +TYPED_TEST_SUITE_P(MatMulOpTest); + template class FusedMatMulOpTest : public OpsTestBase { protected: @@ -323,6 +504,439 @@ TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x1WithActivation) { } } +// -------------------------------------------------------------------------- // +// MatMul Base Random Test // +// -------------------------------------------------------------------------- // + + +// 基础维度 +TYPED_TEST_P(MatMulOpTest, MatMul_256x256x256) { + this->VerifyMatMul(256, 256, 256, false, false); +} + +TYPED_TEST_P(MatMulOpTest, MatMul_1x256x256) { + this->VerifyMatMul(1, 256, 256, false, false); +} + +TYPED_TEST_P(MatMulOpTest, MatMul_256x256x1) { + this->VerifyMatMul(256, 256, 1, false, false); +} + +TYPED_TEST_P(MatMulOpTest, MatMul_1x256x1) { + this->VerifyMatMul(1, 256, 1, false, false); +} + +// -------------------------------------------------------------------------- // +// MatMul Extended Random Test // +// Cover: jdtest // +// -------------------------------------------------------------------------- // + +// === 中等规模 & KDNN 常见场景 === +TYPED_TEST_P(MatMulOpTest, MatMul_5530x104x32) { + this->VerifyMatMul(5530, 104, 32, false, false); +} +TYPED_TEST_P(MatMulOpTest, MatMul_5530x116x32) { + this->VerifyMatMul(5530, 116, 32, false, false); +} +TYPED_TEST_P(MatMulOpTest, MatMul_5530x32x16) { + this->VerifyMatMul(5530, 32, 16, false, false); +} +TYPED_TEST_P(MatMulOpTest, MatMul_5530x16x1) { + this->VerifyMatMul(5530, 16, 1, false, false); +} +TYPED_TEST_P(MatMulOpTest, MatMul_7000x104x32) { + this->VerifyMatMul(7000, 104, 32, false, false); +} +TYPED_TEST_P(MatMulOpTest, MatMul_7000x116x32) { + this->VerifyMatMul(7000, 116, 32, false, false); +} +TYPED_TEST_P(MatMulOpTest, MatMul_7000x32x16) { + this->VerifyMatMul(7000, 32, 16, false, false); +} +TYPED_TEST_P(MatMulOpTest, MatMul_7000x16x1) { + this->VerifyMatMul(7000, 16, 1, false, false); +} + +// 极小维度测试:0 维度(空矩阵) +TYPED_TEST_P(MatMulOpTest, MatMul_0x256x256) { + this->VerifyMatMul(0, 256, 256, false, false); +} + +TYPED_TEST_P(MatMulOpTest, MatMul_256x0x256) { + this->VerifyMatMul(256, 0, 256, false, false); +} + +TYPED_TEST_P(MatMulOpTest, MatMul_256x256x0) { + this->VerifyMatMul(256, 256, 0, false, false); +} + +TYPED_TEST_P(MatMulOpTest, MatMul_0x0x0) { + this->VerifyMatMul(0, 0, 0, false, false); +} + +TYPED_TEST_P(MatMulOpTest, MatMul_0x0x1) { + this->VerifyMatMul(0, 0, 1, false, false); +} + +TYPED_TEST_P(MatMulOpTest, MatMul_0x1x0) { + this->VerifyMatMul(0, 1, 0, false, false); +} + +TYPED_TEST_P(MatMulOpTest, MatMul_1x0x0) { + this->VerifyMatMul(1, 0, 0, false, false); +} + +// 非 2 的幂次维度(内存不对齐) +TYPED_TEST_P(MatMulOpTest, MatMul_257x257x257) { + this->VerifyMatMul(257, 257, 257, false, false); +} + +TYPED_TEST_P(MatMulOpTest, MatMul_250x240x230) { + this->VerifyMatMul(250, 240, 230, false, false); +} + +TYPED_TEST_P(MatMulOpTest, MatMul_123x456x789) { + this->VerifyMatMul(123, 456, 789, false, false); +} + +// 大 k 值(高计算密度) +TYPED_TEST_P(MatMulOpTest, MatMul_64x8192x64) { + this->VerifyMatMul(64, 8192, 64, false, false); +} + +// 小 k 值 +TYPED_TEST_P(MatMulOpTest, MatMul_256x1x256) { + this->VerifyMatMul(256, 1, 256, false, false); +} + +// 大 k:如 Embedding 后接 FFN +TYPED_TEST_P(MatMulOpTest, MatMul_64x4096x512) { + this->VerifyMatMul(64, 4096, 512, false, false); +} + +// 小 m, 大 n:如分类头 +TYPED_TEST_P(MatMulOpTest, MatMul_1x512x1000) { + this->VerifyMatMul(1, 512, 1000, false, false); +} + +// 大 m, 小 n:如 Batch 大但输出小 +TYPED_TEST_P(MatMulOpTest, MatMul_1024x256x1) { + this->VerifyMatMul(1024, 256, 1, false, false); +} + +// 超小维度组合(广播/Kernel 选择错误) +TYPED_TEST_P(MatMulOpTest, MatMul_1x1x1) { + this->VerifyMatMul(1, 1, 1, false, false); +} + +TYPED_TEST_P(MatMulOpTest, MatMul_1x1x64) { + this->VerifyMatMul(1, 1, 64, false, false); +} + +TYPED_TEST_P(MatMulOpTest, MatMul_64x1x1) { + this->VerifyMatMul(64, 1, 1, false, false); +} + +// -------------------------------------------------------------------------- // +// MatMul Base Value Test // +// -------------------------------------------------------------------------- // + +// 零值矩阵 +TYPED_TEST_P(MatMulOpTest, MatMul_ZeroMatrix_A) { + int m = 3, k = 4, n = 5; + std::vector A(m * k, TypeParam(0)); // 全零矩阵 + std::vector B(k * n); + std::generate(B.begin(), B.end(), []() { + return static_cast(rand() % 10 - 5); + }); + + this->VerifyMatMulWithInputs(m, k, n, A, B); +} + +TYPED_TEST_P(MatMulOpTest, MatMul_ZeroMatrix_B) { + int m = 3, k = 4, n = 5; + std::vector A(m * k); + std::vector B(k * n, TypeParam(0)); + std::generate(A.begin(), A.end(), []() { + return static_cast(rand() % 10 - 5); + }); + + this->VerifyMatMulWithInputs(m, k, n, A, B); +} + +TYPED_TEST_P(MatMulOpTest, MatMul_ZeroMatrix_AB) { + int m = 3, k = 4, n = 5; + std::vector A(m * k, TypeParam(0)); + std::vector B(k * n, TypeParam(0)); + + this->VerifyMatMulWithInputs(m, k, n, A, B); +} + +// 单位矩阵 +TYPED_TEST_P(MatMulOpTest, MatMul_IdentityMatrix_A) { + int n = 9; // 使用 9x9 单位阵 + int m = n, k = n; + std::vector A(m * k, TypeParam(0)); + for (int i = 0; i < n; ++i) { + A[i * k + i] = TypeParam(1); // 对角线为 1 + } + + std::vector B(k * n); + std::generate(B.begin(), B.end(), []() { + return static_cast(rand() % 10 - 5); + }); + + this->VerifyMatMulWithInputs(m, k, n, A, B); +} + +TYPED_TEST_P(MatMulOpTest, MatMul_IdentityMatrix_B) { + int n = 9; // 使用 9x9 单位阵 + int m = n, k = n; + std::vector A(m * k); + std::generate(A.begin(), A.end(), []() { + return static_cast(rand() % 10 - 5); + }); + + std::vector B(k * n, TypeParam(0)); + for (int i = 0; i < n; ++i) { + B[i * n + i] = TypeParam(1); // 对角线为 1 + } + + this->VerifyMatMulWithInputs(m, k, n, A, B); +} + +TYPED_TEST_P(MatMulOpTest, MatMul_IdentityMatrix_AB) { + int n = 9; // 使用 9x9 单位阵 + int m = n, k = n; + std::vector A(m * k, TypeParam(0)); + for (int i = 0; i < n; ++i) { + A[i * k + i] = TypeParam(1); // 对角线为 1 + } + + std::vector B = A; + + this->VerifyMatMulWithInputs(m, k, n, A, B); +} + +// 浮点特殊值 +// | A = [[+inf]] | B = | Expected = [[+inf]] | +TYPED_TEST_P(MatMulOpTest, MatMul_Inf_Positive) { + std::vector A = {std::numeric_limits::infinity()}; + std::vector B = {TypeParam(2)}; + + this->VerifyMatMulWithInputs(1, 1, 1, A, B); +} + +// | A = [[-inf]] | B = | Expected = [[-inf]] | +TYPED_TEST_P(MatMulOpTest, MatMul_Inf_Negative) { + std::vector A = {-std::numeric_limits::infinity()}; + std::vector B = {TypeParam(2)}; + + this->VerifyMatMulWithInputs(1, 1, 1, A, B); +} + +// NaN 传播 +// | A = [[nan]] | B = [[b]] | Expected = [[nan]] | +TYPED_TEST_P(MatMulOpTest, MatMul_NaN) { + int m = 2, k = 2, n = 2; + std::vector A = {1.0, 2.0, + 3.0, std::numeric_limits::quiet_NaN()}; + + std::vector B = {1.0, 1.0, + 1.0, 1.0}; + Tensor lhs(DataTypeToEnum::v(), {m, k}); + Tensor rhs(DataTypeToEnum::v(), {k, n}); + std::copy(A.begin(), A.end(), lhs.flat().data()); + std::copy(B.begin(), B.end(), rhs.flat().data()); + + Tensor output_ref, output_kdnn; + + this->RunRefMatMul(lhs, rhs, false, false, &output_ref); + this->RunKMatMul(lhs, rhs, false, false, &output_kdnn, false); + auto ref_flat = output_ref.flat(); + auto kdnn_flat = output_kdnn.flat(); + int size = ref_flat.size(); + + for (int i = 0; i < size; ++i) { + TypeParam x = ref_flat(i); + TypeParam y = kdnn_flat(i); + + bool both_nan = std::isnan(x) && std::isnan(y); + bool both_inf = std::isinf(x) && std::isinf(y) && (std::signbit(x) == std::signbit(y)); + + TypeParam atol = 1e-5f; + TypeParam rtol = 1e-5f; + TypeParam diff = std::abs(x - y); + TypeParam threshold = atol + rtol * std::abs(y); + bool normal_close = !std::isnan(x) && !std::isnan(y) && + diff <= threshold; // 自定义 atol/rtol + + EXPECT_TRUE(both_nan || both_inf || normal_close) + << "Mismatch at index " << i << ": ref=" << x << ", kdnn=" << y; + } +} + +// 浮点数极限值 +// | A = max | B = 1 | Expected = max | +TYPED_TEST_P(MatMulOpTest, MatMul_FloatMax_Times_One) { + auto max_val = std::numeric_limits::max(); + std::vector A = {max_val}; + std::vector B = {TypeParam(1)}; + + this->VerifyMatMulWithInputs(1, 1, 1, A, B); +} + +// | A = min | B = 1 | Expected = min | +TYPED_TEST_P(MatMulOpTest, MatMul_FloatMin_Times_One) { + auto min_val = std::numeric_limits::lowest(); + std::vector A = {min_val}; + std::vector B = {TypeParam(1)}; + + this->VerifyMatMulWithInputs(1, 1, 1, A, B); +} + +// (max-1) × 1 = max-1 +TYPED_TEST_P(MatMulOpTest, MatMul_FloatMax_Minus_One_Times_One) { + auto max_val = std::numeric_limits::max(); + if (std::isfinite(max_val)) { + std::vector A = {max_val - TypeParam(1)}; + std::vector B = {TypeParam(1)}; + this->VerifyMatMulWithInputs(1, 1, 1, A, B); + } +} + +// (min+1) × 1 = min+1 +TYPED_TEST_P(MatMulOpTest, MatMul_FloatMin_Plus_One_Times_One) { + auto min_val = std::numeric_limits::lowest(); + if (std::isfinite(min_val)) { + std::vector A = {min_val + TypeParam(1)}; + std::vector B = {TypeParam(1)}; + this->VerifyMatMulWithInputs(1, 1, 1, A, B); + } +} + +// (max/2) × 2 → 应溢出为 inf +TYPED_TEST_P(MatMulOpTest, MatMul_FloatMax_Half_Times_Two) { + auto max_val = std::numeric_limits::max(); + TypeParam half_max = max_val / TypeParam(2); + std::vector A = {half_max}; + std::vector B = {TypeParam(2)}; + + this->VerifyMatMulWithInputs(1, 1, 1, A, B); +} + +// (min/2) × 2 → 应溢出为 -inf +TYPED_TEST_P(MatMulOpTest, MatMul_FloatMin_Half_Times_Two) { + auto min_val = std::numeric_limits::lowest(); + TypeParam half_min = min_val / TypeParam(2); + std::vector A = {half_min}; + std::vector B = {TypeParam(2)}; + + this->VerifyMatMulWithInputs(1, 1, 1, A, B); +} + +// (max-1)/2 × 2 = max-1(不溢出) +TYPED_TEST_P(MatMulOpTest, MatMul_FloatMax_Minus_One_Half_Times_Two) { + auto max_val = std::numeric_limits::max(); + TypeParam val = (max_val - TypeParam(1)) / TypeParam(2); + std::vector A = {val}; + std::vector B = {TypeParam(2)}; + + this->VerifyMatMulWithInputs(1, 1, 1, A, B); +} + +// (min+1)/2 × 2 = min+1(不溢出) +TYPED_TEST_P(MatMulOpTest, MatMul_FloatMin_Plus_One_Half_Times_Two) { + auto min_val = std::numeric_limits::lowest(); + TypeParam val = (min_val + TypeParam(1)) / TypeParam(2); + std::vector A = {val}; + std::vector B = {TypeParam(2)}; + + this->VerifyMatMulWithInputs(1, 1, 1, A, B); +} + +// 精度敏感测试 +TYPED_TEST_P(MatMulOpTest, MatMul_PrecisionSensitive) { + // 使用小数值 + const TypeParam kSmall = TypeParam(0.01); + + int m = 2, k = 2, n = 2; + std::vector A = {kSmall, kSmall, + kSmall, kSmall}; + std::vector B = {kSmall, kSmall, + kSmall, kSmall}; + + // 期望结果:C[i][j] = 2 * (0.01 * 0.01) = 2 * 0.0001 = 0.0002 + const TypeParam kExpected = TypeParam(2) * kSmall * kSmall; // 0.0002 + + Tensor lhs(DataTypeToEnum::v(), {m, k}); + Tensor rhs(DataTypeToEnum::v(), {k, n}); + std::copy(A.begin(), A.end(), lhs.flat().data()); + std::copy(B.begin(), B.end(), rhs.flat().data()); + + Tensor output; + this->RunKMatMul(lhs, rhs, false, false, &output, false); + + auto out_flat = output.flat(); + for (int i = 0; i < m * n; ++i) { + EXPECT_NEAR(out_flat(i), kExpected, 1e-6) + << "Output at index " << i << " is not close to expected precision."; + } +} + +REGISTER_TYPED_TEST_SUITE_P(MatMulOpTest, + MatMul_256x256x256, + MatMul_1x256x256, + MatMul_256x256x1, + MatMul_1x256x1, + // 中等规模 + MatMul_5530x104x32, + MatMul_5530x116x32, + MatMul_5530x32x16, + MatMul_5530x16x1, + MatMul_7000x104x32, + MatMul_7000x116x32, + MatMul_7000x32x16, + MatMul_7000x16x1, + // 边界测试 + MatMul_0x256x256, + MatMul_256x0x256, + MatMul_256x256x0, + MatMul_0x0x0, + MatMul_0x0x1, + MatMul_0x1x0, + MatMul_1x0x0, + MatMul_257x257x257, + MatMul_250x240x230, + MatMul_123x456x789, + MatMul_64x8192x64, + MatMul_256x1x256, + MatMul_64x4096x512, + MatMul_1x512x1000, + MatMul_1024x256x1, + MatMul_1x1x1, + MatMul_1x1x64, + MatMul_64x1x1, + MatMul_ZeroMatrix_A, + MatMul_ZeroMatrix_B, + MatMul_ZeroMatrix_AB, + MatMul_IdentityMatrix_A, + MatMul_IdentityMatrix_B, + MatMul_IdentityMatrix_AB, + MatMul_Inf_Positive, + MatMul_Inf_Negative, + MatMul_NaN, + MatMul_FloatMax_Times_One, + MatMul_FloatMin_Times_One, + MatMul_FloatMax_Minus_One_Times_One, + MatMul_FloatMin_Plus_One_Times_One, + MatMul_FloatMax_Half_Times_Two, + MatMul_FloatMin_Half_Times_Two, + MatMul_FloatMax_Minus_One_Half_Times_Two, + MatMul_FloatMin_Plus_One_Half_Times_Two, + MatMul_PrecisionSensitive); + + REGISTER_TYPED_TEST_SUITE_P(FusedMatMulWithBiasOpTest, // MatMul256x256x256, // MatMul1x256x256, // @@ -333,6 +947,10 @@ REGISTER_TYPED_TEST_SUITE_P(FusedMatMulWithBiasOpTest, // MatMul256x256x1WithActivation, // MatMul1x256x1WithActivation); +using MatMulTestTypes = ::testing::Types; +INSTANTIATE_TYPED_TEST_SUITE_P(Test, MatMulOpTest, + MatMulTestTypes); + // TODO(ezhulenev): Add support for more data types. using FusedBiasAddDataTypes = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedMatMulWithBiasOpTest, diff --git a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py index 5154ffa82..f4c079565 100644 --- a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py +++ b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py @@ -98,7 +98,6 @@ class TestFusedGather(unittest.TestCase): out_opt_val1, err_msg="Segment count mismatch" ) - np.testing.assert_array_equal( out_ori_val2, out_opt_val2, @@ -131,15 +130,25 @@ class TestFusedGather(unittest.TestCase): def test_kp_embedding_gather(self): base_data = np.linspace(0, 11, num=240, endpoint=False, dtype=np.float32).reshape(20, 12) base_slice_input = np.array([[0, 0], [0, 1], [1, 2]], dtype=np.int64) - base_begin = [0, 1] + base_begin = np.array([0, 1], dtype=np.int32) self._run_kp_gather_test((20, 12), (3, 2), base_data, base_slice_input, base_begin, num_runs=100) + def test_kp_gather_with_duplicates(self): + base_data = np.random.rand(100, 12).astype(np.float32) + base_slice_input = np.array([[5, 3], [7, 3], [9, 4], [5, 3]], dtype=np.int64) + base_begin = np.array([0, 1], dtype=np.int32) + self._run_kp_gather_test((100, 12), (4, 2), base_data, base_slice_input, base_begin, num_runs=100) + + def test_kp_gather_single_unique(self): + base_data = np.random.rand(50, 12).astype(np.float32) + base_slice_input = np.array([[10, 7], [20, 7], [30, 7]], dtype=np.int64) + base_begin = np.array([0, 1], dtype=np.int32) + self._run_kp_gather_test((50, 12), (3, 2), base_data, base_slice_input, base_begin, num_runs=100) def test_kp_gather_262145(self): base_data = np.linspace(0, 11111, num=262145*12, dtype=np.float32).reshape(262145, 12) - # base_slice_input = np.array([[0, 1]], dtype=np.int64) base_slice_input = np.random.randint(0, 262146, size=(46, 2), dtype=np.int64) - base_begin = [0, 0] + base_begin = np.array([0, 0], dtype=np.int32) self._run_kp_gather_test((262145, 12), (46, 2), base_data, base_slice_input, base_begin, num_runs=100) -- Gitee