diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py index bb39809891611ea0b3d17e660f5841849195b1ff..6c27733735694ad613b8950d009ed8d576c7a2f6 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py @@ -43,6 +43,8 @@ def correct_data(result): def cosine_similarity(n_value, b_value): np.seterr(divide='ignore', invalid='ignore') + if not np.issubdtype(n_value.dtype, np.floating) or not np.issubdtype(b_value.dtype, np.floating): + return "unsupported", "Cosine similarity comparison is not supported for non-float tensors." if len(n_value) == 1: return "unsupported", "This tensor is scalar." num = n_value.dot(b_value)