From 886f25edf631659f67f44e88ca6d97904cab556d Mon Sep 17 00:00:00 2001 From: sunyiming Date: Sat, 16 Mar 2024 01:55:23 +0000 Subject: [PATCH] update debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py. Signed-off-by: sunyiming --- .../ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py index bb398098916..6c277337356 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py @@ -43,6 +43,8 @@ def correct_data(result): def cosine_similarity(n_value, b_value): np.seterr(divide='ignore', invalid='ignore') + if not np.issubdtype(n_value.dtype, np.floating) or not np.issubdtype(b_value.dtype, np.floating): + return "unsupported", "Cosine similarity comparison is not supported for non-float tensors." if len(n_value) == 1: return "unsupported", "This tensor is scalar." num = n_value.dot(b_value) -- Gitee