diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_file_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1e28e07cd195eafba5807d446d04af82054b431c --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_file_utils.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +from unittest import TestCase +from msprobe.core.common.file_utils import FileChecker +from msprobe.core.common.file_utils import check_file_size +from msprobe.core.common.file_utils import check_path_before_create +from msprobe.core.common.exceptions import FileCheckException + +class TestFileUtils(TestCase): + + def test_check_path_type(self): + with self.assertRaises(FileCheckException) as context: + FileChecker._check_path_type("err_file") + self.assertEqual(context.exception.code, FileCheckException.ILLEGAL_PARAM_ERROR) + + def test_check_file_size(self): + with self.assertRaises(FileCheckException) as context: + check_file_size("xxx.txt", 100) + self.assertEqual(context.exception.code, FileCheckException.INVALID_FILE_ERROR) + + def test_check_path_before_create_long_path(self): + with self.assertRaises(FileCheckException) as context: + long_file_path = "xxx"*4094+".txt" + check_path_before_create(long_file_path) + self.assertEqual(context.exception.code, FileCheckException.ILLEGAL_PATH_ERROR) + + def test_check_path_before_create_invalid_char(self): + with self.assertRaises(FileCheckException) as context: + invalid_file_path = "***.txt" + check_path_before_create(invalid_file_path) + self.assertEqual(context.exception.code, FileCheckException.ILLEGAL_PATH_ERROR) + diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_common_data_scope_parser.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_common_data_scope_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..ee9c7ac866cdd80bfa83bf61434b8294fdce62f2 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_common_data_scope_parser.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +from unittest import TestCase +from msprobe.core.compare.layer_mapping.data_scope_parser import DumpDataItem +from msprobe.core.common.utils import CompareException + + +class TestDataScopeParser(TestCase): + + def test_check_stack_valid_invalid_stack_type(self): + stack_info_string = "conv1.Conv2d.forward.input" + with self.assertRaises(CompareException) as context: + DumpDataItem.check_stack_valid(stack_info_string) + self.assertEqual(context.exception.code, CompareException.INVALID_DATA_ERROR) + + def test_check_stack_valid_invalid_stack_info(self): + stack_info_list = ["conv1.Conv2d.forward.input", 1] + with self.assertRaises(CompareException) as context: + DumpDataItem.check_stack_valid(stack_info_list) + self.assertEqual(context.exception.code, CompareException.INVALID_DATA_ERROR) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_postprocess_pass.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_postprocess_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb33eb277848fa96bdf5b7456867d8579359723 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_postprocess_pass.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +from unittest import TestCase +from msprobe.core.compare.layer_mapping.postprocess_pass import extract_next_item_last_number +from msprobe.core.compare.layer_mapping.postprocess_pass import replace_next_item_index + + +class TestPostProcessPass(TestCase): + + def test_check_path_type_None(self): + input_data = "conv1.Conv2d.forward.input" + prefix = "Conv2d" + none_result = extract_next_item_last_number(input_data, prefix) + self.assertEqual(none_result, None) + + def test_check_path_type_find_result(self): + input_data = "conv1.Conv2d.forward.input.conv1" + prefix = "conv1" + result_2 = extract_next_item_last_number(input_data, prefix) + self.assertEqual(result_2, 2) + + def test_replace_next_item_index(self): + input_data = "conv1.Conv2d.forward.input.conv1" + prefix = "conv1" + replace_result = replace_next_item_index(input_data, prefix, 1) + self.assertEqual(replace_result, "conv1.1.forward.input.conv1") + + def test_replace_next_item_index_with_inf(self): + input_data = "conv1.Conv2d.forward.input.conv1" + prefix = "conv1" + inf_value = float("inf") + replace_result = replace_next_item_index(input_data, prefix, inf_value) + self.assertEqual(replace_result, input_data) +