From cad66a2f36fd005545c8fc4dd31ee39bc618cb77 Mon Sep 17 00:00:00 2001 From: kai-ma Date: Thu, 3 Jul 2025 21:25:28 +0800 Subject: [PATCH] renew --- accuracy_tools/.clang-format | 36 + accuracy_tools/CMakeLists.txt | 25 + accuracy_tools/LICENSE | 201 ++++++ accuracy_tools/README.md | 17 + accuracy_tools/__init__.py | 13 + accuracy_tools/build.sh | 83 +++ accuracy_tools/cmake/Findcpython.cmake | 16 + accuracy_tools/cmake/Findgtest.cmake | 47 ++ accuracy_tools/cmake/Findmockcpp.cmake | 43 ++ accuracy_tools/cmake/Findnlohmannjson.cmake | 18 + accuracy_tools/cmake/config.ini | 11 + accuracy_tools/cmake/download_opensource.sh | 110 +++ accuracy_tools/cmake/utils.cmake | 42 ++ accuracy_tools/docs/0001.capability_matrix.md | 0 accuracy_tools/docs/0002.installation.md | 63 ++ .../docs/0003.config_introduction.md | 0 accuracy_tools/docs/0004.config_examples.md | 0 .../docs/0101.dump_offline_model.md | 0 accuracy_tools/docs/README.md | 27 + accuracy_tools/msprobe/CMakeLists.txt | 1 + accuracy_tools/msprobe/__init__.py | 16 + accuracy_tools/msprobe/__main__.py | 26 + accuracy_tools/msprobe/base/__init__.py | 18 + accuracy_tools/msprobe/base/cmd.py | 80 +++ .../msprobe/base/component/__init__.py | 13 + .../msprobe/base/component/manager.py | 278 ++++++++ accuracy_tools/msprobe/base/config.py | 117 ++++ .../msprobe/base/service/__init__.py | 13 + .../msprobe/base/service/manager.py | 114 +++ accuracy_tools/msprobe/common/__init__.py | 13 + accuracy_tools/msprobe/common/ascend.py | 93 +++ accuracy_tools/msprobe/common/cli.py | 99 +++ accuracy_tools/msprobe/common/dirs.py | 74 ++ accuracy_tools/msprobe/common/stat.py | 89 +++ accuracy_tools/msprobe/common/validation.py | 181 +++++ accuracy_tools/msprobe/core/__init__.py | 15 + accuracy_tools/msprobe/core/base/__init__.py | 17 + .../msprobe/core/base/dump_actuator.py | 191 +++++ .../msprobe/core/base/dump_dumper.py | 49 ++ .../msprobe/core/base/dump_writer.py | 126 ++++ accuracy_tools/msprobe/core/cli/__init__.py | 15 + .../msprobe/core/cli/command_dump.py | 65 ++ .../msprobe/core/components/__init__.py | 26 + .../msprobe/core/components/dumper_acl.py | 150 ++++ .../msprobe/core/components/dumper_atb.py | 120 ++++ .../msprobe/core/components/dumper_caffe.py | 104 +++ .../core/components/dumper_offline_model.py | 20 + .../msprobe/core/components/dumper_om.py | 35 + .../msprobe/core/components/dumper_onnx.py | 161 +++++ .../msprobe/core/components/dumper_tf.py | 164 +++++ .../msprobe/core/components/dumper_writer.py | 149 ++++ .../msprobe/core/config_initiator/__init__.py | 15 + .../core/config_initiator/config_dump.py | 110 +++ .../core/config_initiator/validate_params.py | 325 +++++++++ accuracy_tools/msprobe/core/dump/__init__.py | 19 + .../msprobe/core/dump/acl_manager.py | 109 +++ .../msprobe/core/dump/caffe_model.py | 59 ++ accuracy_tools/msprobe/core/dump/om_model.py | 171 +++++ .../msprobe/core/dump/onnx_model.py | 68 ++ accuracy_tools/msprobe/core/dump/tf_model.py | 171 +++++ .../msprobe/core/service/__init__.py | 15 + accuracy_tools/msprobe/core/service/dump.py | 315 +++++++++ accuracy_tools/msprobe/csrc/CMakeLists.txt | 51 ++ .../msprobe/csrc/acl/core/AclApi.cpp | 411 +++++++++++ .../msprobe/csrc/acl/include/AclApi.h | 152 ++++ .../msprobe/csrc/atb_probe/Override.cpp | 252 +++++++ .../msprobe/csrc/atb_probe/Override.h | 84 +++ .../msprobe/csrc/atb_probe/core/Helper.cpp | 287 ++++++++ .../msprobe/csrc/atb_probe/core/SaveExtra.cpp | 181 +++++ .../msprobe/csrc/atb_probe/core/SaveGraph.cpp | 188 +++++ .../csrc/atb_probe/core/SaveTensor.cpp | 171 +++++ .../msprobe/csrc/atb_probe/core/Stat.cpp | 229 ++++++ .../msprobe/csrc/atb_probe/include/Helper.h | 86 +++ .../csrc/atb_probe/include/SaveExtra.h | 36 + .../csrc/atb_probe/include/SaveGraph.h | 32 + .../csrc/atb_probe/include/SaveTensor.h | 33 + .../msprobe/csrc/atb_probe/include/Stat.h | 29 + .../msprobe/csrc/common/Toolkit.cpp | 82 +++ accuracy_tools/msprobe/csrc/common/Toolkit.h | 29 + .../msprobe/csrc/python/PyACLActuator.cpp | 591 ++++++++++++++++ .../msprobe/csrc/python/PyACLActuator.h | 26 + .../msprobe/csrc/python/PyInterface.cpp | 64 ++ accuracy_tools/msprobe/csrc/python/PyLog.cpp | 107 +++ accuracy_tools/msprobe/csrc/python/PyLog.h | 32 + accuracy_tools/msprobe/csrc/utils/Constant.h | 63 ++ accuracy_tools/msprobe/csrc/utils/DataType.h | 88 +++ accuracy_tools/msprobe/csrc/utils/Exception.h | 30 + accuracy_tools/msprobe/csrc/utils/IO.cpp | 195 ++++++ accuracy_tools/msprobe/csrc/utils/IO.h | 44 ++ accuracy_tools/msprobe/csrc/utils/Log.h | 156 +++++ accuracy_tools/msprobe/csrc/utils/Path.cpp | 237 +++++++ accuracy_tools/msprobe/csrc/utils/Path.h | 78 +++ accuracy_tools/msprobe/csrc/utils/Str.cpp | 197 ++++++ accuracy_tools/msprobe/csrc/utils/Str.h | 64 ++ accuracy_tools/msprobe/utils/__init__.py | 13 + accuracy_tools/msprobe/utils/constants.py | 201 ++++++ accuracy_tools/msprobe/utils/dependencies.py | 96 +++ accuracy_tools/msprobe/utils/env.py | 80 +++ accuracy_tools/msprobe/utils/exceptions.py | 22 + accuracy_tools/msprobe/utils/hijack.py | 405 +++++++++++ accuracy_tools/msprobe/utils/io.py | 347 +++++++++ accuracy_tools/msprobe/utils/log.py | 69 ++ accuracy_tools/msprobe/utils/path.py | 340 +++++++++ accuracy_tools/msprobe/utils/toolkits.py | 424 +++++++++++ accuracy_tools/pyproject.toml | 11 + accuracy_tools/requirements/requirements.txt | 7 + .../requirements/requirements_tf.txt | 10 + accuracy_tools/setup.py | 87 +++ accuracy_tools/test/ST/run_st.py | 0 accuracy_tools/test/UT/.coveragerc | 3 + accuracy_tools/test/UT/CMakeLists.txt | 1 + .../component_ut/test_manager_component.py | 269 +++++++ .../service_ut/test_manager_service.py | 79 +++ accuracy_tools/test/UT/base_ut/test_cmd.py | 86 +++ accuracy_tools/test/UT/base_ut/test_config.py | 135 ++++ .../test/UT/common_ut/test_ascend.py | 0 accuracy_tools/test/UT/common_ut/test_cli.py | 52 ++ .../test/UT/common_ut/test_validation.py | 254 +++++++ .../probe_ut/base_ut/test_dump_actuator.py | 133 ++++ .../probe_ut/base_ut/test_dump_dumper.py | 39 ++ .../probe_ut/base_ut/test_dump_writer.py | 166 +++++ .../test_dumper_offline_model.py | 26 + .../test_validate_params.py | 325 +++++++++ .../probe_ut/dump_ut/test_caffe_model.py | 118 ++++ .../probe_ut/dump_ut/test_onnx_model.py | 160 +++++ .../core_ut/probe_ut/dump_ut/test_tf_model.py | 269 +++++++ accuracy_tools/test/UT/csrc_ut/CMakeLists.txt | 24 + .../test/UT/csrc_ut/utils_ut/test_log.cpp | 0 accuracy_tools/test/UT/pytest.ini | 3 + accuracy_tools/test/UT/run_ut.py | 35 + accuracy_tools/test/UT/run_ut.sh | 56 ++ accuracy_tools/test/UT/test___main__.py | 38 + .../test/UT/utils_ut/test_dependencies.py | 113 +++ accuracy_tools/test/UT/utils_ut/test_env.py | 98 +++ .../test/UT/utils_ut/test_hijack.py | 294 ++++++++ accuracy_tools/test/UT/utils_ut/test_io.py | 569 +++++++++++++++ accuracy_tools/test/UT/utils_ut/test_log.py | 116 +++ accuracy_tools/test/UT/utils_ut/test_path.py | 660 ++++++++++++++++++ .../test/UT/utils_ut/test_toolkits.py | 446 ++++++++++++ accuracy_tools/third_party/.keep | 0 140 files changed, 15745 insertions(+) create mode 100644 accuracy_tools/.clang-format create mode 100644 accuracy_tools/CMakeLists.txt create mode 100644 accuracy_tools/LICENSE create mode 100644 accuracy_tools/README.md create mode 100644 accuracy_tools/__init__.py create mode 100644 accuracy_tools/build.sh create mode 100644 accuracy_tools/cmake/Findcpython.cmake create mode 100644 accuracy_tools/cmake/Findgtest.cmake create mode 100644 accuracy_tools/cmake/Findmockcpp.cmake create mode 100644 accuracy_tools/cmake/Findnlohmannjson.cmake create mode 100644 accuracy_tools/cmake/config.ini create mode 100644 accuracy_tools/cmake/download_opensource.sh create mode 100644 accuracy_tools/cmake/utils.cmake create mode 100644 accuracy_tools/docs/0001.capability_matrix.md create mode 100644 accuracy_tools/docs/0002.installation.md create mode 100644 accuracy_tools/docs/0003.config_introduction.md create mode 100644 accuracy_tools/docs/0004.config_examples.md create mode 100644 accuracy_tools/docs/0101.dump_offline_model.md create mode 100644 accuracy_tools/docs/README.md create mode 100644 accuracy_tools/msprobe/CMakeLists.txt create mode 100644 accuracy_tools/msprobe/__init__.py create mode 100644 accuracy_tools/msprobe/__main__.py create mode 100644 accuracy_tools/msprobe/base/__init__.py create mode 100644 accuracy_tools/msprobe/base/cmd.py create mode 100644 accuracy_tools/msprobe/base/component/__init__.py create mode 100644 accuracy_tools/msprobe/base/component/manager.py create mode 100644 accuracy_tools/msprobe/base/config.py create mode 100644 accuracy_tools/msprobe/base/service/__init__.py create mode 100644 accuracy_tools/msprobe/base/service/manager.py create mode 100644 accuracy_tools/msprobe/common/__init__.py create mode 100644 accuracy_tools/msprobe/common/ascend.py create mode 100644 accuracy_tools/msprobe/common/cli.py create mode 100644 accuracy_tools/msprobe/common/dirs.py create mode 100644 accuracy_tools/msprobe/common/stat.py create mode 100644 accuracy_tools/msprobe/common/validation.py create mode 100644 accuracy_tools/msprobe/core/__init__.py create mode 100644 accuracy_tools/msprobe/core/base/__init__.py create mode 100644 accuracy_tools/msprobe/core/base/dump_actuator.py create mode 100644 accuracy_tools/msprobe/core/base/dump_dumper.py create mode 100644 accuracy_tools/msprobe/core/base/dump_writer.py create mode 100644 accuracy_tools/msprobe/core/cli/__init__.py create mode 100644 accuracy_tools/msprobe/core/cli/command_dump.py create mode 100644 accuracy_tools/msprobe/core/components/__init__.py create mode 100644 accuracy_tools/msprobe/core/components/dumper_acl.py create mode 100644 accuracy_tools/msprobe/core/components/dumper_atb.py create mode 100644 accuracy_tools/msprobe/core/components/dumper_caffe.py create mode 100644 accuracy_tools/msprobe/core/components/dumper_offline_model.py create mode 100644 accuracy_tools/msprobe/core/components/dumper_om.py create mode 100644 accuracy_tools/msprobe/core/components/dumper_onnx.py create mode 100644 accuracy_tools/msprobe/core/components/dumper_tf.py create mode 100644 accuracy_tools/msprobe/core/components/dumper_writer.py create mode 100644 accuracy_tools/msprobe/core/config_initiator/__init__.py create mode 100644 accuracy_tools/msprobe/core/config_initiator/config_dump.py create mode 100644 accuracy_tools/msprobe/core/config_initiator/validate_params.py create mode 100644 accuracy_tools/msprobe/core/dump/__init__.py create mode 100644 accuracy_tools/msprobe/core/dump/acl_manager.py create mode 100644 accuracy_tools/msprobe/core/dump/caffe_model.py create mode 100644 accuracy_tools/msprobe/core/dump/om_model.py create mode 100644 accuracy_tools/msprobe/core/dump/onnx_model.py create mode 100644 accuracy_tools/msprobe/core/dump/tf_model.py create mode 100644 accuracy_tools/msprobe/core/service/__init__.py create mode 100644 accuracy_tools/msprobe/core/service/dump.py create mode 100644 accuracy_tools/msprobe/csrc/CMakeLists.txt create mode 100644 accuracy_tools/msprobe/csrc/acl/core/AclApi.cpp create mode 100644 accuracy_tools/msprobe/csrc/acl/include/AclApi.h create mode 100644 accuracy_tools/msprobe/csrc/atb_probe/Override.cpp create mode 100644 accuracy_tools/msprobe/csrc/atb_probe/Override.h create mode 100644 accuracy_tools/msprobe/csrc/atb_probe/core/Helper.cpp create mode 100644 accuracy_tools/msprobe/csrc/atb_probe/core/SaveExtra.cpp create mode 100644 accuracy_tools/msprobe/csrc/atb_probe/core/SaveGraph.cpp create mode 100644 accuracy_tools/msprobe/csrc/atb_probe/core/SaveTensor.cpp create mode 100644 accuracy_tools/msprobe/csrc/atb_probe/core/Stat.cpp create mode 100644 accuracy_tools/msprobe/csrc/atb_probe/include/Helper.h create mode 100644 accuracy_tools/msprobe/csrc/atb_probe/include/SaveExtra.h create mode 100644 accuracy_tools/msprobe/csrc/atb_probe/include/SaveGraph.h create mode 100644 accuracy_tools/msprobe/csrc/atb_probe/include/SaveTensor.h create mode 100644 accuracy_tools/msprobe/csrc/atb_probe/include/Stat.h create mode 100644 accuracy_tools/msprobe/csrc/common/Toolkit.cpp create mode 100644 accuracy_tools/msprobe/csrc/common/Toolkit.h create mode 100644 accuracy_tools/msprobe/csrc/python/PyACLActuator.cpp create mode 100644 accuracy_tools/msprobe/csrc/python/PyACLActuator.h create mode 100644 accuracy_tools/msprobe/csrc/python/PyInterface.cpp create mode 100644 accuracy_tools/msprobe/csrc/python/PyLog.cpp create mode 100644 accuracy_tools/msprobe/csrc/python/PyLog.h create mode 100644 accuracy_tools/msprobe/csrc/utils/Constant.h create mode 100644 accuracy_tools/msprobe/csrc/utils/DataType.h create mode 100644 accuracy_tools/msprobe/csrc/utils/Exception.h create mode 100644 accuracy_tools/msprobe/csrc/utils/IO.cpp create mode 100644 accuracy_tools/msprobe/csrc/utils/IO.h create mode 100644 accuracy_tools/msprobe/csrc/utils/Log.h create mode 100644 accuracy_tools/msprobe/csrc/utils/Path.cpp create mode 100644 accuracy_tools/msprobe/csrc/utils/Path.h create mode 100644 accuracy_tools/msprobe/csrc/utils/Str.cpp create mode 100644 accuracy_tools/msprobe/csrc/utils/Str.h create mode 100644 accuracy_tools/msprobe/utils/__init__.py create mode 100644 accuracy_tools/msprobe/utils/constants.py create mode 100644 accuracy_tools/msprobe/utils/dependencies.py create mode 100644 accuracy_tools/msprobe/utils/env.py create mode 100644 accuracy_tools/msprobe/utils/exceptions.py create mode 100644 accuracy_tools/msprobe/utils/hijack.py create mode 100644 accuracy_tools/msprobe/utils/io.py create mode 100644 accuracy_tools/msprobe/utils/log.py create mode 100644 accuracy_tools/msprobe/utils/path.py create mode 100644 accuracy_tools/msprobe/utils/toolkits.py create mode 100644 accuracy_tools/pyproject.toml create mode 100644 accuracy_tools/requirements/requirements.txt create mode 100644 accuracy_tools/requirements/requirements_tf.txt create mode 100644 accuracy_tools/setup.py create mode 100644 accuracy_tools/test/ST/run_st.py create mode 100644 accuracy_tools/test/UT/.coveragerc create mode 100644 accuracy_tools/test/UT/CMakeLists.txt create mode 100644 accuracy_tools/test/UT/base_ut/component_ut/test_manager_component.py create mode 100644 accuracy_tools/test/UT/base_ut/service_ut/test_manager_service.py create mode 100644 accuracy_tools/test/UT/base_ut/test_cmd.py create mode 100644 accuracy_tools/test/UT/base_ut/test_config.py create mode 100644 accuracy_tools/test/UT/common_ut/test_ascend.py create mode 100644 accuracy_tools/test/UT/common_ut/test_cli.py create mode 100644 accuracy_tools/test/UT/common_ut/test_validation.py create mode 100644 accuracy_tools/test/UT/core_ut/probe_ut/base_ut/test_dump_actuator.py create mode 100644 accuracy_tools/test/UT/core_ut/probe_ut/base_ut/test_dump_dumper.py create mode 100644 accuracy_tools/test/UT/core_ut/probe_ut/base_ut/test_dump_writer.py create mode 100644 accuracy_tools/test/UT/core_ut/probe_ut/components_ut/test_dumper_offline_model.py create mode 100644 accuracy_tools/test/UT/core_ut/probe_ut/config_initiator_ut/test_validate_params.py create mode 100644 accuracy_tools/test/UT/core_ut/probe_ut/dump_ut/test_caffe_model.py create mode 100644 accuracy_tools/test/UT/core_ut/probe_ut/dump_ut/test_onnx_model.py create mode 100644 accuracy_tools/test/UT/core_ut/probe_ut/dump_ut/test_tf_model.py create mode 100644 accuracy_tools/test/UT/csrc_ut/CMakeLists.txt create mode 100644 accuracy_tools/test/UT/csrc_ut/utils_ut/test_log.cpp create mode 100644 accuracy_tools/test/UT/pytest.ini create mode 100644 accuracy_tools/test/UT/run_ut.py create mode 100644 accuracy_tools/test/UT/run_ut.sh create mode 100644 accuracy_tools/test/UT/test___main__.py create mode 100644 accuracy_tools/test/UT/utils_ut/test_dependencies.py create mode 100644 accuracy_tools/test/UT/utils_ut/test_env.py create mode 100644 accuracy_tools/test/UT/utils_ut/test_hijack.py create mode 100644 accuracy_tools/test/UT/utils_ut/test_io.py create mode 100644 accuracy_tools/test/UT/utils_ut/test_log.py create mode 100644 accuracy_tools/test/UT/utils_ut/test_path.py create mode 100644 accuracy_tools/test/UT/utils_ut/test_toolkits.py create mode 100644 accuracy_tools/third_party/.keep diff --git a/accuracy_tools/.clang-format b/accuracy_tools/.clang-format new file mode 100644 index 00000000000..095d6f03539 --- /dev/null +++ b/accuracy_tools/.clang-format @@ -0,0 +1,36 @@ +BasedOnStyle: LLVM + +IndentWidth: 4 +TabWidth: 4 +UseTab: Never + +ColumnLimit: 120 +BreakBeforeBraces: Attach + +BraceWrapping: + AfterNamespace: false + AfterFunction: false + AfterClass: false + AfterControlStatement: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false + SplitEmptyRecord: false + SplitEmptyFunction: false + AfterEnum: true + +AccessModifierOffset: -4 +IndentCaseLabels: true +SpaceBeforeParens: ControlStatements + +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AllowShortFunctionsOnASingleLine: None +AllowShortBlocksOnASingleLine: false + +BinPackParameters: false +BinPackArguments: false + +NamespaceIndentation: All +AllowAllArgumentsOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: true diff --git a/accuracy_tools/CMakeLists.txt b/accuracy_tools/CMakeLists.txt new file mode 100644 index 00000000000..1122a2304b1 --- /dev/null +++ b/accuracy_tools/CMakeLists.txt @@ -0,0 +1,25 @@ +project(accuracy_tools) +cmake_minimum_required(VERSION 3.14) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED OFF) +set(CMAKE_CXX_EXTENSIONS OFF) + +execute_process( + COMMAND uname -m + OUTPUT_VARIABLE machine_arch + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +if (DEFINED ARCH_TYPE AND NOT "${ARCH_TYPE}" STREQUAL "${machine_arch}") + message(FATAL_ERROR "Cross-compilation is not supported currently. (compile ${ARCH_TYPE} on ${machine_arch})") +endif() + +set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake") +set(ENV{PROJECT_ROOT_PATH} "${CMAKE_SOURCE_DIR}") +include("${CMAKE_SOURCE_DIR}/cmake/utils.cmake") +add_subdirectory(msprobe) + +if (DEFINED BUILD_TEST_CASE AND "${BUILD_TEST_CASE}" STREQUAL "True") + add_subdirectory(test/UT) +endif() diff --git a/accuracy_tools/LICENSE b/accuracy_tools/LICENSE new file mode 100644 index 00000000000..261eeb9e9f8 --- /dev/null +++ b/accuracy_tools/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/accuracy_tools/README.md b/accuracy_tools/README.md new file mode 100644 index 00000000000..2cb4079e193 --- /dev/null +++ b/accuracy_tools/README.md @@ -0,0 +1,17 @@ +# 📖 msprobe 使用手册 + +![platform](https://img.shields.io/badge/platform-Linux-yellow) +![License: Apache 2.0](https://img.shields.io/badge/license-Apache%202.0-green) + +## 用前必看 + +使用工具前,请先浏览[工具模块简介、适用场景和当前版本局限](./docs/0001.capability_matrix.md)。 + +## ⚙️ [安装](./docs/0002.installation.md) +## 🛠️ config.json [介绍](./docs/0003.config_introduction.md) 和 [示例](./docs/0004.config_examples.md) + +## 🧰 主要功能 + +### 1 数据采集 + +[离线模型 ONNX、TensorFlow (.pb, saved model)、Ascend OM、Caffe 场景](./docs/0101.dump_offline_model.md) diff --git a/accuracy_tools/__init__.py b/accuracy_tools/__init__.py new file mode 100644 index 00000000000..53529bc8d31 --- /dev/null +++ b/accuracy_tools/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/accuracy_tools/build.sh b/accuracy_tools/build.sh new file mode 100644 index 00000000000..90b120efaae --- /dev/null +++ b/accuracy_tools/build.sh @@ -0,0 +1,83 @@ +#!/bin/bash + +set -e + +BUILD_PATH=$(pwd) + +BUILD_ARGS=$(getopt \ + -o ha:v:j:ft \ + --long help,release,debug,arch:,python-version:,jobs:,force-rebuild,local,test-cases \ + -- "$@") +eval set -- "${BUILD_ARGS}" + +ARCH_TYPE=$(uname -m) +BUILD_TYPE=release +CONCURRENT_JOBS=16 +BUILD_TEST_CASE=False +USE_LOCAL_FIRST=False +PYTHON_VERSION="" + +HELP_DOC=$(cat << EOF +Usage: build.sh [OPTION]...\n +Build the C++ part of msprobe.\n +\n +Arguments:\n + -a, --arch Specify the schema, which generally does not need to be set up.\n + -j, --jobs Specify the number of compilation jobs(default 16).\n + -f, --force-rebuild Clean up the cache before building.\n + -t, --test-cases Build test cases.\n + --local Prioritize the use of on-premises, third-party resources as dependencies.\n + --release Build the release version(default).\n + --debug Build the debug version. + -v, --python-version Specify version of python. +EOF +) + +while true; do + case "$1" in + -h | --help) + echo -e ${HELP_DOC} + exit 0 ;; + -a | --arch) + ARCH_TYPE="$2" ; shift 2 ;; + -v | --python-version) + PYTHON_VERSION="$2" ; shift 2 ;; + --release) + BUILD_TYPE=release ; shift ;; + --debug) + BUILD_TYPE=debug ; shift ;; + -j | --jobs) + CONCURRENT_JOBS="$2" ; shift 2 ;; + --local) + USE_LOCAL_FIRST=True ; shift ;; + -f | --force-rebuild) + rm -rf "${BUILD_PATH}/lib" "${BUILD_PATH}/output" "${BUILD_PATH}/msprobe/lib/msprobe_c.so" + shift ;; + -t | --test-cases) + BUILD_TEST_CASE=True ; shift ;; + --) + shift ; break ;; + *) + echo "Unknow argument $1" + exit 1 ;; + esac +done + +BUILD_OUTPUT_PATH=${BUILD_PATH}/output/${BUILD_TYPE} + +cmake -B ${BUILD_OUTPUT_PATH} -S . -DARCH_TYPE=${ARCH_TYPE} -DBUILD_TYPE=${BUILD_TYPE} \ + -DUSE_LOCAL_FIRST=${USE_LOCAL_FIRST} -DBUILD_TEST_CASE=${BUILD_TEST_CASE} \ + -DPYTHON_VERSION=${PYTHON_VERSION} +cd ${BUILD_OUTPUT_PATH} +make -j${CONCURRENT_JOBS} + +if [[ ! -e ${BUILD_OUTPUT_PATH}/msprobe/csrc/libmsprobe_c.so ]]; then + echo "Failed to build libmsprobe_c.so." + exit 1 +fi + +if [[ ! -e ${BUILD_PATH}/msprobe/lib ]]; then + mkdir ${BUILD_PATH}/msprobe/lib +fi + +cp ${BUILD_OUTPUT_PATH}/msprobe/csrc/libmsprobe_c.so ${BUILD_PATH}/msprobe/lib/msprobe_c.so diff --git a/accuracy_tools/cmake/Findcpython.cmake b/accuracy_tools/cmake/Findcpython.cmake new file mode 100644 index 00000000000..815fbc638de --- /dev/null +++ b/accuracy_tools/cmake/Findcpython.cmake @@ -0,0 +1,16 @@ +set(PKG_NAME cpython) + +if (NOT ${PKG_NAME}_FOUND) + +find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development) +if (NOT Python3_FOUND) + message(FATAL_ERROR "${Python3} is not found.") +endif() + +set(PACKAGE_VERSION ${Python3_VERSION}) + +include_directories(${Python3_INCLUDE_DIRS}) +set(${PKG_NAME}_LIBRARIES ${Python3_LIBRARIES}) +set(${PKG_NAME}_FOUND TRUE) + +endif() diff --git a/accuracy_tools/cmake/Findgtest.cmake b/accuracy_tools/cmake/Findgtest.cmake new file mode 100644 index 00000000000..66297bf16f2 --- /dev/null +++ b/accuracy_tools/cmake/Findgtest.cmake @@ -0,0 +1,47 @@ +set(PACKAGE_VERSION 1.12.1) + +set(PKG_NAME gtest) +set(SHA256_VALUE "81964fe578e9bd7c94dfdb09c8e4d6e6759e19967e397dbea48d1c10e45d0df2") +set(DOWNLOAD_PATH "$ENV{PROJECT_ROOT_PATH}/third_party") +set(DIR_NAME "${DOWNLOAD_PATH}/googletest") + +if (NOT ${PKG_NAME}_FOUND) + +download_opensource_pkg(${PKG_NAME} + SHA256 ${SHA256_VALUE} + DOWNLOAD_PATH ${DOWNLOAD_PATH} +) + +include_directories(${DIR_NAME}/googletest/include) +include_directories(${DIR_NAME}/googlemock/include) + +set(BUILD_DEPENDENCY_PATH "$ENV{PROJECT_ROOT_PATH}/build_dependency") +execute_process( + WORKING_DIRECTORY ${DIR_NAME} + COMMAND cmake . -DBUILD_SHARED_LIBS=ON + RESULT_VARIABLE RESULT +) +if (NOT RESULT EQUAL 0) + message(FATAL_ERROR "Failed to build gtest. ${RESULT}") +endif() +execute_process( + WORKING_DIRECTORY ${DIR_NAME} + COMMAND make -j16 + RESULT_VARIABLE RESULT +) +if (NOT RESULT EQUAL 0) + message(FATAL_ERROR "Failed to build gtest. ${RESULT}") +endif() + +file(GLOB GTEST_SO "${DIR_NAME}/lib/libgtest.so") +file(GLOB GMOCK_SO "${DIR_NAME}/lib/libgmock.so") +file(GLOB GTEST_MAIN_SO "${DIR_NAME}/lib/libgtest_main.so") +file(GLOB GMOCK_MAIN_SO "${DIR_NAME}/lib/libgmock_main.so") +if (NOT GTEST_SO OR NOT GMOCK_SO OR NOT GTEST_MAIN_SO OR NOT GMOCK_MAIN_SO) + message(FATAL_ERROR "Failed to build gtest.") +endif() + +set(${PKG_NAME}_LIBRARIES "${GTEST_SO};${GMOCK_SO};${GTEST_MAIN_SO};${GMOCK_MAIN_SO}") +set(${PKG_NAME}_FOUND TRUE) + +endif() diff --git a/accuracy_tools/cmake/Findmockcpp.cmake b/accuracy_tools/cmake/Findmockcpp.cmake new file mode 100644 index 00000000000..9bc71c23701 --- /dev/null +++ b/accuracy_tools/cmake/Findmockcpp.cmake @@ -0,0 +1,43 @@ +set(PACKAGE_VERSION 2.7) + +set(PKG_NAME mockcpp) +set(SHA256_VALUE "0dc7111c5be9785d0550ed3b68db7e12fd5d7802b7bc6548c52ac7b9e727fcc1") +set(DOWNLOAD_PATH "$ENV{PROJECT_ROOT_PATH}/third_party") +set(DIR_NAME "${DOWNLOAD_PATH}/mockcpp") + +if (NOT ${PKG_NAME}_FOUND) + +download_opensource_pkg(${PKG_NAME} + SHA256 ${SHA256_VALUE} + DOWNLOAD_PATH ${DOWNLOAD_PATH} +) + +include_directories(${DIR_NAME}/include) +include_directories(${DIR_NAME}/3rdparty) + +execute_process( + WORKING_DIRECTORY ${DIR_NAME} + COMMAND cmake . + RESULT_VARIABLE RESULT +) +if (NOT RESULT EQUAL 0) + message(FATAL_ERROR "Failed to build mockcpp. ${RESULT}") +endif() +execute_process( + WORKING_DIRECTORY ${DIR_NAME} + COMMAND make -j16 + RESULT_VARIABLE RESULT +) +if (NOT RESULT EQUAL 0) + message(FATAL_ERROR "Failed to build mockcpp. ${RESULT}") +endif() + +file(GLOB MOCKCPP_LIB "${DIR_NAME}/src/libmockcpp.a") +if (NOT MOCKCPP_LIB) + message(FATAL_ERROR "Failed to build mockcpp.") +endif() + +set(${PKG_NAME}_LIBRARIES "${MOCKCPP_LIB}") +set(${PKG_NAME}_FOUND TRUE) + +endif() diff --git a/accuracy_tools/cmake/Findnlohmannjson.cmake b/accuracy_tools/cmake/Findnlohmannjson.cmake new file mode 100644 index 00000000000..63c77137703 --- /dev/null +++ b/accuracy_tools/cmake/Findnlohmannjson.cmake @@ -0,0 +1,18 @@ +set(PACKAGE_VERSION 3.10.1) + +set(PKG_NAME nlohmannjson) +set(SHA256_VALUE "5c7d0a0542431fef628f8dc4c34fd022fe8747ccb577012d58f38672d8747e0d") +set(DOWNLOAD_PATH "$ENV{PROJECT_ROOT_PATH}/third_party") +set(DIR_NAME "${DOWNLOAD_PATH}/JSON-for-Modern-CPP") + +if (NOT ${PKG_NAME}_FOUND) + +download_opensource_pkg(${PKG_NAME} + SHA256 ${SHA256_VALUE} + DOWNLOAD_PATH ${DOWNLOAD_PATH} +) + +include_directories(${DIR_NAME}/include) +set(${PKG_NAME}_FOUND TRUE) + +endif() diff --git a/accuracy_tools/cmake/config.ini b/accuracy_tools/cmake/config.ini new file mode 100644 index 00000000000..9a9166b5a8a --- /dev/null +++ b/accuracy_tools/cmake/config.ini @@ -0,0 +1,11 @@ +[nlohmannjson] +url = https://gitee.com/mirrors/JSON-for-Modern-CPP.git +tag = v3.10.1 + +[gtest] +url = https://gitee.com/mirrors/googletest.git +tag = release-1.12.1 + +[mockcpp] +url = https://gitee.com/sinojelly/mockcpp.git +tag = v2.7 diff --git a/accuracy_tools/cmake/download_opensource.sh b/accuracy_tools/cmake/download_opensource.sh new file mode 100644 index 00000000000..4beec113362 --- /dev/null +++ b/accuracy_tools/cmake/download_opensource.sh @@ -0,0 +1,110 @@ +#!/bin/bash + +if [ "$#" -lt 2 ]; then + echo "Usage: $0 [ ] [ ]" + exit 1 +fi + +pkg_name=$1 +path=$2 + +if [ "$#" -ge 3 ]; then + sha256_value=$3 +fi +if [ "$#" -ge 4 ]; then + tag=$4 +fi + +url=$(awk -F " = " '/\['${pkg_name}'\]/{a=1}a==1&&$1~/url/{print $2;exit}' config.ini) +tag=$(awk -F " = " '/\['${pkg_name}'\]/{a=1} a==1 && $1 ~ /tag/ {print $2; exit}' config.ini) + +if [[ ! $url = https* ]]; then + echo "[ERROR] The URL of $pkg_name is illegal." + exit 1 +fi + +echo "[INFO] Start to download ${url}..." + +if [ ! -d "$path" ]; then + echo "[ERROR] The specified path does not exist: $path" + exit 1 +fi +cd ${path} + +extension=$(echo "${url}" | awk -F'[./]' '{print $NF}') +fullname="${path}/$(basename "${url}")" +if [[ "${extension}" == "gz" || "${extension}" == "zip" ]]; then + if [[ -e "${fullname}" ]]; then + echo "[INFO] Source ${fullname} already exists, skipping download." + else + echo "[INFO] Start downloading: ${url}" + curl -L --fail --retry 3 --connect-timeout 10 -o "${fullname}" "${url}" + if [[ $? -ne 0 ]]; then + echo "[ERROR] Download failed: ${url}" + rm -f "${fullname}" + exit 1 + fi + + filesize=$(stat -c%s "${fullname}") + if [[ "${filesize}" -lt 10240 ]]; then + echo "[ERROR] Downloaded file too small (<10KB), possible error page: ${url}" + rm -f "${fullname}" + exit 1 + fi + + if file "${fullname}" | grep -q "HTML"; then + echo "[ERROR] Downloaded file is HTML, not a zip archive." + rm -f "${fullname}" + exit 1 + fi + + echo "[INFO] Download success: ${url} (${filesize} bytes)" + fi + + if [[ ! -z "${sha256_value}" ]]; then + sha256data=$(sha256sum "${fullname}" | cut -d' ' -f1) + if [[ "${sha256data}" != "${sha256_value}" ]]; then + echo "[ERROR] SHA256 verification failed: ${url}" + echo "[ERROR] Expected: ${sha256_value}" + echo "[ERROR] Actual : ${sha256data}" + exit 1 + fi + fi + + if [[ "${extension}" == "gz" ]]; then + tar -zxvf "${fullname}" -C ./ -n > /dev/null + elif [[ "${extension}" == "zip" ]]; then + unzip -n "${fullname}" -d ./ > /dev/null + fi +elif [[ "${extension}" == "git" ]]; then + repo_dir=$(basename "${url}" .git) + + if [[ -d "${repo_dir}" ]]; then + echo "[INFO] Repository already exists: ${repo_dir}, skipping clone." + if [[ -n "${tag}" ]]; then + cd "${repo_dir}" + echo "[INFO] Checking out ${tag}..." + git fetch origin + git checkout "${tag}" || { + echo "[ERROR] Failed to checkout ${tag}" + exit 1 + } + cd - + fi + else + if [[ -n "${tag}" ]]; then + git clone --progress -b "${tag}" "${url}" + else + git clone --progress "${url}" + fi + if [[ $? -eq 0 ]]; then + echo "[INFO] Clone success: ${url}" + else + echo "[ERROR] Clone failed: ${url}" + exit 1 + fi + fi +else + echo "[ERROR] Unknown url type: ${url}" + exit 1 +fi diff --git a/accuracy_tools/cmake/utils.cmake b/accuracy_tools/cmake/utils.cmake new file mode 100644 index 00000000000..95038933c25 --- /dev/null +++ b/accuracy_tools/cmake/utils.cmake @@ -0,0 +1,42 @@ +function(download_opensource_pkg pkg_name) + message("start to download ${pkg_name}...") + set(options) + set(oneValueArgs SHA256 GIT_TAG DOWNLOAD_PATH DIR_NAME BUILD_CMD) + set(multiValueArgs PATCHES) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + if (NOT PKG_DOWNLOAD_PATH) + set(PKG_DOWNLOAD_PATH "${CMAKE_SOURCE_DIR}/../third_party") + endif() + file(MAKE_DIRECTORY ${PKG_DOWNLOAD_PATH}) + + execute_process( + WORKING_DIRECTORY $ENV{PROJECT_ROOT_PATH}/cmake + COMMAND bash download_opensource.sh ${pkg_name} ${PKG_DOWNLOAD_PATH} ${PKG_SHA256} ${PKG_GIT_TAG} + RESULT_VARIABLE RESULT + ) + if (NOT RESULT EQUAL 0) + message(FATAL_ERROR "Failed to download ${pkg_name}(${RESULT}).") + endif() + if (PKG_BUILD_CMD) + execute_process(COMMAND bash -c "cd ${PKG_DOWNLOAD_PATH}/${DIR_NAME};${PKG_BUILD_CMD}") + + endif() +endfunction() + +function(compile_protobuf_file output_path) + if (NOT PROTOC_EXECUTABLE) + message(FATAL_ERROR "You shall install protobuf first.") + endif() + file(MAKE_DIRECTORY ${output_path}) + foreach(file ${ARGN}) + get_filename_component(abs_file_path ${file} ABSOLUTE) + get_filename_component(file_name ${file} NAME_WE) + get_filename_component(file_dir ${abs_file_path} PATH) + file(RELATIVE_PATH rel_path ${CMAKE_CURRENT_SOURCE_DIR} ${file_dir}) + execute_process( + COMMAND ${PROTOC_EXECUTABLE} -I${file_dir} --cpp_out=${output_path} ${abs_file_path} + ) + message("Compile protobuf file ${file}") + endforeach() +endfunction() diff --git a/accuracy_tools/docs/0001.capability_matrix.md b/accuracy_tools/docs/0001.capability_matrix.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/accuracy_tools/docs/0002.installation.md b/accuracy_tools/docs/0002.installation.md new file mode 100644 index 00000000000..75e69006579 --- /dev/null +++ b/accuracy_tools/docs/0002.installation.md @@ -0,0 +1,63 @@ +# 安装 + +## 1 依赖 + +### 1.1 硬件环境 + +[昇腾产品形态说明](https://www.hiascend.com/document/detail/zh/canncommercial/80RC22/quickstart/quickstart/quickstart_18_0002.html) + +### 1.2 软件环境 + +[固件和驱动](https://www.hiascend.com/hardware/firmware-drivers/community?product=1&model=30&cann=8.2.RC1.alpha001&driver=Ascend+HDK+25.0.RC1) + +| 框架 | 是否必选 | 版本 | +| -------------------------------------------------------------------------------------------- | -------- | ----------------------------------------------------------- | +| [Python](https://www.python.org/) | 是 | 3.7 ~ 3.12 | +| [GCC](https://gcc.gnu.org/) | 是 | 需支持 C++17 标准 | +| [git](https://git-scm.com/) | 是 | 推荐稳定版本 2.34.x - 2.42.x | +| [CANN](https://www.hiascend.cn/developer/download/community/result?module=cann)*1 | 否 | 完全兼容,根据 CPU 架构和 NPU 型号选择 toolkit 和 kernel 包 | +| [PyTorch (CPU, GPU)](https://pytorch.org/) | 否 | 1.11、2.1 ~ 2.6,对应的 Python 版本最低为 3.8 | +| [PyTorch (NPU)](https://gitee.com/ascend/pytorch) | 否 | 1.11、2.1 ~ 2.6,对应的 Python 版本最低为 3.8 | +| [MindIE-LLM](https://gitee.com/ascend/MindIE-LLM)*2 | 否 | 1.0 | +| [MindSpore](https://www.mindspore.cn/) | 否 | 2.4 ~ 2.6,对应的 Python 版本为 3.9 ~ 3.11*3 | +| [MSAdapter](https://gitee.com/mindspore/msadapter)*4 | 否 | 2.1.0 | +| [TensorFlow](https://github.com/tensorflow/tensorflow/releases/tag/v2.6.5)*5 | 否 | 仅支持 2.6.5 版本,对应的 Python 版本为 3.7 ~ 3.9 | +| [Caffe](https://caffe.berkeleyvision.org/) | 否 | 1.0,仅支持 Python 3.7 版本 | + +*1: **CANN** 安装参见[社区资料](https://www.hiascend.com/document/detail/zh/canncommercial/81RC1/softwareinst/instg/instg_0002.html)。 + +*2: **MindIE-LLM** 非开源,如需查看请联系该组件的华为工程师。 + +*3: **MindSpore** 历史版本参见[官网](https://www.mindspore.cn/versions)。 + +*4: **MSAdapter** 非开源,如需查看请联系该组件的华为工程师。 + +*5: **TensorFlow** 模型在 **Ascend NPU** 的迁移,还需要安装 [TF 插件](https://gitee.com/ascend/tensorflow/releases)。 + +用户可以根据使用场景自行安装适配的 Python 和其他软件包,并在使用 msprobe 前确保所依赖的框架可以正常运行。 + +## 2 安装 msprobe + +### 2.1 从源码安装 + +```sh +git clone https://gitee.com/ascend/mstt.git +cd mstt/accuracy_tools + +pip install setuptools wheel + +python setup.py bdist_wheel [--compat tf] +cd ./dist +pip install mindstudio_probe*.whl +python -c 'import os; import site; import subprocess; parent_path=site.getsitepackages()[0]; subprocess.run(["chmod", "550", os.path.join(parent_path, "msprobe"), "-R"])' +``` +**注意**:`--compat` 参数非必选,默认为无,当前支持 tf。 + + + + +# 3 查看 msprobe 工具信息 + +```sh +pip show mindstudio_probe +``` diff --git a/accuracy_tools/docs/0003.config_introduction.md b/accuracy_tools/docs/0003.config_introduction.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/accuracy_tools/docs/0004.config_examples.md b/accuracy_tools/docs/0004.config_examples.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/accuracy_tools/docs/0101.dump_offline_model.md b/accuracy_tools/docs/0101.dump_offline_model.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/accuracy_tools/docs/README.md b/accuracy_tools/docs/README.md new file mode 100644 index 00000000000..935f5a80c00 --- /dev/null +++ b/accuracy_tools/docs/README.md @@ -0,0 +1,27 @@ +# msprobe 文档编写查阅指南 + +## 1 文档编号 + +0. 公共文档:0001 - 0099 +1. 数据采集:0101 - 0199 +2. 溢出检测:0201 - 0299 +3. 精度预检:0301 - 0399 +4. 精度比对:0401 - 0499 +5. 模型改图:0501 - 0599 +6. 状态监控:0601 - 0699 +7. 数据解析:0701 - 0799 +8. 参数检查:0801 - 0899 + +## 2 文档模板 + +```md +# 简介 + +# 接口介绍 + +# 使用示例 + +# 输出件介绍 + +# 约束 +``` diff --git a/accuracy_tools/msprobe/CMakeLists.txt b/accuracy_tools/msprobe/CMakeLists.txt new file mode 100644 index 00000000000..86735ca287f --- /dev/null +++ b/accuracy_tools/msprobe/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(csrc) diff --git a/accuracy_tools/msprobe/__init__.py b/accuracy_tools/msprobe/__init__.py new file mode 100644 index 00000000000..e24ad32c29b --- /dev/null +++ b/accuracy_tools/msprobe/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe import core +from msprobe.base import Service diff --git a/accuracy_tools/msprobe/__main__.py b/accuracy_tools/msprobe/__main__.py new file mode 100644 index 00000000000..17336619123 --- /dev/null +++ b/accuracy_tools/msprobe/__main__.py @@ -0,0 +1,26 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.common.cli import MainCommand + + +def main(): + msprobe_command = MainCommand() + msprobe_command.register() + args = msprobe_command.parse() + msprobe_command.execute(args) + + +if __name__ == "__main__": + main() diff --git a/accuracy_tools/msprobe/base/__init__.py b/accuracy_tools/msprobe/base/__init__.py new file mode 100644 index 00000000000..f57ae818ae0 --- /dev/null +++ b/accuracy_tools/msprobe/base/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.base.cmd import BaseCommand, Command +from msprobe.base.component.manager import BaseComponent, Component, ConsumerComp, ProducerComp, Scheduler +from msprobe.base.config import SIZE_1M, BaseConfig, Dict2Class +from msprobe.base.service.manager import BaseService, Service diff --git a/accuracy_tools/msprobe/base/cmd.py b/accuracy_tools/msprobe/base/cmd.py new file mode 100644 index 00000000000..9a0022bce7d --- /dev/null +++ b/accuracy_tools/msprobe/base/cmd.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from argparse import RawTextHelpFormatter +from sys import argv + +from msprobe.utils.constants import CmdConst, MsgConst +from msprobe.utils.exceptions import MsprobeException + + +class Command: + """ + A hierarchical command registration system that supports multi-level command structures. + """ + + _cmd_map = {} # Internal storage: {parent_cmd: {name: command_class}} + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(Command, cls).__new__(cls) + return cls._instance + + @classmethod + def register(cls, parent_cmd, name): + def decorator(command_cls): + if parent_cmd not in cls._cmd_map: + cls._cmd_map[parent_cmd] = {} + cls._cmd_map[parent_cmd][name] = command_cls + return command_cls + + return decorator + + @classmethod + def get(cls, parent_cmd): + return cls._cmd_map.get(parent_cmd, {}) + + +class BaseCommand(ABC): + def __init__(self): + self.formatter_class = RawTextHelpFormatter + + @property + def service_key(self): + if isinstance(self.subcommand_level, int) and self.subcommand_level > 0: + return argv[self.subcommand_level] if len(argv) > self.subcommand_level else None + else: + raise MsprobeException(MsgConst.INVALID_ARGU, "Subcommand level must be a positive integer.") + + @abstractmethod + def add_arguments(self, parse): + pass + + def build_parser(self, parent_parser, parent_cmd_class): + if self.subcommand_level > MsgConst.MAX_RECURSION_DEPTH: + raise MsprobeException( + MsgConst.RISK_ALERT, f"Maximum recursion depth of {MsgConst.MAX_RECURSION_DEPTH} exceeded." + ) + subcommands = Command.get(parent_cmd_class) + if subcommands: + self.subcommand_level += 1 + subparsers = parent_parser.add_subparsers(dest=f"L{self.subcommand_level}command") + for name, cmd_class in subcommands.items(): + cmd_parser = subparsers.add_parser( + name=name, help=CmdConst.HELP_TASK_MAP.get(name), formatter_class=self.formatter_class + ) + cmd_class.add_arguments(cmd_parser) + self.build_parser(cmd_parser, cmd_class) diff --git a/accuracy_tools/msprobe/base/component/__init__.py b/accuracy_tools/msprobe/base/component/__init__.py new file mode 100644 index 00000000000..53529bc8d31 --- /dev/null +++ b/accuracy_tools/msprobe/base/component/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/accuracy_tools/msprobe/base/component/manager.py b/accuracy_tools/msprobe/base/component/manager.py new file mode 100644 index 00000000000..fdc3078e185 --- /dev/null +++ b/accuracy_tools/msprobe/base/component/manager.py @@ -0,0 +1,278 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from collections import deque +from threading import RLock + +from msprobe.utils.constants import MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.toolkits import register + + +class BaseComponent(object): + """ + Methods that need to be implemented: + activate: Called when service.start() is invoked. + deactivate: Called when service.stop() is invoked. + """ + + def __init__(self, priority=100): + self.activated = False + self.priority = priority + + @property + def is_activated(self): + return self.activated + + def activate(self, *args, **kwargs): + pass + + def deactivate(self, *args, **kwargs): + pass + + def do_activate(self): + if self.activated: + return + self.activate() + self.activated = True + + def do_deactivate(self): + if not self.activated: + return + self.deactivate() + self.activated = False + + +class ProducerComp(BaseComponent, ABC): + """ + A ProducerComp can generate data. + If the data is passively generated (e.g., when a consumer applies the data), implement "load_data". + If the data is actively generated (e.g., when an interest event occurs), + call "publish" to send it to subscribers. + """ + + def __init__(self, priority): + super(ProducerComp, self).__init__(priority) + self.output_buffer = deque() + self.subscribers = set() + + @property + def is_ready(self): + return len(self.output_buffer) > 0 + + @abstractmethod + def load_data(self): + pass + + def publish(self, data, msg_id=0): + """ + Wrap the data and pack it into the output buffer. + """ + self.output_buffer.append([self, data, msg_id]) + Scheduler().enqueue([self]) + + def on_subscribe(self, comp): + if not isinstance(comp, ConsumerComp): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "Only ConsumerComp can subscribe to ProducerComp.") + self.subscribers.add(comp) + + def retrieve(self): + if self.output_buffer: + return self.output_buffer.popleft() + else: + return None + + def do_load_data(self): + if self.output_buffer: + return + data = self.load_data() + if data: + self.publish(data) + + def get_subscribers(self): + return self.subscribers + + +class ConsumerComp(BaseComponent, ABC): + """ + A ConsumerComp can consume data. + Call "subscribe" to subscribe data from a ProducerComp. + Implement "consume" to process data. + """ + + def __init__(self, priority): + super(ConsumerComp, self).__init__(priority) + self.dependencies = {} + + def subscribe(self, comp): + if not isinstance(comp, ProducerComp): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "Only ProducerComp can subscribe to ConsumerComp.") + if self.is_activated: + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, f"Component {comp} must be subscribed before activation." + ) + if self.is_cycle(comp): + raise MsprobeException(MsgConst.RISK_ALERT, "Cycle dependency detected! Subscription denied.") + comp.on_subscribe(self) + if comp not in self.dependencies: + self.dependencies[comp] = None + + @abstractmethod + def consume(self, packages): + pass + + def is_cycle(self, comp, visited=None, stack=None): + if visited is None: + visited = set() + if stack is None: + stack = set() + if comp in stack: + return True + if comp in visited: + return False + visited.add(comp) + stack.add(comp) + if isinstance(comp, ConsumerComp): + for producer in comp.dependencies: + if self.is_cycle(producer, visited, stack): + return True + stack.remove(comp) + return False + + def on_receive(self, package): + try: + self.dependencies[package[0]] = package + except Exception as e: + raise MsprobeException( + MsgConst.PARSING_FAILED, + "The first element in the data (self.output_buffer) published by the producer must be itself.", + ) from e + + def get_empty_dependencies(self): + dependencies_list = [] + for k, v in self.dependencies.items(): + if v is None: + dependencies_list.append(k) + return dependencies_list + + def do_consume(self): + """ + Encapsulate the data in "dependencies" and invoke it using "consume". + """ + if self.get_empty_dependencies(): + return + packages = [] + for key in self.dependencies: + packages.append(self.dependencies[key]) + self.dependencies[key] = None + self.consume(packages) + + +class Component: + _component_type_map = {} + + @classmethod + def register(cls, name): + return register(name, cls._component_type_map) + + @classmethod + def get(cls, name): + return cls._component_type_map.get(name) + + +class Scheduler: + _instance = None + _lock = RLock() + + def __new__(cls): + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + self.comp_ref = {} + self.queue = deque() + self.enqueued = set() + self.in_loop = False + self._initialized = True + + def add(self, components): + for comp in components: + if comp in self.comp_ref: + self.comp_ref[comp] += 1 + else: + self.comp_ref[comp] = 1 + comp.do_activate() + self.enqueue([comp]) + self.run_loop() + + def remove(self, components): + for comp in components: + if comp not in self.comp_ref: + continue + if self.comp_ref[comp] > 1: + self.comp_ref[comp] -= 1 + else: + comp.do_deactivate() + del self.comp_ref[comp] + + def enqueue(self, comps): + for comp in comps: + if comp not in self.enqueued: + self.queue.append(comp) + self.enqueued.add(comp) + + def run_loop(self): + if self.in_loop: + return + self.in_loop = True + try: + while self.queue: + comp = self.queue.popleft() + self.enqueued.remove(comp) + if isinstance(comp, ConsumerComp): + self._schedule_consumer(comp) + if isinstance(comp, ProducerComp): + self._schedule_producer(comp) + finally: + self.in_loop = False + + def _schedule_producer(self, comp: ProducerComp): + if not comp.is_ready: + return + package = comp.retrieve() + if not package: + return + subscribers = comp.get_subscribers() + if not subscribers: + return + for subscriber in subscribers: + subscriber.on_receive(package) + self.enqueue([subscriber]) + + def _schedule_consumer(self, comp: ConsumerComp): + dependencies = comp.get_empty_dependencies() + if not dependencies: + comp.do_consume() + self.enqueue([comp]) + return + for dependency in dependencies: + dependency.do_load_data() + if dependency.is_ready: + self.enqueue([dependency]) diff --git a/accuracy_tools/msprobe/base/config.py b/accuracy_tools/msprobe/base/config.py new file mode 100644 index 00000000000..49d57925cde --- /dev/null +++ b/accuracy_tools/msprobe/base/config.py @@ -0,0 +1,117 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from msprobe.common.validation import ( + valid_buffer_size, + valid_framework, + valid_level, + valid_log_level, + valid_seed, + valid_step_or_rank, + valid_task, +) +from msprobe.utils.constants import CfgConst, MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.io import load_json +from msprobe.utils.log import logger + +SIZE_1M = 1_048_576 # 1024 * 1024 + + +class BaseConfig(ABC): + def __init__(self, config_path, task="", framework="", step: list = None, level: list = None): + self.config_path = config_path + self.config = load_json(self.config_path) + self.task_config = {} + self.task = task + self.framework = framework + self.step = step + self.level = level + + def __getattribute__(self, name): + attr = object.__getattribute__(self, name) + if name == "check_config" and callable(attr): + + def wrapper(*args, **kwargs): + self._common_check() + self._get_task_dict() + result = attr(*args, **kwargs) + return result + + return wrapper + return attr + + @abstractmethod + def check_config(self): + pass + + def _get_task_dict(self): + self.task_config = self.config.get(self.config.get(CfgConst.TASK)) + if not self.task_config: + raise MsprobeException( + MsgConst.REQUIRED_ARGU_MISSING, f'Missing dictionary for key "{self.config.get(CfgConst.TASK)}".' + ) + + def _common_check(self): + logger.info("Validating configuration file parameters.") + self._update_config(self.config, CfgConst.TASK, valid_task, self.task or self.config.get(CfgConst.TASK, None)) + self._update_config( + self.config, CfgConst.FRAMEWORK, valid_framework, self.framework or self.config.get(CfgConst.FRAMEWORK, None) + ) + self._update_config( + self.config, CfgConst.STEP, valid_step_or_rank, self.step or self.config.get(CfgConst.STEP, []) + ) + self._update_config(self.config, CfgConst.RANK, valid_step_or_rank, self.config.get(CfgConst.RANK, [])) + self._update_config( + self.config, + CfgConst.LEVEL, + valid_level, + self.level or self.config.get(CfgConst.LEVEL, [CfgConst.LEVEL_API]), + ) + self._update_config( + self.config, CfgConst.LOG_LEVEL, valid_log_level, self.config.get(CfgConst.LOG_LEVEL, "info") + ) + self._update_config(self.config, CfgConst.SEED, valid_seed, self.config.get(CfgConst.SEED, None)) + self._update_config( + self.config, CfgConst.BUFFER_SIZE, valid_buffer_size, self.config.get(CfgConst.BUFFER_SIZE, SIZE_1M) + ) + + def _update_config(self, dic: dict, key: str, check_fun, value: str): + dic[key] = check_fun(value) + + +class Dict2Class: + def __init__(self, data: dict, depth: int = 0): + if depth > MsgConst.MAX_RECURSION_DEPTH: + raise MsprobeException( + MsgConst.RISK_ALERT, f"Maximum recursion depth of {MsgConst.MAX_RECURSION_DEPTH} exceeded." + ) + if data.get(CfgConst.TASK) in data: + data_pop = data.pop(data.get(CfgConst.TASK)) + for key, value in data_pop.items(): + if key == "input" and len(value) == 2: + setattr(self, "input_shape", value[0]) + setattr(self, "input_path", value[1]) + setattr(self, key, value) + for key, value in data.items(): + if isinstance(value, dict): + setattr(self, key, Dict2Class(value, depth + 1)) + else: + setattr(self, key, value) + + @classmethod + def __getattr__(cls, item): + raise MsprobeException(MsgConst.ATTRIBUTE_ERROR, f"{cls.__name__} object has no attribute {item}.") diff --git a/accuracy_tools/msprobe/base/service/__init__.py b/accuracy_tools/msprobe/base/service/__init__.py new file mode 100644 index 00000000000..53529bc8d31 --- /dev/null +++ b/accuracy_tools/msprobe/base/service/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/accuracy_tools/msprobe/base/service/manager.py b/accuracy_tools/msprobe/base/service/manager.py new file mode 100644 index 00000000000..756f3609811 --- /dev/null +++ b/accuracy_tools/msprobe/base/service/manager.py @@ -0,0 +1,114 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from msprobe.base import BaseComponent, Scheduler +from msprobe.common.validation import valid_task +from msprobe.utils.constants import CfgConst, CmdConst +from msprobe.utils.io import load_json +from msprobe.utils.toolkits import get_current_rank, register + +_TASK_SERVICE_MAP = {CfgConst.TASK_STAT: CmdConst.DUMP, CfgConst.TASK_TENSOR: CmdConst.DUMP} + + +class Service: + _services_map = {} + + def __init__(self, *args, **kwargs): + cmd_namespace = kwargs.get("cmd_namespace") + serv_name = kwargs.get("serv_name") + if hasattr(cmd_namespace, CfgConst.CONFIG_PATH): + if not kwargs.get(CfgConst.TASK): + config = load_json(cmd_namespace.config_path) + task = valid_task(config.get(CfgConst.TASK)) + else: + task = valid_task(kwargs.get(CfgConst.TASK)) + serv_name = _TASK_SERVICE_MAP.get(task) + self.service_class = self.get(serv_name) + self.service_instance = self.service_class(*args, **kwargs) + + def __getattr__(self, name): + return getattr(self.service_instance, name) + + @classmethod + def register(cls, name): + return register(name, cls._services_map) + + @classmethod + def get(cls, name): + return cls._services_map.get(name) + + +class BaseService(ABC): + def __init__(self): + self.comps = [] + self.current_step = 0 + self.scheduler = Scheduler() + + @property + def is_skip(self): + return False + + @property + def current_rank(self): + try: + return int(get_current_rank()) + except Exception: + return None + + @abstractmethod + def construct(self): + pass + + def start(self, *args, **kwargs): + """ + Service startup workflow: + 1. Configure services (init_start). + 2. Build components (construct). + 3. Filter/prioritize components (ignore_actuator), then schedule execution. + 4. Schedule execution and cleanup. + 5. Post-processing (finalize_start). + """ + if self.is_skip: + return + self.init_start() + self.construct() + for attr in self.__dict__.values(): + if isinstance(attr, BaseComponent) and (attr not in self.comps): + self.comps.append(attr) + self.ignore_actuator(attr) + self.comps.sort(key=lambda x: x.priority) + self.scheduler.add(self.comps) + self.finalize_start() + + def init_start(self): + pass + + def ignore_actuator(self, attr): + pass + + def finalize_start(self): + pass + + def step(self, *args, **kwargs): + if self.is_skip: + return + self.current_step += 1 + + def stop(self, *args, **kwargs): + if self.is_skip: + return + self.scheduler.remove(self.comps) + self.comps.clear() diff --git a/accuracy_tools/msprobe/common/__init__.py b/accuracy_tools/msprobe/common/__init__.py new file mode 100644 index 00000000000..53529bc8d31 --- /dev/null +++ b/accuracy_tools/msprobe/common/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/accuracy_tools/msprobe/common/ascend.py b/accuracy_tools/msprobe/common/ascend.py new file mode 100644 index 00000000000..61c5e4a5957 --- /dev/null +++ b/accuracy_tools/msprobe/common/ascend.py @@ -0,0 +1,93 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.utils.constants import MsgConst, PathConst +from msprobe.utils.env import evars +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger +from msprobe.utils.path import SafePath, SoftLinkLevel, join_path +from msprobe.utils.toolkits import run_subprocess + +_ENVVAR_ASCEND_TOOLKIT_HOME = "ASCEND_TOOLKIT_HOME" +_DEFAULT_ASCEND_TOOLKIT_HOME = "/usr/local/Ascend/ascend-toolkit/latest" +_ENVVAR_ATB_HOME_PATH = "ATB_HOME_PATH" +_SUFFIX_CONVERT_MODEL = (".om", ".txt") +_ATC_BIN_PATH = "compiler/bin/atc" +_OLD_ATC_BIN_PATH = "atc/bin/atc" +_ATC_MODE_OM2JSON = "1" +_ATC_MODE_GETXT2JSON = "5" + + +class CANN: + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(CANN, cls).__new__(cls) + return cls._instance + + def __init__(self): + self.cann_home = evars.get(_ENVVAR_ASCEND_TOOLKIT_HOME, _DEFAULT_ASCEND_TOOLKIT_HOME) + + @property + def lib_atb_path(self): + atb_home_path = evars.get(_ENVVAR_ATB_HOME_PATH) + return SafePath(join_path(atb_home_path, "lib", "libatb.so"), PathConst.FILE, "r", PathConst.SIZE_20M).check( + soft_link_level=SoftLinkLevel.IGNORE + ) + + @property + def probe_symbols(self): + output = run_subprocess(["nm", "-D", self.lib_atb_path], capture_output=True) + res = [] + for line in (output or "").splitlines(): + parts = line.strip().split() + if len(parts) != 3: + continue + symbol_type = parts[1] + symbol_name = parts[2] + if symbol_type == "T" and "Probe" in symbol_name: + res.append(symbol_name) + return res + + def model2json(self, model_path: str, json_path: str): + model_path = SafePath(model_path, PathConst.FILE, "r", PathConst.SIZE_30G, _SUFFIX_CONVERT_MODEL).check() + json_path = SafePath(json_path, PathConst.FILE, "w", suffix=".json").check(path_exist=False) + atc = self._get_atc_path() + if model_path.endswith(".om"): + mode_type = _ATC_MODE_OM2JSON + else: + mode_type = _ATC_MODE_GETXT2JSON + atc_cmd = [atc, "--mode=" + mode_type, "--om=" + model_path, "--json=" + json_path] + logger.info("Start converting the model format to JSON.") + run_subprocess(atc_cmd) + logger.info(f"The model has been converted to a JSON file, located at {json_path}.") + + def _get_atc_path(self): + try: + atc_path = SafePath( + join_path(self.cann_home, _ATC_BIN_PATH), PathConst.FILE, "e", PathConst.SIZE_20M + ).check(soft_link_level=SoftLinkLevel.IGNORE) + except Exception as e1: + logger.error(str(e1)) + try: + atc_path = SafePath( + join_path(self.cann_home, _OLD_ATC_BIN_PATH), PathConst.FILE, "e", PathConst.SIZE_20M + ).check(soft_link_level=SoftLinkLevel.IGNORE) + except Exception as e2: + raise MsprobeException(MsgConst.CANN_FAILED) from e2 + return atc_path + + +cann = CANN() diff --git a/accuracy_tools/msprobe/common/cli.py b/accuracy_tools/msprobe/common/cli.py new file mode 100644 index 00000000000..1861d34d5dd --- /dev/null +++ b/accuracy_tools/msprobe/common/cli.py @@ -0,0 +1,99 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser +from pathlib import Path +from sys import argv + +from msprobe.base import BaseCommand, Command, Service +from msprobe.utils.constants import CfgConst, CmdConst, MsgConst, PathConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger +from msprobe.utils.path import SafePath +from msprobe.utils.toolkits import run_subprocess, set_ld_preload + +_DESCRIPTION = """ + _ + _ __ ___ ___ _ __ _ __ ___ | |__ ___ + | '_ ` _ \/ __| '_ \| '__/ _ \| '_ \ / _ \ + | | | | | \__ \ |_) | | | (_) | |_) | __/ + |_| |_| |_|___/ .__/|_| \___/|_.__/ \___| + |_| + +msprobe (MindStudio-Probe), [Powered by MindStudio]. +A set of tools for diagnosing and improving model accuracy on Ascend NPU, +including API accuracy, args checker, grad tool etc. +""" +_L2COMMAND = "L2command" +_ROOT_LEVEL = 1 +_SECEND_LEVEL = 2 + + +class MainCommand(BaseCommand): + def __init__(self): + super().__init__() + self.parser = ArgumentParser(prog="msprobe", description=_DESCRIPTION, formatter_class=self.formatter_class) + self.subparser = self.parser.add_subparsers(dest=_L2COMMAND) + self.second_commands = Command.get("msprobe") + self.subcommand_level = _ROOT_LEVEL + + @property + def _msprobe_so_path(self): + current_file = Path(__file__).resolve() + lib_path = str(current_file.parent.parent / "lib" / "msprobe_c.so") + return SafePath(lib_path, PathConst.FILE, "r", PathConst.SIZE_500M, ".so").check() + + def add_arguments(self, parse): + pass + + def register(self): + for name, cmd_class in self.second_commands.items(): + cmd_parser = self.subparser.add_parser( + name=name, help=CmdConst.HELP_SERVICE_MAP.get(name), formatter_class=self.formatter_class + ) + if self.service_key in self.second_commands: + cmd_class.add_arguments(cmd_parser) + self.subcommand_level = _SECEND_LEVEL + self.build_parser(cmd_parser, cmd_class) + + def parse(self): + return self.parser.parse_args() + + def execute(self, args): + if len(argv) <= self.subcommand_level: + self.parser.print_help() + return + serv_name = argv[self.subcommand_level - 1] + if Service.get(serv_name): + logger.info(f"Preparing to launch {serv_name} service.") + if args.framework: + self._set_env(args.framework) + if not args.msprobex: + serv = Service(cmd_namespace=args, serv_name=serv_name) + serv.run_cli() + else: + run_subprocess(args.exec) + else: + raise MsprobeException(MsgConst.CALL_FAILED, f"The {serv_name} service is not registered. Please check it.") + + def _set_env(self, framework): + env_func = {CfgConst.FRAMEWORK_MINDIE_LLM: self._set_mindie_llm_env} + frame_init = env_func.get(framework) + if frame_init: + frame_init() + else: + raise MsprobeException(MsgConst.CALL_FAILED, f"The {framework} framework is not supported.") + + def _set_mindie_llm_env(self): + set_ld_preload(self._msprobe_so_path) diff --git a/accuracy_tools/msprobe/common/dirs.py b/accuracy_tools/msprobe/common/dirs.py new file mode 100644 index 00000000000..3907c8681ba --- /dev/null +++ b/accuracy_tools/msprobe/common/dirs.py @@ -0,0 +1,74 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime + +from msprobe.utils.log import logger +from msprobe.utils.path import DirSafeHandler +from msprobe.utils.toolkits import get_current_rank, get_current_timestamp, timestamp_sync + + +class DirPool: + msprobe_path = None + model_dir = None + + def __init__(self): + self.step_dir = None + self.rank_dir = None + self.input_dir = None + self.tensor_dir = None + + @classmethod + def make_msprobe_dir(cls, path: str): + timestamp = get_current_timestamp(microsecond=False) + timestamp = timestamp_sync(timestamp) + formatted_date = datetime.fromtimestamp(timestamp).strftime("%Y%m%d_%H%M%S") + cls.msprobe_path = DirSafeHandler.join_and_create(path, f"msprobe_{formatted_date}/") + + @classmethod + def get_msprobe_dir(cls): + return DirSafeHandler.get_or_raise(cls.msprobe_path, "Dump dir has not been set.") + + @classmethod + def make_model_dir(cls): + cls.model_dir = DirSafeHandler.join_and_create(cls.get_msprobe_dir(), "model") + + @classmethod + def get_model_dir(cls): + return DirSafeHandler.get_or_raise(cls.model_dir, "Model dir has not been set.") + + def make_step_dir(self, current_step: int): + self.step_dir = DirSafeHandler.join_and_create(self.get_msprobe_dir(), f"step{current_step}") + + def get_step_dir(self): + logger.info(f"Step dir has switched to {self.step_dir}.") + return DirSafeHandler.get_or_raise(self.step_dir, "Step dir has not been set.") + + def make_rank_dir(self): + self.rank_dir = DirSafeHandler.join_and_create(self.get_step_dir(), f"rank{get_current_rank()}") + + def get_rank_dir(self): + return DirSafeHandler.get_or_raise(self.rank_dir, "Rank dir has not been set.") + + def make_input_dir(self): + self.input_dir = DirSafeHandler.join_and_create(self.get_rank_dir(), "input") + + def get_input_dir(self): + return DirSafeHandler.get_or_raise(self.input_dir, "Input dir has not been set.") + + def make_tensor_dir(self): + self.tensor_dir = DirSafeHandler.join_and_create(self.get_rank_dir(), "dump_tensor_data") + + def get_tensor_dir(self): + return DirSafeHandler.get_or_raise(self.tensor_dir, "dump_tensor_data dir has not been set.") diff --git a/accuracy_tools/msprobe/common/stat.py b/accuracy_tools/msprobe/common/stat.py new file mode 100644 index 00000000000..558adb4c605 --- /dev/null +++ b/accuracy_tools/msprobe/common/stat.py @@ -0,0 +1,89 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from zlib import crc32 + +import numpy as np + +from msprobe.utils.constants import DumpConst, MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger +from msprobe.utils.toolkits import safely_compute + + +class DataStat: + @staticmethod + def get_valid_type(np_data): + try: + module_name = np_data.__class__.__module__ + class_name = np_data.__class__.__name__ + return f"{module_name}.{class_name}" + except Exception: + logger.warning(f"Unrecognized type pattern: {type(np_data)}.") + return None + + @staticmethod + @safely_compute + def get_dtype(npy): + return npy.dtype + + @staticmethod + @safely_compute + def get_shape(npy): + return npy.shape + + @staticmethod + @safely_compute + def get_max(npy): + return float(npy.max()) + + @staticmethod + @safely_compute + def get_min(npy): + return float(npy.min()) + + @staticmethod + @safely_compute + def get_mean(npy): + return float(npy.mean()) + + @staticmethod + @safely_compute + def get_norm(npy): + return float(np.linalg.norm(npy)) + + @staticmethod + @safely_compute + def get_crc32_hash(npy): + npy_bytes = npy.tobytes() + crc32_hash = crc32(npy_bytes) + return f"{crc32_hash:08x}" + + @classmethod + def collect_stats_for_numpy(cls, npy: np.ndarray, summary_mode: str): + try: + npy = np.asarray(npy) + except Exception as e: + raise MsprobeException(MsgConst.CONVERSION_FAILED, f"Failed to convert to numpy array.") from e + stat_dict = {} + stat_dict["type"] = cls.get_valid_type(npy) + stat_dict["dtype"] = cls.get_dtype(npy) + stat_dict["shape"] = cls.get_shape(npy) + stat_dict["Max"] = cls.get_max(npy) + stat_dict["Min"] = cls.get_min(npy) + stat_dict["Mean"] = cls.get_mean(npy) + stat_dict["Norm"] = cls.get_norm(npy) + if summary_mode == DumpConst.SUMMARY_MD5: + stat_dict["md5"] = cls.get_crc32_hash(npy) + return stat_dict diff --git a/accuracy_tools/msprobe/common/validation.py b/accuracy_tools/msprobe/common/validation.py new file mode 100644 index 00000000000..ac617aa5158 --- /dev/null +++ b/accuracy_tools/msprobe/common/validation.py @@ -0,0 +1,181 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from argparse import Action + +from msprobe.utils.constants import CfgConst, MsgConst, PathConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import LOG_LEVEL +from msprobe.utils.path import SafePath, is_dir, is_file +from msprobe.utils.toolkits import check_int_border + +_HYPHEN_NUM_PATTERN = r"^(?:\d+-\d+|\d+-\d+-\d+)$" + + +def valid_task(value: str): + if not isinstance(value, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"task" must be a string.') + if value not in CfgConst.ALL_TASK: + raise MsprobeException(MsgConst.INVALID_ARGU, f'"task" must be one of {CfgConst.ALL_TASK}, currently: {value}.') + return value + + +def _valid_suffix_for_exec(value: str, extension: str, error_msg: str): + try: + if not value.endswith(extension): + raise MsprobeException(MsgConst.INVALID_ARGU, error_msg) + except Exception as e: + raise MsprobeException(MsgConst.PARSING_FAILED) from e + _ = SafePath(value, PathConst.FILE, "r", PathConst.SIZE_30G).check() + + +def valid_exec(values: str): + if values is None: + return values + if not isinstance(values, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"exec" must be a string.') + values = values.split(" ") + first_keyword = values[0] + if is_dir(first_keyword): + _ = SafePath(first_keyword, PathConst.DIR, "r", PathConst.SIZE_50G).check() + elif first_keyword == "bash": + _valid_suffix_for_exec(values[1], ".sh", "The interpreter must start with bash when the script ends with .sh.") + elif first_keyword in {"python", "python3"}: + _valid_suffix_for_exec( + values[1], ".py", "The interpreter must start with python, python3 when the script ends with .py." + ) + elif is_file(first_keyword): + _valid_suffix_for_exec( + first_keyword, + (PathConst.SUFFIX_OFFLINE_MODEL + PathConst.SUFFIX_ONLINE_SCRIPT), + "A single readable or executable file must end with " + f"{PathConst.SUFFIX_OFFLINE_MODEL + PathConst.SUFFIX_ONLINE_SCRIPT}.", + ) + else: + raise MsprobeException(MsgConst.INVALID_ARGU, f"Please check the `--exec (-e)`, currently: {values}.") + return values + + +class CheckExec(Action): + def __call__(self, parser, namespace, values, option_string=None): + values = valid_exec(values) + setattr(namespace, self.dest, values) + + +def valid_config_path(value: str): + return SafePath(value, PathConst.FILE, "r", PathConst.SIZE_2G, ".json").check() + + +class CheckConfigPath(Action): + def __call__(self, parser, namespace, values, option_string=None): + values = valid_config_path(values) + setattr(namespace, self.dest, values) + + +def valid_framework(value: str): + if not value: + return value + if not isinstance(value, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"framework" must be a string.') + if value not in CfgConst.ALL_FRAMEWORK: + raise MsprobeException( + MsgConst.INVALID_ARGU, f'"framework" must be one of {CfgConst.ALL_FRAMEWORK}, currently: {value}.' + ) + return value + + +class CheckFramework(Action): + def __call__(self, parser, namespace, values, option_string=None): + values = valid_framework(values) + setattr(namespace, self.dest, values) + + +def parse_hyphen(element, tag=None): + if not re.match(_HYPHEN_NUM_PATTERN, element): + msg = 'Only accepts numbers or a range like "123-456", "123-456-2".' + if tag: + msg += f" Context: {tag}." + raise MsprobeException(MsgConst.INVALID_ARGU, msg) + split_ele = element.split("-") + start = int(split_ele[0]) + end = int(split_ele[1]) + check_int_border(start, end, tag="Hyphen-connected integer") + if start > end: + msg = f"The left value must be smaller than the right, currently: {start} v.s. {end}." + if msg: + msg += f" Context: {tag}." + raise MsprobeException(MsgConst.INVALID_ARGU, msg) + step = int(split_ele[2]) if len(split_ele) == 3 else 1 + ranges = [i for i in range(start, end + 1, step)] + return ranges + + +def valid_step_or_rank(values: list): + if not values: + return values + if not isinstance(values, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"rank" or "step" must be a list.') + res = [] + for element in values: + if isinstance(element, str): + res.extend(parse_hyphen(element, tag="strp or rank")) + elif isinstance(element, int): + check_int_border(element, tag="Element in the 'rank' or 'step' list") + res.append(element) + else: + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, 'Elements in the "rank" or "step" support only strings and integers.' + ) + res = list(set(res)) + res.sort() + return res + + +def valid_level(values: list): + if not values: + return values + if not isinstance(values, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"level" must be a list.') + for value in values: + if value not in CfgConst.ALL_LEVEL: + raise MsprobeException( + MsgConst.INVALID_ARGU, f'"level" must be one of {CfgConst.ALL_LEVEL}, currently: {value}.' + ) + return values + + +def valid_log_level(value: str): + if value is None: + return value + if not isinstance(value, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"log_level" must be a string.') + log_level = {level.lower() for level in LOG_LEVEL} + if value not in log_level: + raise MsprobeException(MsgConst.INVALID_ARGU, f'"log_level" must be one of {log_level}, currently: {value}.') + return value + + +def valid_seed(value: int): + if value is None: + return value + check_int_border(value, tag="seed number") + return value + + +def valid_buffer_size(value: int): + if value is None: + return value + check_int_border(value, tag="buffer size") + return value diff --git a/accuracy_tools/msprobe/core/__init__.py b/accuracy_tools/msprobe/core/__init__.py new file mode 100644 index 00000000000..607855a9ec5 --- /dev/null +++ b/accuracy_tools/msprobe/core/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core import cli, components, service diff --git a/accuracy_tools/msprobe/core/base/__init__.py b/accuracy_tools/msprobe/core/base/__init__.py new file mode 100644 index 00000000000..83608d03334 --- /dev/null +++ b/accuracy_tools/msprobe/core/base/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.base.dump_actuator import OfflineModelActuator +from msprobe.core.base.dump_dumper import BaseDumper +from msprobe.core.base.dump_writer import RankDirFile, SaveBinTensor, SaveNpyTensor, SaveTensor diff --git a/accuracy_tools/msprobe/core/base/dump_actuator.py b/accuracy_tools/msprobe/core/base/dump_actuator.py new file mode 100644 index 00000000000..e399b0216c1 --- /dev/null +++ b/accuracy_tools/msprobe/core/base/dump_actuator.py @@ -0,0 +1,191 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from msprobe.utils.constants import MsgConst +from msprobe.utils.dependencies import dependent +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.io import load_bin_data, load_npy, save_npy +from msprobe.utils.log import logger +from msprobe.utils.path import join_path + + +def get_tf_type2dtype_map(): + pons = dependent.get_tensorflow() + if None not in pons: + tf, _, _ = pons + return { + tf.float16: np.float16, + tf.float32: np.float32, + tf.float64: np.float64, + tf.int8: np.int8, + tf.int16: np.int16, + tf.int32: np.int32, + tf.int64: np.int64, + } + else: + return {} + + +class OfflineModelActuator: + def __init__(self, model_path: str, input_shape: dict, input_path: str, **kwargs): + self.model_path = model_path + self.input_shape = input_shape or {} + self.input_path = input_path or "" + self.kwargs = kwargs + self.dir_pool = kwargs.get("dir_pool") + + @staticmethod + def _is_dynamic_shape(tensor_shape): + for shape in tensor_shape: + if shape is None or not isinstance(shape, int): + return True + return False + + @staticmethod + def _tensor2numpy_for_type(tensor_type): + base_type2dtype_map = { + "tensor(int)": np.int32, + "tensor(int8)": np.int8, + "tensor(int16)": np.int16, + "tensor(int32)": np.int32, + "tensor(int64)": np.int64, + "tensor(uint8)": np.uint8, + "tensor(uint16)": np.uint16, + "tensor(uint32)": np.uint32, + "tensor(uint64)": np.uint64, + "tensor(float)": np.float32, + "tensor(float16)": np.float16, + "tensor(double)": np.double, + "tensor(bool)": np.bool_, + "tensor(complex64)": np.complex64, + "tensor(complex128)": np.complex_, + "float32": np.float32, + "float16": np.float16, + } + numpy_data_type = {**base_type2dtype_map, **get_tf_type2dtype_map()}.get(tensor_type) + if numpy_data_type: + return numpy_data_type + else: + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, f"Tensor type {tensor_type} not provided.") + + @staticmethod + def _generate_random_input_data(save_dir, names, shapes, dtypes, is_byte_data=False): + input_map = {} + for tensor_name, tensor_shape, tensor_dtype in zip(names, shapes, dtypes): + if is_byte_data: + input_data = np.random.randint(0, 256, int(np.prod(tensor_shape))).astype(np.uint8) + else: + input_data = np.random.random(tensor_shape).astype(tensor_dtype) + input_map[tensor_name] = input_data + shape_str = "_".join(list(map(str, tensor_shape))) + file_name = "_".join([tensor_name, "shape", shape_str, ".npy"]) + save_npy(input_data, join_path(save_dir, file_name)) + logger.info( + f"Save input file path: {join_path(save_dir, file_name)}, " + f"shape: {input_data.shape}, dtype: {input_data.dtype}." + ) + return input_map + + @staticmethod + def _read_input_data(input_paths, names, shapes, dtypes, is_byte_data=False): + input_map = {} + for input_path, name, shape, dtype in zip(input_paths, names, shapes, dtypes): + if input_path.endswith(".bin"): + input_data = load_bin_data(input_path, dtype, shape, is_byte_data) + elif input_path.endswith(".npy"): + input_data = load_npy(input_path) + if np.prod(input_data.shape) != np.prod(shape) and not is_byte_data: + raise MsprobeException( + MsgConst.INVALID_ARGU, + "The shape of the input data does not match the model's shape, " + f"input path: {input_path}, input shape: {input_data.shape}, " + f"model's shape: {shape}.", + ) + if not is_byte_data: + input_data = input_data.reshape(shape) + input_map[name] = input_data + logger.info(f"Load input file path: {input_path}, shape: {input_data.shape}, dtype: {input_data.dtype}.") + return input_map + + @staticmethod + def _check_input_shape(op_name, model_shape, input_shape): + if not input_shape: + raise MsprobeException( + MsgConst.REQUIRED_ARGU_MISSING, + f"{op_name}'s input_shape is missing. " + f'Please set `shape: [xxx]` in "input" according to {model_shape}.', + ) + if len(model_shape) != len(input_shape): + raise MsprobeException( + MsgConst.INVALID_ARGU, + f"Unequal lengths for the shape of {op_name}. " + f"Model shape: {model_shape}, input shape: {input_shape}.", + ) + for index, value in enumerate(model_shape): + if value is None or isinstance(value, str): + continue + if input_shape[index] != value: + raise MsprobeException( + MsgConst.INVALID_ARGU, + "The input shape does not match the model shape. " + f"Tensor name: {op_name}, {str(input_shape)} v.s. {str(model_shape)}.", + ) + + @classmethod + def _get_input_shape_info(cls, tensor_name, tensor_shape, input_shape, tensor_type): + cls._check_input_shape(tensor_name, tensor_shape, input_shape) + tensor_shape_info = {"name": tensor_name, "shape": input_shape, "type": tensor_type} + logger.info(f"The dynamic shape of {tensor_name} has been fixed to {input_shape}.") + return tensor_shape_info + + def get_inputs_data(self, inputs_tensor_info, is_byte_data=False): + names, shapes, dtypes = [], [], [] + for x in inputs_tensor_info: + names.append(x["name"]) + shapes.append(x["shape"]) + # read raw byte data (memory) regardless of type; defaults to int8. + dtypes.append(self._tensor2numpy_for_type(x["type"]) if not is_byte_data else np.int8) + if not self.input_path: + self.dir_pool.make_input_dir() + input_map = self._generate_random_input_data( + self.dir_pool.get_input_dir(), names, shapes, dtypes, is_byte_data + ) + else: + input_map = self._read_input_data(self.input_path, names, shapes, dtypes, is_byte_data) + return input_map + + def process_tensor_shape(self, tensor_name, tensor_type, tensor_shape): + tensor_shape_info_list = [] + if self._is_dynamic_shape(tensor_shape): + if not self.input_shape: + raise MsprobeException( + MsgConst.INVALID_ARGU, + f"The dynamic shape {tensor_shape} are not supported. Please " + f'set "shape" of {tensor_name} in "input" to fix the dynamic shape.', + ) + if tensor_name not in self.input_shape: + raise MsprobeException( + MsgConst.INVALID_ARGU, + f'{tensor_name} has a dynamic shape, but its shape is not defined in the "input".', + ) + if self.input_shape: + inshape = self.input_shape.get(tensor_name) + tensor_shape_info = self._get_input_shape_info(tensor_name, tensor_shape, inshape, tensor_type) + tensor_shape_info_list.append(tensor_shape_info) + else: + tensor_shape_info = {"name": tensor_name, "shape": tensor_shape, "type": tensor_type} + tensor_shape_info_list.append(tensor_shape_info) + return tensor_shape_info_list diff --git a/accuracy_tools/msprobe/core/base/dump_dumper.py b/accuracy_tools/msprobe/core/base/dump_dumper.py new file mode 100644 index 00000000000..5c566c26a88 --- /dev/null +++ b/accuracy_tools/msprobe/core/base/dump_dumper.py @@ -0,0 +1,49 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from msprobe.utils.constants import MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.hijack import release + + +class BaseDumper(ABC): + def __init__(self, data_mode): + self.data_mode = data_mode + self.data_for_save = {} + self.input_map = {} + self.output_map = {} + self.handler = [] + self._data_iter = None + + @staticmethod + def through_nodes(nodes, node_name, in_or_out, data_map): + for i, item in enumerate(nodes): + if isinstance(item, str): + args_name = item + elif hasattr(item, "name"): + args_name = item.name + else: + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "Unsupported node type.") + res = [node_name, in_or_out, args_name, i, data_map.get(args_name)] + yield res + + @abstractmethod + def register_hook(self): + pass + + def release_hook(self): + for handler_hex in self.handler: + release(handler_hex) diff --git a/accuracy_tools/msprobe/core/base/dump_writer.py b/accuracy_tools/msprobe/core/base/dump_writer.py new file mode 100644 index 00000000000..3b0e47dca64 --- /dev/null +++ b/accuracy_tools/msprobe/core/base/dump_writer.py @@ -0,0 +1,126 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from abc import ABC, abstractmethod + +import numpy as np + +from msprobe.utils.constants import DumpConst, MsgConst, PathConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.io import save_bin_from_bytes, save_bin_from_ndarray, save_npy +from msprobe.utils.path import SafePath, join_path +from msprobe.utils.toolkits import get_current_timestamp, get_valid_name, register + + +class RankDirFile(ABC): + def __init__(self, buffer_size): + self.max_cache_size = buffer_size + self.cache_file = None + self.cache_file_size = 0 + self.rank_dir = None + + @abstractmethod + def _save(self): + pass + + def add_rank_dir(self, rank_dir): + self.rank_dir = rank_dir + + def cover(self, data): + self.cache_file_size += sys.getsizeof(data) + if self.cache_file_size < self.max_cache_size: + return + self._save() + self.cache_file_size = 0 + + def clear_cache(self): + if self.cache_file_size == 0: + return + self._save() + self.cache_file_size = 0 + if isinstance(self.cache_file, dict): + self.cache_file.clear() + + +class SaveTensorStrategy(ABC): + def __init__(self): + self.tensor_dir = None + self.tensor_path = None + self.suffix = None + + @abstractmethod + def _save(self, data): + pass + + def add_tensor_dir(self, tensor_dir): + self.tensor_dir = tensor_dir + + def save_tensor_data(self, node_name, args_name, data): + file_name = self._generate_file_name(node_name, args_name) + self.tensor_path = self._generate_path(file_name) + self._save(data) + + def _generate_path(self, tensor_name): + tensor_path = SafePath(join_path(self.tensor_dir, tensor_name), PathConst.FILE, "w", suffix=self.suffix).check( + path_exist=False + ) + return tensor_path + + def _generate_file_name(self, node_name, args_name): + name = ".".join( + [ + str(get_current_timestamp(microsecond=True)), + get_valid_name(node_name), + get_valid_name(args_name) + f"{self.suffix}", + ] + ) + return name + + +class SaveTensor: + _fmt_map = {} + + @classmethod + def register(cls, name): + return register(name, cls._fmt_map) + + @classmethod + def get(cls, name): + return cls._fmt_map.get(name) + + +@SaveTensor.register(DumpConst.NPY_FORMAT) +class SaveNpyTensor(SaveTensorStrategy): + def __init__(self): + super().__init__() + self.suffix = ".npy" + + def _save(self, data): + save_npy(data, self.tensor_path) + + +@SaveTensor.register(DumpConst.BIN_FORMAT) +class SaveBinTensor(SaveTensorStrategy): + def __init__(self): + super().__init__() + self.suffix = ".bin" + + def _save(self, data): + if isinstance(data, np.ndarray): + save_bin_from_ndarray(data, self.tensor_path) + elif isinstance(data, bytes): + save_bin_from_bytes(data, self.tensor_path) + else: + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, f"Unsupported data type: {type(data).__name__}.") diff --git a/accuracy_tools/msprobe/core/cli/__init__.py b/accuracy_tools/msprobe/core/cli/__init__.py new file mode 100644 index 00000000000..178ce99fa16 --- /dev/null +++ b/accuracy_tools/msprobe/core/cli/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.cli.command_dump import DumpCommand diff --git a/accuracy_tools/msprobe/core/cli/command_dump.py b/accuracy_tools/msprobe/core/cli/command_dump.py new file mode 100644 index 00000000000..bfc50e3ef27 --- /dev/null +++ b/accuracy_tools/msprobe/core/cli/command_dump.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.base import BaseCommand, Command +from msprobe.common.validation import CheckConfigPath, CheckExec, CheckFramework +from msprobe.utils.constants import CfgConst, CmdConst, PathConst + + +@Command.register("msprobe", CmdConst.DUMP) +class DumpCommand(BaseCommand): + @staticmethod + def add_required_arguments(parser): + req = parser.add_argument_group("Required arguments") + req.add_argument( + "-e", + "--exec", + dest=CfgConst.EXEC, + action=CheckExec, + required=True, + help=f""" Supports two input types: + 1. An offline model file with {("saved_model",) + PathConst.SUFFIX_OFFLINE_MODEL} extension; + 2. An executable CLI scripts enclosed in quotes end with {PathConst.SUFFIX_ONLINE_SCRIPT}. Default: None""", + ) + + @staticmethod + def add_optional_arguments(parser): + opt = parser.add_argument_group("Optional arguments") + opt.add_argument( + "-cfg", + "--config", + dest=CfgConst.CONFIG_PATH, + action=CheckConfigPath, + help=""" A config JSON file for storing data dump settings. Default: None""", + ) + opt.add_argument( + "-f", + "--framework", + dest=CfgConst.FRAMEWORK, + action=CheckFramework, + help=f""" Required when using: {CfgConst.ALL_FRAMEWORK}. Default: None""", + ) + opt.add_argument( + "-x", + "--msprobex", + dest="msprobex", + default=False, + action="store_true", + help=""" Use msprobe extended API. Default: False""", + ) + + @classmethod + def add_arguments(cls, parser): + cls.add_required_arguments(parser) + cls.add_optional_arguments(parser) diff --git a/accuracy_tools/msprobe/core/components/__init__.py b/accuracy_tools/msprobe/core/components/__init__.py new file mode 100644 index 00000000000..e895ba49535 --- /dev/null +++ b/accuracy_tools/msprobe/core/components/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.components.dumper_acl import ACLCompatibleComp, ACLDumperComp +from msprobe.core.components.dumper_atb import AtbActuatorComp +from msprobe.core.components.dumper_caffe import CaffeActuatorComp, CaffeDumperComp +from msprobe.core.components.dumper_om import OmActuatorComp +from msprobe.core.components.dumper_onnx import OnnxActuatorComp, OnnxDumperComp +from msprobe.core.components.dumper_tf import ( + FrozenGraphActuatorCompCPU, + FrozenGraphActuatorCompNPU, + FrozenGraphDumperCompCPU, + FrozenGraphSetGECompNPU, +) +from msprobe.core.components.dumper_writer import DumpWriterComp diff --git a/accuracy_tools/msprobe/core/components/dumper_acl.py b/accuracy_tools/msprobe/core/components/dumper_acl.py new file mode 100644 index 00000000000..e5429dcb38e --- /dev/null +++ b/accuracy_tools/msprobe/core/components/dumper_acl.py @@ -0,0 +1,150 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.base import Component, ConsumerComp, ProducerComp +from msprobe.common.dirs import DirPool +from msprobe.core.dump import acl_device_manager +from msprobe.utils.constants import ACLConst, CompConst, MsgConst, PathConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.io import save_json +from msprobe.utils.log import logger +from msprobe.utils.path import get_abs_path, get_name_and_ext, join_path + + +class ACLDumpDataProcessor: + def __init__(self, chunk): + self.dump_chunk = chunk + self.buffer_bytes = bytes() + self.total_len = 0 + self.completed = False + + @property + def is_completed(self): + return self.completed + + def get_data(self): + if not self.buffer_bytes or not self.completed: + return None + return self.buffer_bytes + + def process_data(self): + if self.is_completed: + logger.error(f"DataProcessor receive data when completed. Some errors may occur.") + return + self._tackle_last_chunk() + self._tackle_buf_len() + + def _tackle_last_chunk(self): + if self.dump_chunk.get(ACLConst.IS_LAST_CHUNK): + self.completed = True + + def _tackle_buf_len(self): + buf_len = self.dump_chunk.get(ACLConst.BUF_LEN) + if buf_len == 0 or buf_len + self.total_len > PathConst.SIZE_4G: + raise MsprobeException( + MsgConst.RISK_ALERT, + f"Buffer overflow (cached size {self.total_len}), receiving size: {buf_len}.", + ) + self.total_len += buf_len + self.buffer_bytes += self.dump_chunk.get(ACLConst.DATA_BUF) + + +@Component.register(CompConst.ACL_DUMPER_COMP) +class ACLDumperComp(ProducerComp): + def __init__(self, priority, data_mode, rank, **kwargs): + super().__init__(priority) + self.data_processor_map = {} + self.data_map = {} + self.data_mode = data_mode + self.model_path = kwargs.get("model_path") + self.acl_resource_manager = acl_device_manager.get_acl_resource_manager(rank) + + @staticmethod + def _get_node_name(chunk): + dump_file_path = get_abs_path(chunk.get(ACLConst.FILE_NAME)) + file_name = dump_file_path.split("/")[-1] + file_name = file_name.split(".") + if len(file_name) >= 2: + type_and_name = "-".join(file_name[:2]) + return type_and_name + else: + raise MsprobeException(MsgConst.INVALID_ARGU, "The filename returned by ACL has no dot.") + + def activate(self, *args, **kwargs): + self.acl_resource_manager.initialize() + self.acl_resource_manager.set_dump(self._get_dump_json(), self._dump_call_back) + + def deactivate(self, *args, **kwargs): + self.acl_resource_manager.destroy_resource() + + def load_data(self): + if not self.data_map: + return None + data_map = self.data_map + self.data_map = {} + return data_map + + def _get_dump_json(self): + if self.model_path: + name, _ = get_name_and_ext(self.model_path) + dump_list = [{"model_name": name}] + else: + dump_list = [] + acl_dump_json_dict = { + "dump": { + "dump_path": DirPool.get_msprobe_dir(), + "dump_mode": self.data_mode[0], + "dump_list": dump_list, + } + } + acl_dump_json_path = join_path(DirPool.get_msprobe_dir(), "acl_dump.json") + save_json(acl_dump_json_dict, acl_dump_json_path, indent=4) + return acl_dump_json_path + + def _dump_call_back(self, chunk, length): + # The _dump_call_back function must take two arguments; the second one is required and cannot be removed. + type_and_name = self._get_node_name(chunk) + self.data_map[type_and_name] = self._get_data(type_and_name, chunk) + + def _get_data(self, type_and_name, chunk): + if type_and_name not in self.data_processor_map: + self.data_processor_map[type_and_name] = ACLDumpDataProcessor(chunk) + processor = self.data_processor_map.get(type_and_name) + processor.process_data() + if not processor.is_completed: + return + bytes_data = processor.get_data() + if not bytes_data: + return + """ + Store the data in a Map to prevent processing errors caused by the model having + already been finalized when directly publishing the last chunk of data. + """ + self.data_processor_map.pop(type_and_name) + return bytes_data + + +@Component.register(CompConst.ACL_COMPATIBLE_COMP) +class ACLCompatibleComp(ConsumerComp, ProducerComp): + def __init__(self, priority, **kwargs): + super().__init__(priority) + + def load_data(self): + pass + + def consume(self, packages): + data_map = packages[0][1] + for node_name, data in data_map.items(): + sealed_data = [node_name, "all", "args_name", "x", data] + self.publish((sealed_data, None)) diff --git a/accuracy_tools/msprobe/core/components/dumper_atb.py b/accuracy_tools/msprobe/core/components/dumper_atb.py new file mode 100644 index 00000000000..2da3e9fa10b --- /dev/null +++ b/accuracy_tools/msprobe/core/components/dumper_atb.py @@ -0,0 +1,120 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.base import SIZE_1M, BaseComponent, Component +from msprobe.common.dirs import DirPool +from msprobe.utils.constants import CfgConst, CompConst, DumpConst, MsgConst, PathConst +from msprobe.utils.env import evars +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger +from msprobe.utils.path import is_enough_disk_space +from msprobe.utils.toolkits import run_subprocess, seed_all + + +@Component.register(CompConst.ATB_ACTUATOR_COMP) +class AtbActuatorComp(BaseComponent): + def __init__(self, priority, dump_path, **kwargs): + super().__init__(priority) + self.dump_path = dump_path + self.task = kwargs.get("task", CfgConst.TASK_STAT) + self.dump_level = kwargs.get("dump_level", [CfgConst.LEVEL_KERNEL]) + self.step = kwargs.get("step", []) + self.rank = kwargs.get("rank", []) + self.seed = kwargs.get("seed") + self.log_level = kwargs.get("log_level") + self.summary_mode = kwargs.get("summary_mode", CfgConst.TASK_STAT) + self.buffer_size = kwargs.get("buffer_size", SIZE_1M) + self.data_mode = kwargs.get("data_mode", ["all"]) + self.dump_extra = kwargs.get("dump_extra", []) + self.op_id = kwargs.get("op_id", []) + self.op_name = kwargs.get("op_name", {}) + self.exec = kwargs.get("exec", []) + + def activate(self, *args, **kwargs): + self.set_env_vars() + self.execute_dump() + + def set_env_vars(self): + self._set_dump_path() + self._set_task() + self._set_dump_level() + self._set_step() + self._set_rank() + self._set_seed() + self._set_log_level() + self._set_summary_mode() + self._set_buffer_size() + self._set_data_mode() + self._set_dump_extra() + self._set_op_id() + self._set_op_name() + logger.info("The ATB dump parameters have been set.") + + def execute_dump(self): + if not is_enough_disk_space(DirPool.get_msprobe_dir(), PathConst.SIZE_2G): + raise MsprobeException( + MsgConst.RISK_ALERT, "Please reserve at least 2GB of disk space for saving dump data." + ) + run_subprocess(self.exec) + + def _set_dump_path(self): + evars.set(DumpConst.ENVVAR_LINK_DUMP_PATH, self.dump_path) + + def _set_task(self): + evars.set(DumpConst.ENVVAR_LINK_DUMP_TASK, self.task) + + def _set_dump_level(self): + evars.set(DumpConst.ENVVAR_LINK_DUMP_LEVEL, ",".join(self.dump_level)) + + def _set_step(self): + evars.set(DumpConst.ENVVAR_LINK_STEP, ",".join([str(i) for i in self.step])) + + def _set_rank(self): + evars.set(DumpConst.ENVVAR_LINK_RANK, ",".join([str(i) for i in self.rank])) + + def _set_seed(self): + if self.seed: + seed_all(self.seed, mode=True, rm_dropout=False) + + def _set_log_level(self): + evars.set(DumpConst.ENVVAR_LINK_LOG_LEVEL, logger.get_level_id(self.log_level)) + + def _set_summary_mode(self): + evars.set(DumpConst.ENVVAR_LINK_SUMMARY_MODE, self.summary_mode) + + def _set_buffer_size(self): + evars.set(DumpConst.ENVVAR_LINK_BUFFER_SIZE, self.buffer_size) + + def _set_data_mode(self): + evars.set(DumpConst.ENVVAR_LINK_DATA_MODE, ",".join(self.data_mode)) + + def _set_dump_extra(self): + options = { + "tiling": DumpConst.ENVVAR_LINK_SAVE_TILING, + "cpu_profiling": DumpConst.ENVVAR_LINK_SAVE_CPU_PROFILING, + "kernel_info": DumpConst.ENVVAR_LINK_SAVE_KERNEL_INFO, + "op_info": DumpConst.ENVVAR_LINK_SAVE_OP_INFO, + "param": DumpConst.ENVVAR_LINK_SAVE_PARAM, + } + for key, env_var in options.items(): + evars.set(env_var, "1" if key in self.dump_extra else "0") + + def _set_op_id(self): + evars.set(DumpConst.ENVVAR_LINK_SAVE_TENSOR_IDS, ",".join([str(i) for i in self.op_id])) + + def _set_op_name(self): + op_name = [] + for ll in self.dump_level: + op_name.extend(self.op_name.get(ll, [])) + evars.set(DumpConst.ENVVAR_LINK_SAVE_TENSOR_RUNNER, ",".join([str(i).lower() for i in op_name])) diff --git a/accuracy_tools/msprobe/core/components/dumper_caffe.py b/accuracy_tools/msprobe/core/components/dumper_caffe.py new file mode 100644 index 00000000000..363db0ea8cf --- /dev/null +++ b/accuracy_tools/msprobe/core/components/dumper_caffe.py @@ -0,0 +1,104 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.base import Component, ProducerComp +from msprobe.core.base import BaseDumper +from msprobe.core.components.dumper_offline_model import OfflineModelActuatorComp +from msprobe.core.dump import CaffeModelActuator +from msprobe.utils.constants import CompConst, DumpConst, MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.hijack import ActionType, hijacker + + +@Component.register(CompConst.CAFFE_ACTUATOR_COMP) +class CaffeActuatorComp(OfflineModelActuatorComp): + def __init__(self, priority, model_path, input_shape, input_path, **kwargs): + super().__init__(priority) + self.actuator = CaffeModelActuator(model_path, input_shape, input_path, **kwargs) + + def activate(self, *args, **kwargs): + self.actuator.load_model() + inputs_tensor_info = self.actuator.get_input_tensor_info() + input_map = self.actuator.get_inputs_data(inputs_tensor_info) + _ = self.actuator.infer(input_map) + + +@Component.register(CompConst.CAFFE_DUMPER_COMP) +class CaffeDumperComp(ProducerComp, BaseDumper): + def __init__(self, priority, data_mode): + ProducerComp.__init__(self, priority) + BaseDumper.__init__(self, data_mode) + self.caffe_net = None + + def activate(self, *args, **kwargs): + self.register_hook() + + def deactivate(self, *args, **kwargs): + self.release_hook() + + def register_hook(self): + self.handler.append( + hijacker(stub=self._capture_caffe_net, module="caffe", function="Net", action=ActionType.POST_HOOK) + ) + + def load_data(self): + if self._data_iter is None: + self._get_input_output_map() + self._data_iter = self._summ_dump_data() + try: + return next(self._data_iter) + except StopIteration: + return None + + def _get_input_output_map(self): + self._get_output_map() + self._augment_input_map() + + def _summ_dump_data(self): + net_output_nodes = self.caffe_net.outputs + for layer_name in self.caffe_net.blobs.keys(): + self.data_for_save.setdefault(layer_name, {}) + if any(x in self.data_mode for x in DumpConst.INPUT_ALL): + input_data = self.through_nodes( + self.caffe_net.bottom_names.get(layer_name), layer_name, DumpConst.INPUT_ARGS, self.input_map + ) + for item in input_data: + yield item, net_output_nodes + if any(x in self.data_mode for x in DumpConst.OUTPUT_ALL): + output_data = self.through_nodes( + self.caffe_net.top_names.get(layer_name), layer_name, DumpConst.OUTPUT_ARGS, self.output_map + ) + for item in output_data: + yield item, net_output_nodes + + def _augment_input_map(self): + for layer_name, param in self.caffe_net.params.items(): + if len(param) != 2: + raise MsprobeException( + MsgConst.REQUIRED_ARGU_MISSING, + f"The current layer ({layer_name})'s input does not include weights and biases.", + ) + self.input_map[f"{layer_name}_weight"] = param[0].data + self.input_map[f"{layer_name}_bias"] = param[1].data + self.caffe_net.bottom_names.get(layer_name).append(f"{layer_name}_weight") + self.caffe_net.bottom_names.get(layer_name).append(f"{layer_name}_bias") + self.input_map = {**self.input_map, **self.output_map} + + def _get_output_map(self): + for layer_name, blob in self.caffe_net.blobs.items(): + self.output_map[layer_name] = blob.data + + def _capture_caffe_net(self, ret, *args, **kwargs): + self.caffe_net = ret + return ret diff --git a/accuracy_tools/msprobe/core/components/dumper_offline_model.py b/accuracy_tools/msprobe/core/components/dumper_offline_model.py new file mode 100644 index 00000000000..8a3db6a073b --- /dev/null +++ b/accuracy_tools/msprobe/core/components/dumper_offline_model.py @@ -0,0 +1,20 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.base import BaseComponent + + +class OfflineModelActuatorComp(BaseComponent): + def __init__(self, priority=100): + super().__init__(priority) diff --git a/accuracy_tools/msprobe/core/components/dumper_om.py b/accuracy_tools/msprobe/core/components/dumper_om.py new file mode 100644 index 00000000000..d06b7b0caa2 --- /dev/null +++ b/accuracy_tools/msprobe/core/components/dumper_om.py @@ -0,0 +1,35 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.base import Component +from msprobe.core.components.dumper_offline_model import OfflineModelActuatorComp +from msprobe.core.dump import OmModelActuator +from msprobe.core.dump.acl_manager import acl_device_manager +from msprobe.utils.constants import CompConst + + +@Component.register(CompConst.OM_ACTUATOR_COMP) +class OmActuatorComp(OfflineModelActuatorComp): + def __init__(self, priority, model_path, input_shape, input_path, **kwargs): + super().__init__(priority) + self.actuator = OmModelActuator(model_path, input_shape, input_path, **kwargs) + self.acl_resource_manager = acl_device_manager.get_acl_resource_manager(kwargs.get("rank", 0)) + + def activate(self, *args, **kwargs): + self.acl_resource_manager.initialize() + self.actuator.load_model() + inputs_tensor_info = self.actuator.get_input_tensor_info() + input_map = self.actuator.get_inputs_data(inputs_tensor_info, is_byte_data=True) + self.actuator.infer(input_map) + self.actuator.convert_om2json() diff --git a/accuracy_tools/msprobe/core/components/dumper_onnx.py b/accuracy_tools/msprobe/core/components/dumper_onnx.py new file mode 100644 index 00000000000..dd8015716cd --- /dev/null +++ b/accuracy_tools/msprobe/core/components/dumper_onnx.py @@ -0,0 +1,161 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from msprobe.base import Component, ProducerComp +from msprobe.core.base import BaseDumper +from msprobe.core.components.dumper_offline_model import OfflineModelActuatorComp +from msprobe.core.dump import OnnxModelActuator +from msprobe.utils.constants import CompConst, DumpConst, MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.hijack import ActionType, hijacker +from msprobe.utils.io import load_npy_from_buffer + +_ONNX_DTYPE = {1: np.float32, 2: np.float64} +_SECOND_CALL = 1 +_FIRST_PARAM = 0 + + +@Component.register(CompConst.ONNX_ACTUATOR_COMP) +class OnnxActuatorComp(OfflineModelActuatorComp): + def __init__(self, priority, model_path, input_shape, input_path, **kwargs): + super().__init__(priority) + self.actuator = OnnxModelActuator(model_path, input_shape, input_path, **kwargs) + + def activate(self, *args, **kwargs): + self.actuator.load_model() + inputs_tensor_info = self.actuator.get_input_tensor_info() + input_map = self.actuator.get_inputs_data(inputs_tensor_info) + uninfer_model_path = self.actuator.export_uninfer_model() + _ = self.actuator.infer(uninfer_model_path, input_map) + + +@Component.register(CompConst.ONNX_DUMPER_COMP) +class OnnxDumperComp(ProducerComp, BaseDumper): + def __init__(self, priority, data_mode): + ProducerComp.__init__(self, priority) + BaseDumper.__init__(self, data_mode) + self.output_list = [] + self.origin_model = None + self.model_session = None + + def activate(self, *args, **kwargs): + self.register_hook() + + def deactivate(self, *args, **kwargs): + self.release_hook() + + def register_hook(self): + self.handler_session = hijacker( + stub=self._capture_model_session, + module="onnxruntime", + cls="InferenceSession", + function="__init__", + action=ActionType.PRE_HOOK, + priority=10, + ) + self.handler.append(self.handler_session) + self.handler.append( + hijacker( + stub=self._capture_input_map, + module="onnxruntime", + cls="InferenceSession", + function="run", + action=ActionType.PRE_HOOK, + priority=10, + ) + ) + self.handler.append( + hijacker( + stub=self._capture_output_list, + module="onnxruntime", + cls="InferenceSession", + function="run", + action=ActionType.POST_HOOK, + priority=20, + ) + ) + self.handler.append( + hijacker( + stub=self._capture_origin_model, + module="onnx", + function="load_model", + action=ActionType.POST_HOOK, + priority=20, + ) + ) + + def load_data(self): + if self._data_iter is None: + self._get_model_session() + self._get_input_output_map() + self._data_iter = self._summ_dump_data() + try: + return next(self._data_iter) + except StopIteration: + return None + + def _get_output_map(self): + res_idx = 0 + for node in self.origin_model.graph.node: + for node_output in node.output: + self.output_map[node_output] = self.output_list[res_idx] + res_idx += 1 + + def _augment_input_map(self): + for temp in self.origin_model.graph.initializer: + npy_data = load_npy_from_buffer(temp.raw_data, _ONNX_DTYPE.get(temp.data_type), temp.dims) + self.input_map[temp.name] = npy_data + self.input_map = {**self.input_map, **self.output_map} + + def _get_input_output_map(self): + self._get_output_map() + self._augment_input_map() + + def _summ_dump_data(self): + net_output_nodes = [item.name for item in self.model_session.get_outputs()] + for node in self.origin_model.graph.node: + self.data_for_save.setdefault(node.name, {}) + if any(x in self.data_mode for x in DumpConst.INPUT_ALL): + input_data = self.through_nodes(node.input, node.name, DumpConst.INPUT_ARGS, self.input_map) + for item in input_data: + yield item, net_output_nodes + if any(x in self.data_mode for x in DumpConst.OUTPUT_ALL): + output_data = self.through_nodes(node.output, node.name, DumpConst.OUTPUT_ARGS, self.output_map) + for item in output_data: + yield item, net_output_nodes + + def _get_model_session(self): + try: + self.model_session = self.handler_session.call_data.get(_SECOND_CALL).get("args")[_FIRST_PARAM] + except Exception as e: + raise MsprobeException( + MsgConst.VALUE_NOT_FOUND, "The hook function failed to capture the model_session." + ) from e + + def _capture_model_session(self, *args, **kwargs): + return args, kwargs + + def _capture_input_map(self, *args, **kwargs): + self.input_map = args[2] + return args, kwargs + + def _capture_output_list(self, output, *args, **kwargs): + self.output_list = output + return output + + def _capture_origin_model(self, origin_model, *args, **kwargs): + self.origin_model = origin_model + return origin_model diff --git a/accuracy_tools/msprobe/core/components/dumper_tf.py b/accuracy_tools/msprobe/core/components/dumper_tf.py new file mode 100644 index 00000000000..63d04693a8f --- /dev/null +++ b/accuracy_tools/msprobe/core/components/dumper_tf.py @@ -0,0 +1,164 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.base import BaseComponent, Component, ProducerComp +from msprobe.core.base import BaseDumper +from msprobe.core.components.dumper_offline_model import OfflineModelActuatorComp +from msprobe.core.dump import FrozenGraphActuatorCPU, FrozenGraphActuatorNPU +from msprobe.utils.constants import CompConst, DumpConst +from msprobe.utils.env import evars +from msprobe.utils.hijack import ActionType, hijacker +from msprobe.utils.log import logger +from msprobe.utils.toolkits import get_net_output_nodes_from_graph_def + + +@Component.register(CompConst.FROZEN_GRAPH_ACTUATOR_COMP_CPU) +class FrozenGraphActuatorCompCPU(OfflineModelActuatorComp): + def __init__(self, priority, model_path, input_shape, input_path, **kwargs): + super().__init__(priority) + self.actuator = FrozenGraphActuatorCPU(model_path, input_shape, input_path, **kwargs) + + def activate(self, *args, **kwargs): + self.actuator.load_model() + inputs_tensor_info = self.actuator.get_input_tensor_info() + input_map = self.actuator.get_inputs_data(inputs_tensor_info) + _ = self.actuator.infer(input_map) + + +@Component.register(CompConst.FROZEN_GRAPH_DUMPER_COMP_CPU) +class FrozenGraphDumperCompCPU(ProducerComp, BaseDumper): + def __init__(self, priority, data_mode): + ProducerComp.__init__(self, priority) + BaseDumper.__init__(self, data_mode) + self.graph_def = None + self.infer_output = [] + self.tf_ops = [] + + def activate(self, *args, **kwargs): + self.register_hook() + + def deactivate(self, *args, **kwargs): + self.release_hook() + + def register_hook(self): + self.handler.append( + hijacker( + stub=self._capture_graph_def, + module="tensorflow.python.framework.importer", + function="_import_graph_def_internal", + action=ActionType.PRE_HOOK, + priority=20, + ) + ) + self.handler.append( + hijacker( + stub=self._capture_tf_ops, + module="tensorflow.python.client.session", + cls="Session", + function="run", + action=ActionType.PRE_HOOK, + priority=20, + ) + ) + self.handler.append( + hijacker( + stub=self._capture_output, + module="tensorflow.python.client.session", + cls="Session", + function="run", + action=ActionType.POST_HOOK, + priority=25, + ) + ) + + def load_data(self): + if self._data_iter is None: + self._get_input_output_map() + self._data_iter = self._summ_dump_data() + try: + return next(self._data_iter) + except StopIteration: + return None + + def _get_input_output_map(self): + self._get_output_map() + self._get_input_map() + + def _get_output_map(self): + self.output_map = {tensor.name: result for tensor, result in zip(self.tf_ops, self.infer_output)} + + def _get_input_map(self): + node_names = [tensor.op.name for tensor in self.tf_ops] + for idx, node_name in enumerate(node_names): + for input_tensor in self.tf_ops[idx].op.inputs: + input_data = self.output_map.get(input_tensor.name) + if input_data is None: + logger.warning(f"Input {input_tensor.name} for {node_name} not found.") + continue + self.input_map[input_tensor.name] = input_data + + def _summ_dump_data(self): + net_output_nodes = get_net_output_nodes_from_graph_def(self.graph_def) + for node in self.tf_ops: + self.data_for_save.setdefault(node.name, {}) + if any(x in self.data_mode for x in DumpConst.INPUT_ALL): + input_data = self.through_nodes(node.op.inputs, node.name, DumpConst.INPUT_ARGS, self.input_map) + for item in input_data: + yield item, net_output_nodes + if any(x in self.data_mode for x in DumpConst.OUTPUT_ALL): + output_data = self.through_nodes(node.op.outputs, node.name, DumpConst.OUTPUT_ARGS, self.output_map) + for item in output_data: + yield item, net_output_nodes + + def _capture_graph_def(self, *args, **kwargs): + self.graph_def = args[0] + return args, kwargs + + def _capture_tf_ops(self, *args, **kwargs): + self.tf_ops = args[1] + return args, kwargs + + def _capture_output(self, output, *args, **kwargs): + self.infer_output = output + return output + + +@Component.register(CompConst.FROZEN_GRAPH_ACTUATOR_COMP_NPU) +class FrozenGraphActuatorCompNPU(OfflineModelActuatorComp): + def __init__(self, priority, model_path, input_shape, input_path, **kwargs): + super().__init__(priority) + self.actuator = FrozenGraphActuatorNPU(model_path, input_shape, input_path, **kwargs) + + def activate(self, *args, **kwargs): + self.actuator.load_model() + inputs_tensor_info = self.actuator.get_input_tensor_info() + input_map = self.actuator.get_inputs_data(inputs_tensor_info) + _ = self.actuator.infer(input_map) + self.actuator.convert_txt2json() + + +@Component.register(CompConst.FROZEN_GRAPH_SET_GE_COMP_NPU) +class FrozenGraphSetGECompNPU(BaseComponent): + def __init__(self, priority, work_path, dump_ge_graph, dump_graph_level, dump_graph_path): + super().__init__(priority) + self.work_path = work_path + self.dump_ge_graph = dump_ge_graph + self.dump_graph_level = dump_graph_level + self.dump_graph_path = dump_graph_path + + def activate(self, *args, **kwargs): + evars.set(DumpConst.ENVVAR_ASCEND_WORK_PATH, self.work_path) + evars.set(DumpConst.ENVVAR_DUMP_GE_GRAPH, self.dump_ge_graph) + evars.set(DumpConst.ENVVAR_DUMP_GRAPH_LEVEL, self.dump_graph_level) + evars.set(DumpConst.ENVVAR_DUMP_GRAPH_PATH, self.dump_graph_path) diff --git a/accuracy_tools/msprobe/core/components/dumper_writer.py b/accuracy_tools/msprobe/core/components/dumper_writer.py new file mode 100644 index 00000000000..1b696765e5f --- /dev/null +++ b/accuracy_tools/msprobe/core/components/dumper_writer.py @@ -0,0 +1,149 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from inspect import stack + +from msprobe.base import Component, ConsumerComp +from msprobe.common.dirs import DirPool +from msprobe.common.stat import DataStat +from msprobe.core.base import RankDirFile, SaveTensor +from msprobe.utils.constants import CfgConst, CompConst, DumpConst +from msprobe.utils.io import save_json +from msprobe.utils.log import logger +from msprobe.utils.path import join_path + +_STACK_FILTER_PATH = ["msprobe/core", "msprobe/base", "msprobe/common", "msprobe/utils", "torch/nn/modules/module.py"] +_WITHOUT_CALL_STACK = "The call stack retrieval failed." + + +class DumpJson(RankDirFile): + def __init__(self, buffer_size, task, level, framework, summary_mode): + super().__init__(buffer_size) + self.cache_file = {} + self._init(task=task, level=level, framework=framework) + self.summary_mode = summary_mode + + def update_stat(self, node_name=None, in_out=None, args_name=None, npy_data=None): + if node_name not in self.cache_file[DumpConst.DATA]: + self.cache_file[DumpConst.DATA][node_name] = {} + self._update_dump_json( + self.cache_file[DumpConst.DATA][node_name], + in_out, + {**{"data_name": args_name}, **DataStat.collect_stats_for_numpy(npy_data, self.summary_mode)}, + ) + + def _save(self): + dump_json_path = join_path(self.rank_dir, DumpConst.DUMP_JSON) + save_json(self.cache_file, dump_json_path, indent=4) + + def _init(self, **kwargs): + self.cache_file.update( + { + CfgConst.TASK: kwargs.get("task", None), + CfgConst.LEVEL: kwargs.get("level", None), + CfgConst.FRAMEWORK: kwargs.get("framework", None), + DumpConst.DUMP_DATA_DIR: kwargs.get(DumpConst.DUMP_DATA_DIR, None), + DumpConst.DATA: {}, + } + ) + + def _update_dump_json(self, dump_dic, in_out, kwargs: dict): + if in_out not in dump_dic: + dump_dic[in_out] = [] + dump_dic.get(in_out).append(kwargs) + self.cover(dump_dic) + + +class StackJson(RankDirFile): + def __init__(self, buffer_size): + super().__init__(buffer_size) + self.cache_file = {} + + @staticmethod + def _call_stack(name: str): + try: + _stack = stack()[:5] + except Exception as e: + logger.warning(f"The call stack of {name} failed to retrieve, {e}.") + _stack = None + stack_str = [] + if _stack: + for _, path, line, func, code, _ in _stack: + if not code: + continue + if any(filter_path in path for filter_path in _STACK_FILTER_PATH): + continue + stack_line = f"File {path}, line {str(line)}, in {func}, \n {code[0].strip()}" + stack_str.append(stack_line) + else: + stack_str.append(_WITHOUT_CALL_STACK) + stack_info = {name: stack_str} + return stack_info + + def update_stack(self, name): + self.cache_file.update(self._call_stack(name)) + self.cover(self.cache_file) + + def _save(self): + stack_json_path = join_path(self.rank_dir, DumpConst.STACK_JSON) + save_json(self.cache_file, stack_json_path, indent=4) + + +@Component.register(CompConst.DUMP_WRITER_COMP) +class DumpWriterComp(ConsumerComp): + def __init__(self, priority, task, level, framework, summary_mode, strategy, buffer_size: int, dir_pool: DirPool): + super().__init__(priority) + self.net_output_nodes = None + self.task = task + self.dump_json = DumpJson(buffer_size, task, level, framework, summary_mode) + self.stack_json = StackJson(buffer_size) + self.strategy = strategy + self.save_strategy = SaveTensor.get(strategy)() + self.dir_pool = dir_pool + self.add_path() + + def add_path(self): + self.dump_json.add_rank_dir(self.dir_pool.rank_dir) + self.stack_json.add_rank_dir(self.dir_pool.rank_dir) + if self.task == CfgConst.TASK_TENSOR: + self.save_strategy.add_tensor_dir(self.dir_pool.tensor_dir) + + def consume(self, packages): + if self.task == CfgConst.TASK_TENSOR and not self.dump_json.cache_file.get(DumpConst.DUMP_DATA_DIR): + self.dump_json.cache_file[DumpConst.DUMP_DATA_DIR] = self.dir_pool.get_tensor_dir() + received_data = packages[0][1] + sealed_data, self.net_output_nodes = received_data + self.write(sealed_data) + return + + def write(self, sealed_data): + if self.strategy == DumpConst.NPY_FORMAT: + # Input sealed_data is a list, specifically: node_name, in_or_out, args_name, i, npy_data + self.dump_json.update_stat(sealed_data[0], sealed_data[1], sealed_data[2], sealed_data[4]) + if self.task == CfgConst.TASK_TENSOR: + if DumpConst.INPUT in sealed_data[1]: + self.save_strategy.save_tensor_data(sealed_data[0], sealed_data[2], sealed_data[4]) + elif DumpConst.OUTPUT in sealed_data[1]: + self.save_strategy.save_tensor_data(sealed_data[0], sealed_data[2], sealed_data[4]) + else: + self.save_strategy.save_tensor_data(sealed_data[0], sealed_data[2], sealed_data[4]) + + def finalize(self): + if self.net_output_nodes: + save_json(self.net_output_nodes, join_path(self.dir_pool.get_model_dir(), DumpConst.NET_OUTPUT_NODES_JSON)) + self._flush_remaining_cache() + + def _flush_remaining_cache(self): + self.dump_json.clear_cache() + self.stack_json.clear_cache() diff --git a/accuracy_tools/msprobe/core/config_initiator/__init__.py b/accuracy_tools/msprobe/core/config_initiator/__init__.py new file mode 100644 index 00000000000..492a36cd5db --- /dev/null +++ b/accuracy_tools/msprobe/core/config_initiator/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.config_initiator.config_dump import DumpConfig diff --git a/accuracy_tools/msprobe/core/config_initiator/config_dump.py b/accuracy_tools/msprobe/core/config_initiator/config_dump.py new file mode 100644 index 00000000000..30c1ef1806b --- /dev/null +++ b/accuracy_tools/msprobe/core/config_initiator/config_dump.py @@ -0,0 +1,110 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.base import BaseConfig +from msprobe.core.config_initiator.validate_params import ( + valid_data_mode, + valid_device, + valid_dump_extra, + valid_dump_ge_graph, + valid_dump_graph_level, + valid_dump_path, + valid_fusion_switch_file, + valid_input, + valid_list, + valid_onnx_fusion_switch, + valid_op_id, + valid_saved_model_signature, + valid_saved_model_tag, + valid_summary_mode, + valid_weight_path, +) +from msprobe.utils.constants import CfgConst, DumpConst + + +class DumpConfig(BaseConfig): + def check_config(self, dump_path: str = None): + self.config[self.config.get(CfgConst.TASK)] = self._check_dump_dic(dump_path) + return self.config + + def _check_dump_dic(self, dump_path: str = None): + self._update_config( + self.task_config, + DumpConst.DUMP_PATH, + valid_dump_path, + dump_path or self.task_config.get(DumpConst.DUMP_PATH, "./"), + ) + self._update_config( + self.task_config, + DumpConst.LIST, + valid_list, + (self.task_config.get(DumpConst.LIST, []), self.config.get(CfgConst.LEVEL)), + ) + self._update_config( + self.task_config, DumpConst.DATA_MODE, valid_data_mode, self.task_config.get(DumpConst.DATA_MODE, ["all"]) + ) + self._update_config( + self.task_config, + DumpConst.SUMMARY_MODE, + valid_summary_mode, + self.task_config.get(DumpConst.SUMMARY_MODE, CfgConst.TASK_STAT), + ) + self._update_config( + self.task_config, DumpConst.DUMP_EXTRA, valid_dump_extra, self.task_config.get(DumpConst.DUMP_EXTRA, []) + ) + self._update_config(self.task_config, DumpConst.OP_ID, valid_op_id, self.task_config.get(DumpConst.OP_ID, [])) + self._update_config( + self.task_config, + DumpConst.DUMP_GE_GRAPH, + valid_dump_ge_graph, + self.task_config.get(DumpConst.DUMP_GE_GRAPH, "2"), + ) + self._update_config( + self.task_config, + DumpConst.DUMP_GRAPH_LEVEL, + valid_dump_graph_level, + self.task_config.get(DumpConst.DUMP_GRAPH_LEVEL, "3"), + ) + self._update_config( + self.task_config, + DumpConst.FUSION_SWITCH_FILE, + valid_fusion_switch_file, + self.task_config.get(DumpConst.FUSION_SWITCH_FILE, None), + ) + self._update_config( + self.task_config, DumpConst.DEVICE, valid_device, self.task_config.get(DumpConst.DEVICE, None) + ) + self._update_config(self.task_config, DumpConst.INPUT, valid_input, self.task_config.get(DumpConst.INPUT, [])) + self._update_config( + self.task_config, + DumpConst.ONNX_FUSION_switch, + valid_onnx_fusion_switch, + self.task_config.get(DumpConst.ONNX_FUSION_switch, True), + ) + self._update_config( + self.task_config, + DumpConst.SAVED_MODEL_TAG, + valid_saved_model_tag, + self.task_config.get(DumpConst.SAVED_MODEL_TAG, ["serve"]), + ) + self._update_config( + self.task_config, + DumpConst.SAVED_MODEL_SIGN, + valid_saved_model_signature, + self.task_config.get(DumpConst.SAVED_MODEL_SIGN, "serving_default"), + ) + self._update_config( + self.task_config, DumpConst.WEIGHT_PATH, valid_weight_path, self.task_config.get(DumpConst.WEIGHT_PATH, None) + ) + return self.task_config diff --git a/accuracy_tools/msprobe/core/config_initiator/validate_params.py b/accuracy_tools/msprobe/core/config_initiator/validate_params.py new file mode 100644 index 00000000000..e39ff1bb9d9 --- /dev/null +++ b/accuracy_tools/msprobe/core/config_initiator/validate_params.py @@ -0,0 +1,325 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from itertools import product + +from msprobe.common.validation import parse_hyphen +from msprobe.utils.constants import DumpConst, MsgConst, PathConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger +from msprobe.utils.path import SafePath +from msprobe.utils.toolkits import check_int_border + +_OP_ID_PATTERN = r"^\d{1,10}(_\d{1,10}){0,9}$" +_ALL_DEVICE = {"cpu", "npu"} +_VALID_CHAR = r"^[a-zA-Z0-9_.-:]+$" + + +def check_special_char(value: str): + if not (isinstance(value, str) and re.match(_VALID_CHAR, value)): + raise MsprobeException(MsgConst.RISK_ALERT, f"Invalid input: contains unsafe characters: {value}.") + + +def valid_dump_path(value: str): + return SafePath(value, PathConst.DIR, "w").check() + + +def valid_list(value: tuple): + def re_format(value: tuple): + ret = {} + for ii in value[1]: + for vv in value[0]: + check_special_char(vv) + ret[ii] = value[0] + return ret + + if not value[0] or (isinstance(value[0], list) and len(value[1]) == 1): + return re_format(value) + elif isinstance(value[0], dict): + for key, vv in value[0].items(): + if key not in value[1]: + raise MsprobeException(MsgConst.INVALID_ARGU, f"Key not in allowed list {value[1]}, currently: {key}.") + if not isinstance(vv, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, f"Value must be a list, got {type(vv)} instead.") + for v in vv: + check_special_char(v) + return value[0] + else: + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, + """The list parameter supports two types: + 1. List, which requires "level" to be set with only one element. + 2. Dictionary, which allows "level" to be set with multiple elements.""", + ) + + +def valid_data_mode(value: list): + if not value: + return value + if not isinstance(value, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"data_mode" must be a list.') + if len(value) == 1: + if value[0] not in DumpConst.ALL_DATA_MODE: + raise MsprobeException( + MsgConst.INVALID_ARGU, f'"data_mode" must be one of {DumpConst.ALL_DATA_MODE}, currently: {value[0]}.' + ) + else: + raise MsprobeException(MsgConst.INVALID_ARGU, '"data_mode" only accepts a single-item list.') + return value + + +def valid_dump_extra(values: list): + if not values: + return values + if not isinstance(values, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"dump_extra" must be a list.') + for value in values: + if value not in DumpConst.ALL_DUMP_EXTRA: + raise MsprobeException( + MsgConst.INVALID_ARGU, f'"dump_extra" must be one of {DumpConst.ALL_DUMP_EXTRA}, currently: {value}.' + ) + return values + + +def valid_op_id(value: list): + if not value: + return value + if not isinstance(value, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"op_id" must be a list.') + res = [] + for element in value: + if isinstance(element, int): + check_int_border(element, tag="the integer part of op_id") + res.append(element) + elif isinstance(element, str) and re.match(_OP_ID_PATTERN, element): + res.append(element) + else: + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, + '"op_id" is only supported in the ATB dump scenario, ' + f"with formats like 2, 3_1, or 3_1_2, currently: {element}.", + ) + return res + + +def valid_dump_ge_graph(value: str): + if value is None: + return value + if not isinstance(value, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"dump_ge_graph" must be a string.') + if value not in DumpConst.ALL_DUMP_GE_GRAPH: + raise MsprobeException( + MsgConst.INVALID_ARGU, f'"dump_ge_graph" must be one of {DumpConst.ALL_DUMP_GE_GRAPH}, currently: {value}.' + ) + return value + + +def valid_dump_graph_level(value: str): + if value is None: + return value + if not isinstance(value, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"dump_graph_level" must be a string.') + if value not in DumpConst.ALL_DUMP_GRAPH_LEVEL: + raise MsprobeException( + MsgConst.INVALID_ARGU, + f'"dump_graph_level" must be one of {DumpConst.ALL_DUMP_GRAPH_LEVEL}, currently: {value}.', + ) + return value + + +def valid_fusion_switch_file(value: str): + if value is None: + return value + return SafePath(value, PathConst.FILE, "r", PathConst.SIZE_500M, (".json", ".cfg")).check() + + +def valid_device(value: str): + if value is None: + return value + if not isinstance(value, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"device" must be a string.') + if value not in _ALL_DEVICE: + raise MsprobeException(MsgConst.INVALID_ARGU, f'"device" must be one of {_ALL_DEVICE}, currently: {value}.') + return value + + +def valid_input(value: list): + if not value: + return value + return OfflineModelInput(value).parse() + + +class OfflineModelInput: + def __init__(self, input_list): + self.input_list = input_list + self._check_form() + self.is_need_expand_shape = False + + @staticmethod + def _check_name(infile: dict): + if not infile.get("name"): + raise MsprobeException(MsgConst.PARSING_FAILED, "Each input must have a name.") + return infile.get("name") + + @staticmethod + def _check_input_shape(infile: dict, name): + inshape = infile.get("shape") + if inshape: + if not isinstance(inshape, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, f'"shape" of the input {name} must be a list.') + for vv in inshape: + check_int_border(vv, tag=f'Elements in "shape" of the input {name}') + + @staticmethod + def _check_input_path(infile: dict, name): + inpath = infile.get("path") + if inpath: + if not isinstance(inpath, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, f'"path" of the input {name} must be a string.') + if not inpath.endswith((".bin", ".npy")): + raise MsprobeException( + MsgConst.INVALID_ARGU, f'"path" of {name} can only accept .npy or .bin files, currently: {inpath}.' + ) + _ = SafePath(inpath, PathConst.FILE, "r", PathConst.SIZE_10G).check() + + @staticmethod + def _parse_shape_range_for_str(shape): + if "-" in shape: + ranges = parse_hyphen(shape, tag="Elements in a dynamic shape") + elif "," in shape and shape.count(",") == 1: + try: + ranges = list(map(int, shape.split(","))) + except Exception as e: + raise MsprobeException( + MsgConst.INVALID_ARGU, + f"Both sides of the hyphen (-) in the input must be numbers, currently: {shape}.", + ) from e + else: + raise MsprobeException( + MsgConst.INVALID_ARGU, 'The "dym_shape" of the input can only contain hyphen (-) or a comma (,).' + ) + return ranges + + def parse(self): + logger.info("Start parsing the input list.") + modify_file = [] + for infile in self.input_list: + name = self._check_name(infile) + self._check_input_shape(infile, name) + self._check_input_path(infile, name) + infile = self._check_dym_shape(infile, name) + modify_file.append(infile) + shapes, paths = self._draw_shape_and_path(modify_file) + return shapes, paths + + def _parse_dym_shape_range(self, shapes, name): + if not isinstance(shapes, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, f'"dym_shape" of the input {name} must be a list.') + shapes_list = [] + for shape in shapes: + if isinstance(shape, str): + ranges = self._parse_shape_range_for_str(shape) + elif isinstance(shape, int): + check_int_border(shape, tag="Integer in a dynamic shape") + ranges = [shape] + else: + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, + f'Elements in "dym_shape" of the input support only string and integers, currently: {shape}.', + ) + shapes_list.append(ranges) + return [list(s) for s in list(product(*shapes_list))] + + def _check_form(self): + if isinstance(self.input_list, list): + for vv in self.input_list: + if not isinstance(vv, dict): + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, "Each element in the input must be a dictionary." + ) + else: + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "The input must be a list.") + + def _check_dym_shape(self, infile: dict, name: str): + if infile.get("dym_shape"): + self.is_need_expand_shape = True + infile["dym_shape"] = self._parse_dym_shape_range(infile["dym_shape"], name) + infile["shape"] = [] + if infile.get("path"): + infile["path"] = "" + logger.warning('Since "dym_shape" is used, "shape" and "path" will not take effect.') + return infile + + def _draw_shape_and_path(self, modify_file): + if self.is_need_expand_shape: + dym_shapes = [item["dym_shape"] for item in modify_file] + if all(len(shapes) == len(dym_shapes[0]) for shapes in dym_shapes): + shapes = [dict(zip([item["name"] for item in modify_file], shapes)) for shapes in zip(*dym_shapes)] + paths = None + else: + raise MsprobeException( + MsgConst.INVALID_ARGU, "Ensure all inputs have the same expanded dynamic shape length." + ) + else: + shapes, paths = {}, [] + for item in modify_file: + shapes[item["name"]] = item.get("shape") + if item.get("path"): + paths.append(item["path"]) + return shapes, paths + + +def valid_onnx_fusion_switch(value: bool): + if not value: + return value + if not isinstance(value, bool): + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, f'"onnx_fusion_switch" must be a boolean, currently: {value}.' + ) + return value + + +def valid_saved_model_tag(value: list): + if not value: + return value + if not isinstance(value, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "saved_model_tag msut be a list.") + for vv in value: + check_special_char(vv) + return value + + +def valid_saved_model_signature(value: str): + if value is None: + return value + check_special_char(value) + return value + + +def valid_weight_path(value: str): + if value is None: + return value + return SafePath(value, PathConst.FILE, "r", PathConst.SIZE_50G, ".caffemodel").check() + + +def valid_summary_mode(value: str): + if value is None: + return value + if value not in DumpConst.ALL_SUMMARY_MODE: + raise MsprobeException( + MsgConst.INVALID_ARGU, + f'"summary_mode" must be one of {DumpConst.ALL_SUMMARY_MODE}, currently: {value}.', + ) + return value diff --git a/accuracy_tools/msprobe/core/dump/__init__.py b/accuracy_tools/msprobe/core/dump/__init__.py new file mode 100644 index 00000000000..6267a4ba6e6 --- /dev/null +++ b/accuracy_tools/msprobe/core/dump/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.dump.acl_manager import acl_device_manager +from msprobe.core.dump.caffe_model import CaffeModelActuator +from msprobe.core.dump.om_model import OmModelActuator +from msprobe.core.dump.onnx_model import OnnxModelActuator +from msprobe.core.dump.tf_model import FrozenGraphActuatorCPU, FrozenGraphActuatorNPU diff --git a/accuracy_tools/msprobe/core/dump/acl_manager.py b/accuracy_tools/msprobe/core/dump/acl_manager.py new file mode 100644 index 00000000000..59e1785c252 --- /dev/null +++ b/accuracy_tools/msprobe/core/dump/acl_manager.py @@ -0,0 +1,109 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.lib.msprobe_c import acl +from msprobe.utils.constants import ACLConst, MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger + + +class ACLDeviceManager: + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(ACLDeviceManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + self.acl_device_manager_map = {} + + def get_acl_resource_manager(self, rank): + if rank not in self.acl_device_manager_map: + self.acl_device_manager_map[rank] = ACLResourceManager(rank) + return self.acl_device_manager_map[rank] + + +class ACLResourceManager: + def __init__(self, rank=0): + self.ptr_context = None + self.is_acl_initialized = False + self.is_set_dump = False + self.rank = rank + + def initialize(self): + if self.is_acl_initialized: + return + ret = acl.init() + if ret == ACLConst.SUCCESS: + logger.info("Acl init success!") + else: + raise MsprobeException(MsgConst.CALL_FAILED, f"Acl init failed! ErrorCode = {ret}.") + ret = acl.rt_set_device(self.rank) + if ret == ACLConst.SUCCESS: + logger.info(f"Set device:{self.rank} success!") + else: + raise MsprobeException(MsgConst.CALL_FAILED, f"Acl set device:{self.rank} failed! ErrorCode = {ret}.") + self.ptr_context, ret = acl.rt_create_context(self.rank) + if ret == ACLConst.SUCCESS: + logger.info("Create new context success!") + else: + raise MsprobeException(MsgConst.CALL_FAILED, f"Acl create context failed! ErrorCode = {ret}.") + self.is_acl_initialized = True + + def set_dump(self, dump_cfg_path, message_call_back): + if not self.is_acl_initialized or self.is_set_dump: + return + ret = acl.init_dump() + if ret != ACLConst.SUCCESS: + raise MsprobeException(MsgConst.CALL_FAILED, f"Acl init dump failed! ErrorCode = {ret}.") + ret = acl.dump_reg_callback(message_call_back, 0) + if ret != ACLConst.SUCCESS: + raise MsprobeException(MsgConst.CALL_FAILED, f"Acl dump reg callback failed! ErrorCode = {ret}.") + ret = acl.set_dump(dump_cfg_path) + if ret != ACLConst.SUCCESS: + raise MsprobeException(MsgConst.CALL_FAILED, f"Acl set dump failed! ErrorCode = {ret}.") + self.is_set_dump = True + + def destroy_resource(self): + if not self.is_acl_initialized: + return + self._finalize_dump() + if self.ptr_context is not None: + ret = acl.rt_destroy_context(self.ptr_context) + if ret != ACLConst.SUCCESS: + logger.error(f"Destroy context failed! ErrorCode = {ret}.") + ret = acl.rt_reset_device(self.rank) + if ret != ACLConst.SUCCESS: + logger.error(f"Reset deivce failed! DeviceId = {self.rank}, ErrorCode = {ret}.") + else: + logger.info(f"End to reset device:{self.rank}.") + ret = acl.finalize() + if ret != ACLConst.SUCCESS: + logger.error(f"Finalize failed! ErrorCode = {ret}.") + else: + logger.info("End to finalize.") + self.is_acl_initialized = False + + def _finalize_dump(self): + if not self.is_set_dump: + return + acl.dump_unreg_callback() + ret = acl.finalize_dump() + if ret != ACLConst.SUCCESS: + logger.error(f"Finalize dump failed! ErrorCode = {ret}.") + self.is_set_dump = False + + +acl_device_manager = ACLDeviceManager() diff --git a/accuracy_tools/msprobe/core/dump/caffe_model.py b/accuracy_tools/msprobe/core/dump/caffe_model.py new file mode 100644 index 00000000000..ad395e75e94 --- /dev/null +++ b/accuracy_tools/msprobe/core/dump/caffe_model.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from msprobe.core.base import OfflineModelActuator +from msprobe.utils.constants import MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.io import load_caffe_model +from msprobe.utils.log import logger + + +class CaffeModelActuator(OfflineModelActuator): + def __init__(self, model_path, input_shape, input_path, **kwargs): + super().__init__(model_path, input_shape, input_path, **kwargs) + self.weight_path = kwargs.get("weight_path", "") + if not self.weight_path: + raise MsprobeException( + MsgConst.REQUIRED_ARGU_MISSING, + "When using Caffe for inference, a weight file (.caffemodel) is required.", + ) + + def load_model(self): + self.model = load_caffe_model(self.model_path, self.weight_path) + + def get_input_tensor_info(self): + inputs_tensor_info = [] + input_blob_names = list(self.model.blobs.keys())[: len(self.model.inputs)] + for input_name in input_blob_names: + tensor_data = self.model.blobs[input_name].data + tensor_info = {"name": input_name, "shape": tuple(tensor_data.shape), "type": str(tensor_data.dtype)} + inputs_tensor_info.append(tensor_info) + logger.warning( + "Caffe model doesn't support dynamic shapes and " + "will use the input shape defined in the model for inference." + ) + logger.info(f"Model input tensor info: {inputs_tensor_info}.") + return inputs_tensor_info + + def infer(self, input_map): + try: + for input_name, input_data in input_map.items(): + np.copyto(self.model.blobs[input_name].data, input_data) + return self.model.forward() + except Exception as e: + raise MsprobeException( + MsgConst.CALL_FAILED, "Please check if the input shape or data matches the model requirements." + ) from e diff --git a/accuracy_tools/msprobe/core/dump/om_model.py b/accuracy_tools/msprobe/core/dump/om_model.py new file mode 100644 index 00000000000..ef0caa8245e --- /dev/null +++ b/accuracy_tools/msprobe/core/dump/om_model.py @@ -0,0 +1,171 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.common.ascend import cann +from msprobe.core.base import OfflineModelActuator +from msprobe.lib.msprobe_c import acl +from msprobe.utils.constants import ACLConst, MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.io import load_om_model +from msprobe.utils.log import logger +from msprobe.utils.path import get_name_and_ext, join_path + +_BUFFER_METHOD_MAP = {"input": acl.get_input_size_by_index, "output": acl.get_output_size_by_index} + + +class OmModelActuator(OfflineModelActuator): + def __init__(self, model_path, input_shape, input_path, **kwargs): + super().__init__(model_path, input_shape, input_path, **kwargs) + self.ptr_model_desc = None + self.ptr_input_dataset = None + self.ptr_output_dataset = None + self.model_id = None + self.input_size = 0 + self.output_size = 0 + self.input_ptr_size = [] + self.output_ptr_size = [] + + def load_model(self): + self.model_id = load_om_model(self.model_path) + + def get_input_tensor_info(self): + inputs_tensor_info = [] + self._get_model_info() + for index in range(self.input_size): + name = acl.get_input_name_by_index(self.ptr_model_desc, index) + if name is None: + raise MsprobeException(MsgConst.CALL_FAILED, f"Get input name by index:{index} failed!") + shape, ret = acl.get_input_dims(self.ptr_model_desc, index) + if shape is None or ret != ACLConst.SUCCESS: + raise MsprobeException(MsgConst.CALL_FAILED, f"Get input shape by index:{index} failed!") + dtype = acl.get_input_data_type(self.ptr_model_desc, index) + if dtype is None: + raise MsprobeException(MsgConst.CALL_FAILED, f"Get input type by index:{index} failed!") + inputs_tensor_info.append({"name": name, "shape": shape["dims"], "type": dtype}) + logger.info(f"Model input tensor info: {inputs_tensor_info}.") + return inputs_tensor_info + + def infer(self, input_map): + self._create_data_buffer() + self._copy_data_from_host_to_device(input_map) + self._run() + self._destroy_data_buffer() + self._destroy_resource() + + def convert_om2json(self): + name, _ = get_name_and_ext(self.model_path) + json_path = join_path(self.dir_pool.get_model_dir(), name + ".json") + cann.model2json(self.model_path, json_path) + + def _run(self): + ret = acl.execute(self.model_id, self.ptr_input_dataset, self.ptr_output_dataset) + if ret != ACLConst.SUCCESS: + raise MsprobeException(MsgConst.CALL_FAILED, f"Model execute failed! ErrorCode = {ret}.") + else: + logger.info("Model execute success!") + + def _get_model_info(self): + self.ptr_model_desc = acl.create_desc() + if self.ptr_model_desc is None: + raise MsprobeException(MsgConst.CALL_FAILED, "Create model description Failed!") + ret = acl.get_desc(self.ptr_model_desc, self.model_id) + if ret != ACLConst.SUCCESS: + raise MsprobeException(MsgConst.CALL_FAILED, f"Get model description failed! ErrorCode = {ret}.") + self.input_size = acl.get_num_inputs(self.ptr_model_desc) + if self.input_size is None: + raise MsprobeException(MsgConst.CALL_FAILED, "Get input nums failed!") + self.output_size = acl.get_num_outputs(self.ptr_model_desc) + if self.output_size is None: + raise MsprobeException(MsgConst.CALL_FAILED, "Get output nums failed!") + logger.info("Create model description Success!") + + def _create_data_buffer(self): + for mode in ["input", "output"]: + ptr_dataset = getattr(self, f"ptr_{mode}_dataset", None) + data_size = getattr(self, f"{mode}_size", 0) + ptr_size_map = getattr(self, f"{mode}_ptr_size", []) + ptr_dataset = acl.create_dataset() + if ptr_dataset is None: + raise MsprobeException(MsgConst.CALL_FAILED, f"Create {mode} dataset failed!") + for index in range(data_size): + temp_buffer_size = _BUFFER_METHOD_MAP.get(mode)(self.ptr_model_desc, index) + if temp_buffer_size is None: + raise MsprobeException(MsgConst.CALL_FAILED, f"Get {mode} size by index:{index} failed!") + temp_ptr, ret = acl.rt_malloc(temp_buffer_size) + if ret != ACLConst.SUCCESS: + raise MsprobeException(MsgConst.CALL_FAILED, f"{mode.title()} malloc failed! ErrorCode = {ret}.") + ptr_size_map.append({"buffer": temp_ptr, "size": temp_buffer_size}) + temp_buffer = acl.create_databuffer(temp_ptr, temp_buffer_size) + if temp_buffer is None: + acl.rt_free(temp_ptr) + raise MsprobeException(MsgConst.CALL_FAILED, f"Create {mode} buffer failed!") + ret = acl.add_dataset_buffer(ptr_dataset, temp_buffer) + if ret != ACLConst.SUCCESS: + acl.rt_free(temp_ptr) + raise MsprobeException( + MsgConst.CALL_FAILED, f"Add {mode} buffer to dataset failed! ErrorCode = {ret}." + ) + setattr(self, f"ptr_{mode}_dataset", ptr_dataset) + setattr(self, f"{mode}_size_map", ptr_size_map) + + def _destroy_resource(self): + ret = acl.unload(self.model_id) + if ret != ACLConst.SUCCESS: + logger.error(f"Unload model failed! ErrorCode = {ret}.") + else: + logger.info("End to unload model.") + if self.ptr_model_desc is not None: + ret = acl.destroy_desc(self.ptr_model_desc) + if ret != ACLConst.SUCCESS: + logger.error(f"Destroy model description failed! ErrorCode = {ret}.") + + def _destroy_data_buffer(self): + for mode in ["input", "output"]: + dataset = getattr(self, f"ptr_{mode}_dataset", None) + ptr_size_map = getattr(self, f"{mode}_ptr_size", []) + if dataset is None or not ptr_size_map: + return + buffer_nums = acl.get_dataset_num_buffers(dataset) + if buffer_nums is None: + logger.error(f"Get dataset num buffers failed!") + return + for index in range(buffer_nums): + data_buffer = acl.get_dataset_buffer(dataset, index) + if data_buffer is None: + logger.error(f"From {mode} dataset get dataBuffer failed!") + continue + ret = acl.destroy_databuffer(data_buffer) + if ret != ACLConst.SUCCESS: + logger.error(f"Destroy dataBuffer failed! ErrorCode = {ret}.") + ret = acl.destroy_dataset(dataset) + if ret != ACLConst.SUCCESS: + logger.error(f"Destroy {mode} dataset failed! ErrorCode = {ret}.") + for items in ptr_size_map: + ptr = items.get("buffer", None) + ret = acl.rt_free(ptr) + if ret != ACLConst.SUCCESS: + logger.error(f"Free Failed! ErrorCode = {ret}.") + + def _copy_data_from_host_to_device(self, input_map): + if len(input_map) != len(self.input_ptr_size): + logger.warning(f"input_map size:{len(input_map)} not equal input_ptr_size:{len(self.input_ptr_size)}") + return + for index, (_, input_data) in enumerate(input_map.items()): + dest_ptr = self.input_ptr_size[index].get("buffer", None) + dest_size = self.input_ptr_size[index].get("size", 0) + byte_data = input_data.tobytes() + ret = acl.rt_memcpy(dest_ptr, dest_size, byte_data, len(byte_data), ACLConst.MEMCPY_HOST_TO_DEVICE) + if ret != ACLConst.SUCCESS: + logger.error(f"Memcpy Input data from host to device failed! ErrorCode = {ret}.") + return diff --git a/accuracy_tools/msprobe/core/dump/onnx_model.py b/accuracy_tools/msprobe/core/dump/onnx_model.py new file mode 100644 index 00000000000..64afe1dc9a9 --- /dev/null +++ b/accuracy_tools/msprobe/core/dump/onnx_model.py @@ -0,0 +1,68 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.base import OfflineModelActuator +from msprobe.utils.constants import MsgConst, PathConst +from msprobe.utils.dependencies import dependent +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.io import load_onnx_model, load_onnx_session, save_onnx_model +from msprobe.utils.log import logger +from msprobe.utils.path import convert_bytes, get_basename_from_path, is_file, join_path + + +class OnnxModelActuator(OfflineModelActuator): + def __init__(self, model_path, input_shape, input_path, **kwargs): + super().__init__(model_path, input_shape, input_path, **kwargs) + + @staticmethod + def infer(uninfer_model_path, input_map): + model_session = load_onnx_session(uninfer_model_path) + output_name = [node.name for node in model_session.get_outputs()] + try: + return model_session.run(output_name, input_map) + except Exception as e: + raise MsprobeException( + MsgConst.CALL_FAILED, "Please check if the input shape or data matches the model requirements." + ) from e + + def load_model(self): + self.origin_model = load_onnx_model(self.model_path) + self.model_session = load_onnx_session(self.model_path, self.kwargs.get("onnx_fusion_switch", True)) + + def get_input_tensor_info(self): + inputs_tensor_info = [] + for input_item in self.model_session.get_inputs(): + tensor_name, tensor_type, tensor_shape = (input_item.name, input_item.type, tuple(input_item.shape)) + tensor_shape_info = self.process_tensor_shape(tensor_name, tensor_type, tensor_shape) + inputs_tensor_info.extend(tensor_shape_info) + logger.info(f"Model input tensor info: {inputs_tensor_info}.") + return inputs_tensor_info + + def export_uninfer_model(self): + model_name = "inferential_" + get_basename_from_path(self.model_path) + uninfer_model_path = join_path(self.dir_pool.get_model_dir(), model_name) + if not is_file(uninfer_model_path): + onnx = dependent.get("onnx") + del self.origin_model.graph.output[:] + for node in self.origin_model.graph.node: + for tensor_name in node.output: + value_info = onnx.ValueInfoProto(name=tensor_name) + self.origin_model.graph.output.append(value_info) + model_size = self.origin_model.ByteSize() + logger.info(f"The size of the modified ONNX model to be saved is {convert_bytes(model_size)}.") + if model_size < 0 or model_size > PathConst.SIZE_2G: + logger.warning("The modified ONNX model size has exceeded 2GB, posing a risk of numerical overflow.") + save_onnx_model(self.origin_model, uninfer_model_path) + logger.info(f"The modified ONNX model has been successfully saved to {uninfer_model_path}.") + return uninfer_model_path diff --git a/accuracy_tools/msprobe/core/dump/tf_model.py b/accuracy_tools/msprobe/core/dump/tf_model.py new file mode 100644 index 00000000000..bf51a554a81 --- /dev/null +++ b/accuracy_tools/msprobe/core/dump/tf_model.py @@ -0,0 +1,171 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from glob import glob + +from msprobe.common.ascend import cann +from msprobe.core.base import OfflineModelActuator +from msprobe.utils.constants import MsgConst +from msprobe.utils.dependencies import dependent +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.io import load_pb_frozen_graph_model +from msprobe.utils.log import logger +from msprobe.utils.path import get_name_and_ext, join_path +from msprobe.utils.toolkits import get_net_output_nodes_from_graph_def + + +class FrozenGraphActuator(OfflineModelActuator): + def __init__(self, model_path, input_shape, input_path, **kwargs): + super().__init__(model_path, input_shape, input_path, **kwargs) + self.tf, self.rewriter_config = self._import_tf() + self.sess = None + self.graph_def = None + self.all_node_names = [] + + @staticmethod + def _import_tf(): + pons = dependent.get_tensorflow() + if None not in pons: + tf, rewriter_config, _ = pons + tf.compat.v1.disable_eager_execution() + return tf, rewriter_config + return None, None + + @staticmethod + def _get_tensor_name(name: str): + return name.split(":")[0] + + @staticmethod + def _tf_shape_to_list(tensor_shape): + shape_list = [] + for dim in tensor_shape.dim: + if dim.size == -1: + shape_list.append(None) + else: + shape_list.append(dim.size) + return shape_list + + def close(self): + if self.sess is not None: + try: + self.sess.close() + except AttributeError: + pass + self.sess = None + + def get_input_tensor_info(self): + inputs_tensor_info = [] + for node in self.graph_def.node: + if node.op == "Placeholder": + tensor_name = node.name + tensor_dtype = self.tf.dtypes.as_dtype(node.attr["dtype"].type) + tensor_shape = self._tf_shape_to_list(node.attr["shape"].shape) + inputs_tensor_info.extend(self.process_tensor_shape(tensor_name, tensor_dtype, tensor_shape)) + self.all_node_names.append(node.name) + logger.info(f"Model input tensor info: {inputs_tensor_info}.") + return inputs_tensor_info + + def load_model(self): + self.graph_def = load_pb_frozen_graph_model(self.model_path) + + def infer(self, input_map: dict): + self.sess = self._open_session() + self._renew_all_node_names() + tf_ops = self._get_tf_ops() + feed_dict = self._build_feed(input_map) + try: + outputs = self.sess.run(tf_ops, feed_dict=feed_dict) + except Exception as e: + raise MsprobeException( + MsgConst.CALL_FAILED, "Please check if the input shape or data matches the model requirements." + ) from e + self.close() + return outputs + + def _open_session(self): + return + + def _renew_all_node_names(self): + pass + + def _get_tf_ops(self): + tf_ops = [] + for name in self.all_node_names: + try: + tf_ops.append(self.sess.graph.get_tensor_by_name(name + ":0")) + except Exception as e: + raise MsprobeException( + MsgConst.CALL_FAILED, f'The model lacks the {name + ":0"} node. Please check your model.' + ) from e + return tf_ops + + def _build_feed(self, input_map: dict): + feed_dict = {} + for name, input_data in input_map.items(): + tensor_name = name + ":0" if ":" not in name else name + try: + feed_dict[self.sess.graph.get_tensor_by_name(tensor_name)] = input_data + except Exception as e: + raise MsprobeException( + MsgConst.CALL_FAILED, f"The model lacks the {tensor_name} node. Please check your model." + ) from e + return feed_dict + + +class FrozenGraphActuatorCPU(FrozenGraphActuator): + def __init__(self, model_path, input_shape, input_path, **kwargs): + super().__init__(model_path, input_shape, input_path, **kwargs) + + def _open_session(self): + return self.tf.compat.v1.Session() + + +class FrozenGraphActuatorNPU(FrozenGraphActuator): + def __init__(self, model_path, input_shape, input_path, **kwargs): + super().__init__(model_path, input_shape, input_path, **kwargs) + self.data_mode = kwargs.get("data_mode", ["all"]) + self.fusion_switch_file = kwargs.get("fsf", "") + + def convert_txt2json(self): + model_path = sorted(glob(join_path(self.dir_pool.get_model_dir(), "*", "*_Build.txt"))) + if model_path: + name, _ = get_name_and_ext(model_path[-1]) + cann.model2json(model_path[-1], join_path(self.dir_pool.get_model_dir(), name + ".json")) + else: + raise MsprobeException( + MsgConst.PATH_NOT_FOUND, "No TXT format graph file found in the TensorFlow framework." + ) + + def _open_session(self): + npu_device = dependent.get("npu_device") + if not npu_device: + raise MsprobeException( + MsgConst.ATTRIBUTE_ERROR, "Please ensure that the TF plugin npu_device is properly installed." + ) + npu_device.compat.enable_v1() + config_proto = self.tf.compat.v1.ConfigProto() + custom_op = config_proto.graph_options.rewrite_options.custom_optimizers.add() + custom_op.name = "NpuOptimizer" + custom_op.parameter_map["enable_dump"].b = True + custom_op.parameter_map["dump_path"].s = self.tf.compat.as_bytes(self.dir_pool.get_rank_dir()) + custom_op.parameter_map["dump_step"].s = self.tf.compat.as_bytes("0") + custom_op.parameter_map["data_mode"].s = self.tf.compat.as_bytes(self.data_mode[0]) + if self.fusion_switch_file: + logger.info(f"Fusion switch settings read from {self.fusion_switch_file}.") + custom_op.parameter_map["fusion_switch_file"].s = self.tf.compat.as_bytes(self.fusion_switch_file) + config_proto.graph_options.rewrite_options.remapping = self.rewriter_config.OFF + return self.tf.compat.v1.Session(config=config_proto) + + def _renew_all_node_names(self): + self.all_node_names = get_net_output_nodes_from_graph_def(self.graph_def) diff --git a/accuracy_tools/msprobe/core/service/__init__.py b/accuracy_tools/msprobe/core/service/__init__.py new file mode 100644 index 00000000000..149700e99b3 --- /dev/null +++ b/accuracy_tools/msprobe/core/service/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.service.dump import ServiceDump diff --git a/accuracy_tools/msprobe/core/service/dump.py b/accuracy_tools/msprobe/core/service/dump.py new file mode 100644 index 00000000000..ce44fc5be47 --- /dev/null +++ b/accuracy_tools/msprobe/core/service/dump.py @@ -0,0 +1,315 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tqdm import tqdm + +from msprobe.base import BaseService, Component, Dict2Class, Service +from msprobe.common.dirs import DirPool +from msprobe.core.config_initiator import DumpConfig +from msprobe.utils.constants import CfgConst, CmdConst, CompConst, DumpConst, MsgConst, PathConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.io import savedmodel2pb +from msprobe.utils.log import logger, print_log_with_star +from msprobe.utils.path import get_name_and_ext, is_file, is_saved_model_scene + + +@Service.register(CmdConst.DUMP) +class ServiceDump(BaseService): + def __init__(self, *args, **kwargs): + super().__init__() + args, dump_path, namespace = self._parse_kw(**kwargs) + config = DumpConfig(*args).check_config(dump_path) + self.cfg = Dict2Class(config) + setattr(self.cfg, CfgConst.EXEC, namespace.exec) + logger.set_level(self.cfg.log_level) + DirPool.make_msprobe_dir(self.cfg.dump_path) + self.activate_comp_when_init() + + @property + def is_skip(self): + if not self._is_step_in_goal or not self._is_rank_in_goal: + logger.info(f"Skip task {self.cfg.task}, step {self.cfg.step}, rank {self.cfg.rank}.") + return True + return False + + @property + def _is_indep_control(self): + return self.cfg.framework == CfgConst.FRAMEWORK_MINDIE_LLM + + @property + def _is_step_in_goal(self): + if self._is_indep_control: + return True + return not self.cfg.step or self.current_step in self.cfg.step + + @property + def _is_rank_in_goal(self): + if self._is_indep_control: + return True + return not self.cfg.rank or self.current_rank in self.cfg.rank + + @property + def _is_offline_model(self): + if len(self.cfg.exec) == 1: + if is_file(self.cfg.exec[0]) and get_name_and_ext(self.cfg.exec[0])[1] in PathConst.SUFFIX_OFFLINE_MODEL: + return True + elif is_saved_model_scene(self.cfg.exec[0]): + return True + else: + return False + return False + + @property + def _offline_model_comps_map(self): + model_map_for_cpu = { + ".onnx": "_construct_for_onnx_model", + ".pb": "_construct_for_frozen_graph_model_on_cpu", + "saved_model": "_construct_for_saved_model_on_cpu", + ".prototxt": "_construct_for_caffe_model", + } + model_map_for_npu = { + ".om": "_construct_for_om_model", + ".pb": "_construct_for_frozen_graph_model_on_npu", + "saved_model": "_construct_for_saved_model_on_npu", + } + device_handlers = {"cpu": model_map_for_cpu, "npu": model_map_for_npu} + return device_handlers + + @property + def _online_model_comps_map(self): + framework_handlers = {CfgConst.FRAMEWORK_MINDIE_LLM: "_construct_for_atb_model"} + return framework_handlers + + @property + def _is_make_model_dir(self): + return self._is_offline_model or self.cfg.framework == CfgConst.FRAMEWORK_MINDIE_LLM + + @staticmethod + def _parse_kw(**kwargs): + task = kwargs.get(CfgConst.TASK) + step = kwargs.get(CfgConst.STEP) + level = kwargs.get(CfgConst.LEVEL) + dump_path = kwargs.get(DumpConst.DUMP_PATH) + cmd_namespace = kwargs.get("cmd_namespace") + if hasattr(cmd_namespace, CfgConst.CONFIG_PATH): + config_path = cmd_namespace.config_path + else: + config_path = kwargs.get(CfgConst.CONFIG_PATH) + if hasattr(cmd_namespace, "framework"): + framework = cmd_namespace.framework + else: + framework = kwargs.get(CfgConst.FRAMEWORK) + return (config_path, task, framework, step, level), dump_path, cmd_namespace + + def activate_comp_when_init(self): + pass + + def init_start(self): + print_log_with_star(f"Launching {self.cfg.task} task...") + self.make_dirs() + + def make_dirs(self): + if self._is_make_model_dir: + DirPool.make_model_dir() + if self._is_indep_control: + return + self.dir_pool = DirPool() + self.dir_pool.make_step_dir(self.current_step) + self.dir_pool.make_rank_dir() + if self.cfg.task == CfgConst.TASK_TENSOR: + self.dir_pool.make_tensor_dir() + + def finalize_start(self): + if hasattr(self, "writer"): + self.writer.finalize() + print_log_with_star(f"{self.cfg.task} task completed successfully.") + + def construct(self): + if self._is_offline_model: + device_handler = self._offline_model_comps_map.get(self.cfg.device) + if not device_handler: + raise MsprobeException( + MsgConst.INVALID_ARGU, + '"device" must be set to either "cpu" or "npu" when dumping the offline model.', + ) + exec_type = self.cfg.exec[0] + model_key = ( + "saved_model" + if is_saved_model_scene(exec_type) + else next((key for key in device_handler if exec_type.endswith(key)), None) + ) + handler_name = device_handler.get(model_key) + else: + handler_name = self._online_model_comps_map.get(self.cfg.framework) + if handler_name: + getattr(self, handler_name)() + else: + raise MsprobeException(MsgConst.INVALID_ARGU, "Unsupported framework. Please check parameter settings.") + + def run_cli(self): + if self._is_offline_model: + if isinstance(self.cfg.input_shape, list) and len(self.cfg.input_shape) > 1: + for inshape in tqdm(self.cfg.input_shape, desc="Processing"): + self.cfg.input_shape = inshape + self.start() + self.step() + self.stop() + else: + self.start() + self.stop() + else: + self.start() + self.stop() + + def _construct_for_om_model(self): + self.actuator = Component.get(CompConst.OM_ACTUATOR_COMP)( + priority=20, + model_path=self.cfg.exec[0], + input_shape=self.cfg.input_shape, + input_path=self.cfg.input_path, + dir_pool=self.dir_pool, + rank=self.cfg.rank[0] if self.cfg.rank else 0, + ) + self.dumper = Component.get(CompConst.ACL_DUMPER_COMP)( + priority=10, + data_mode=self.cfg.data_mode, + model_path=self.cfg.exec[0], + rank=self.cfg.rank[0] if self.cfg.rank else 0, + ) + self.compatible = Component.get(CompConst.ACL_COMPATIBLE_COMP)(priority=12) + self.writer = Component.get(CompConst.DUMP_WRITER_COMP)( + priority=15, + task=self.cfg.task, + level=CfgConst.LEVEL_KERNEL, + framework=CfgConst.FRAMEWORK_OM, + summary_mode=self.cfg.summary_mode, + strategy=DumpConst.BIN_FORMAT, + buffer_size=self.cfg.buffer_size, + dir_pool=self.dir_pool, + ) + self.compatible.subscribe(self.dumper) + self.writer.subscribe(self.compatible) + + def _construct_for_atb_model(self): + self.actuator = Component.get(CompConst.ATB_ACTUATOR_COMP)( + priority=100, + dump_path=DirPool.get_msprobe_dir(), + task=self.cfg.task, + dump_level=self.cfg.level, + step=self.cfg.step, + rank=self.cfg.rank, + seed=self.cfg.seed, + log_level=self.cfg.log_level, + summary_mode=self.cfg.summary_mode, + buffer_size=self.cfg.buffer_size, + data_mode=self.cfg.data_mode, + dump_extra=self.cfg.dump_extra, + op_id=self.cfg.op_id, + op_name=self.cfg.list, + exec=self.cfg.exec, + ) + + def _construct_for_onnx_model(self): + self.actuator = Component.get(CompConst.ONNX_ACTUATOR_COMP)( + priority=20, + model_path=self.cfg.exec[0], + input_shape=self.cfg.input_shape, + input_path=self.cfg.input_path, + dir_pool=self.dir_pool, + onnx_fusion_switch=self.cfg.onnx_fusion_switch, + ) + self.dumper = Component.get(CompConst.ONNX_DUMPER_COMP)(priority=10, data_mode=self.cfg.data_mode) + self.writer = Component.get(CompConst.DUMP_WRITER_COMP)( + priority=15, + task=self.cfg.task, + level=CfgConst.LEVEL_KERNEL, + framework=CfgConst.FRAMEWORK_ONNX, + summary_mode=self.cfg.summary_mode, + strategy=DumpConst.NPY_FORMAT, + buffer_size=self.cfg.buffer_size, + dir_pool=self.dir_pool, + ) + self.writer.subscribe(self.dumper) + + def _construct_for_caffe_model(self): + self.actuator = Component.get(CompConst.CAFFE_ACTUATOR_COMP)( + priority=20, + model_path=self.cfg.exec[0], + input_shape=self.cfg.input_shape, + input_path=self.cfg.input_path, + dir_pool=self.dir_pool, + weight_path=self.cfg.weight_path, + ) + self.dumper = Component.get(CompConst.CAFFE_DUMPER_COMP)(priority=10, data_mode=self.cfg.data_mode) + self.writer = Component.get(CompConst.DUMP_WRITER_COMP)( + priority=15, + task=self.cfg.task, + level=CfgConst.LEVEL_MODULE, + framework=CfgConst.FRAMEWORK_CAFFE, + summary_mode=self.cfg.summary_mode, + strategy=DumpConst.NPY_FORMAT, + buffer_size=self.cfg.buffer_size, + dir_pool=self.dir_pool, + ) + self.writer.subscribe(self.dumper) + + def _construct_for_frozen_graph_model_on_cpu(self): + self.actuator = Component.get(CompConst.FROZEN_GRAPH_ACTUATOR_COMP_CPU)( + priority=20, + model_path=self.cfg.exec[0], + input_shape=self.cfg.input_shape, + input_path=self.cfg.input_path, + dir_pool=self.dir_pool, + ) + self.dumper = Component.get(CompConst.FROZEN_GRAPH_DUMPER_COMP_CPU)(priority=10, data_mode=self.cfg.data_mode) + self.writer = Component.get(CompConst.DUMP_WRITER_COMP)( + priority=15, + task=self.cfg.task, + level=CfgConst.LEVEL_KERNEL, + framework=CfgConst.FRAMEWORK_TF, + summary_mode=self.cfg.summary_mode, + strategy=DumpConst.NPY_FORMAT, + buffer_size=self.cfg.buffer_size, + dir_pool=self.dir_pool, + ) + self.writer.subscribe(self.dumper) + + def _construct_for_saved_model_on_cpu(self): + self.cfg.exec[0] = savedmodel2pb( + self.cfg.exec[0], self.cfg.saved_model_tag, self.cfg.saved_model_signature, DirPool.get_model_dir() + ) + self._construct_for_frozen_graph_model_on_cpu() + + def _construct_for_frozen_graph_model_on_npu(self): + self.actuator = Component.get(CompConst.FROZEN_GRAPH_ACTUATOR_COMP_NPU)( + priority=20, + model_path=self.cfg.exec[0], + input_shape=self.cfg.input_shape, + input_path=self.cfg.input_path, + data_mode=self.cfg.data_mode, + fsf=self.cfg.fusion_switch_file, + ) + self.setter = Component.get(CompConst.FROZEN_GRAPH_SET_GE_COMP_NPU)( + priority=10, + work_path=DirPool.get_msprobe_dir(), + dump_ge_graph=self.cfg.dump_ge_graph, + dump_graph_level=self.cfg.dump_graph_level, + dump_graph_path=DirPool.get_model_dir(), + ) + + def _construct_for_saved_model_on_npu(self): + self.cfg.exec[0] = savedmodel2pb( + self.cfg.exec[0], self.cfg.saved_model_tag, self.cfg.saved_model_signature, DirPool.get_model_dir() + ) + self._construct_for_frozen_graph_model_on_npu() diff --git a/accuracy_tools/msprobe/csrc/CMakeLists.txt b/accuracy_tools/msprobe/csrc/CMakeLists.txt new file mode 100644 index 00000000000..fcb5f333c5b --- /dev/null +++ b/accuracy_tools/msprobe/csrc/CMakeLists.txt @@ -0,0 +1,51 @@ +cmake_minimum_required(VERSION 3.14) +cmake_policy(SET CMP0048 NEW) +project(msprobe VERSION 1.0.0 LANGUAGES CXX C) + +set(CMAKE_VERBOSE_MAKEFILE ON) + +find_package(cpython MODULE REQUIRED) +find_package(nlohmannjson MODULE REQUIRED) +find_package(ZLIB REQUIRED) + +add_library(msprobe_c SHARED) + +add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) +target_compile_definitions(msprobe_c PRIVATE _GLIBCXX_USE_CXX11_ABI=0) + +# 安全编译选项 +target_compile_options(msprobe_c PRIVATE "-Wall") +target_compile_options(msprobe_c PRIVATE "-fPIC") +target_compile_options(msprobe_c PRIVATE "-fstack-protector-all") +target_compile_options(msprobe_c PRIVATE "-ftrapv") +target_compile_options(msprobe_c PRIVATE "-D_FORTIFY_SOURCE=2") + +target_link_options(msprobe_c PRIVATE "-Wl,-z,relor") +target_link_options(msprobe_c PRIVATE "-Wl,-z,now") +target_link_options(msprobe_c PRIVATE "-Wl,-z,noexecstack") +target_link_options(msprobe_c PRIVATE "-Wl,--disable-new-dtags") +target_link_options(msprobe_c PRIVATE "-s") + +target_link_libraries(msprobe_c PUBLIC dl) +target_link_libraries(msprobe_c PUBLIC pthread) +target_link_libraries(msprobe_c PUBLIC ${cpython_LIBRARIES}) +target_link_libraries(msprobe_c PUBLIC ZLIB::ZLIB) + +if(DEFINED BUILD_TYPE AND "${BUILD_TYPE}" STREQUAL "debug") + target_compile_options(msprobe_c PRIVATE "-O0") + target_compile_options(msprobe_c PRIVATE "-g") + target_compile_definitions(msprobe_c PRIVATE __DEBUG__) +else() + target_compile_options(msprobe_c PRIVATE "-O2") +endif() + +target_include_directories(msprobe_c BEFORE PRIVATE /usr/include) +target_include_directories(msprobe_c PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + +get_target_property(INCLUDE_DIRS msprobe_c INCLUDE_DIRECTORIES) +message(STATUS "msprobe_c include dirs: ${INCLUDE_DIRS}") + +file(GLOB_RECURSE SOURCES "*.cpp") +target_sources(msprobe_c PRIVATE ${SOURCES}) + +install(TARGETS msprobe_c LIBRARY DESTINATION lib) diff --git a/accuracy_tools/msprobe/csrc/acl/core/AclApi.cpp b/accuracy_tools/msprobe/csrc/acl/core/AclApi.cpp new file mode 100644 index 00000000000..06ad54738ef --- /dev/null +++ b/accuracy_tools/msprobe/csrc/acl/core/AclApi.cpp @@ -0,0 +1,411 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "acl/include/AclApi.h" + +#include "utils/Exception.h" +#include "utils/Log.h" + +namespace Ascendcl { + constexpr const char *ascendAclName = "libascendcl.so"; + + using aclInitFuncType = aclError (*)(const char *); + using aclrtSetDeviceFuncType = aclError (*)(int32_t); + using aclrtCreateContextFuncType = aclError (*)(aclrtContext *, int32_t); + using aclmdlLoadFromFileFuncType = aclError (*)(const char *, uint32_t *); + using aclmdlCreateDescFuncType = aclmdlDesc *(*)(); + using aclmdlGetDescFuncType = aclError (*)(aclmdlDesc *, uint32_t); + using aclmdlGetInputSizeByIndexFuncType = size_t (*)(aclmdlDesc *, size_t); + using aclmdlGetOutputSizeByIndexFuncType = size_t (*)(aclmdlDesc *, size_t); + using aclmdlGetNumInputsFuncType = size_t (*)(aclmdlDesc *); + using aclmdlGetNumOutputsFuncType = size_t (*)(aclmdlDesc *); + using aclmdlGetInputNameByIndexFuncType = const char *(*)(const aclmdlDesc *, size_t); + using aclmdlGetInputDataTypeFuncType = aclDataType (*)(const aclmdlDesc *, size_t); + using aclmdlGetInputDimsFuncType = aclError (*)(const aclmdlDesc *, size_t, aclmdlIODims *); + using aclrtMallocFuncType = aclError (*)(void **, size_t, aclrtMemMallocPolicy); + using aclmdlCreateDatasetFuncType = aclmdlDataset *(*)(); + using aclCreateDataBufferFuncType = aclDataBuffer *(*)(void *, size_t); + using aclmdlAddDatasetBufferFuncType = aclError (*)(aclmdlDataset *, aclDataBuffer *); + using aclmdlExecuteFuncType = aclError (*)(uint32_t, const aclmdlDataset *, aclmdlDataset *); + using aclrtMemcpyFuncType = aclError (*)(void *, size_t, const void *, size_t, aclrtMemcpyKind); + using aclmdlGetDatasetNumBuffersFuncType = size_t (*)(const aclmdlDataset *); + using aclmdlGetDatasetBufferFuncType = aclDataBuffer *(*)(const aclmdlDataset *, size_t); + using aclDestroyDataBufferFuncType = aclError (*)(const aclDataBuffer *); + using aclmdlDestroyDatasetFuncType = aclError (*)(const aclmdlDataset *); + using aclrtFreeFuncType = aclError (*)(void *); + using aclFinalizeFuncType = aclError (*)(); + using aclmdlUnloadFuncType = aclError (*)(uint32_t); + using aclmdlDestroyDescFuncType = aclError (*)(aclmdlDesc *); + using aclrtDestroyContextFuncType = aclError (*)(aclrtContext); + using aclrtResetDeviceFuncType = aclError (*)(int32_t); + using aclmdlInitDumpFuncType = aclError (*)(); + using aclmdlSetDumpFuncType = aclError (*)(const char *); + using aclmdlFinalizeDumpFuncType = aclError (*)(); + using acldumpRegCallbackFuncType = aclError (*)(AclDumpCallbackFuncType, int32_t); + using acldumpUnregCallbackFuncType = void (*)(); + + static aclInitFuncType aclInitFunc = nullptr; + static aclrtSetDeviceFuncType aclrtSetDeviceFunc = nullptr; + static aclrtCreateContextFuncType aclrtCreateContextFunc = nullptr; + static aclmdlLoadFromFileFuncType aclmdlLoadFromFileFunc = nullptr; + static aclmdlCreateDescFuncType aclmdlCreateDescFunc = nullptr; + static aclmdlGetDescFuncType aclmdlGetDescFunc = nullptr; + static aclmdlGetInputSizeByIndexFuncType aclmdlGetInputSizeByIndexFunc = nullptr; + static aclmdlGetOutputSizeByIndexFuncType aclmdlGetOutputSizeByIndexFunc = nullptr; + static aclmdlGetNumInputsFuncType aclmdlGetNumInputsFunc = nullptr; + static aclmdlGetNumOutputsFuncType aclmdlGetNumOutputsFunc = nullptr; + static aclmdlGetInputNameByIndexFuncType aclmdlGetInputNameByIndexFunc = nullptr; + static aclmdlGetInputDataTypeFuncType aclmdlGetInputDataTypeFunc = nullptr; + static aclmdlGetInputDimsFuncType aclmdlGetInputDimsFunc = nullptr; + static aclrtMallocFuncType aclrtMallocFunc = nullptr; + static aclmdlCreateDatasetFuncType aclmdlCreateDatasetFunc = nullptr; + static aclCreateDataBufferFuncType aclCreateDataBufferFunc = nullptr; + static aclmdlAddDatasetBufferFuncType aclmdlAddDatasetBufferFunc = nullptr; + static aclmdlExecuteFuncType aclmdlExecuteFunc = nullptr; + static aclrtMemcpyFuncType aclrtMemcpyFunc = nullptr; + static aclmdlGetDatasetNumBuffersFuncType aclmdlGetDatasetNumBuffersFunc = nullptr; + static aclmdlGetDatasetBufferFuncType aclmdlGetDatasetBufferFunc = nullptr; + static aclDestroyDataBufferFuncType aclDestroyDataBufferFunc = nullptr; + static aclmdlDestroyDatasetFuncType aclmdlDestroyDatasetFunc = nullptr; + static aclrtFreeFuncType aclrtFreeFunc = nullptr; + static aclFinalizeFuncType aclFinalizeFunc = nullptr; + static aclmdlUnloadFuncType aclmdlUnloadFunc = nullptr; + static aclmdlDestroyDescFuncType aclmdlDestroyDescFunc = nullptr; + static aclrtDestroyContextFuncType aclrtDestroyContextFunc = nullptr; + static aclrtResetDeviceFuncType aclrtResetDeviceFunc = nullptr; + static aclmdlInitDumpFuncType aclmdlInitDumpFunc = nullptr; + static aclmdlSetDumpFuncType aclmdlSetDumpFunc = nullptr; + static aclmdlFinalizeDumpFuncType aclmdlFinalizeDumpFunc = nullptr; + static acldumpRegCallbackFuncType acldumpRegCallbackFunc = nullptr; + static acldumpUnregCallbackFuncType acldumpUnregCallbackFunc = nullptr; + + const std::map functionMap = { + {"aclInit", reinterpret_cast(&aclInitFunc)}, + {"aclrtSetDevice", reinterpret_cast(&aclrtSetDeviceFunc)}, + {"aclrtCreateContext", reinterpret_cast(&aclrtCreateContextFunc)}, + {"aclmdlLoadFromFile", reinterpret_cast(&aclmdlLoadFromFileFunc)}, + {"aclmdlCreateDesc", reinterpret_cast(&aclmdlCreateDescFunc)}, + {"aclmdlGetDesc", reinterpret_cast(&aclmdlGetDescFunc)}, + {"aclmdlGetNumInputs", reinterpret_cast(&aclmdlGetNumInputsFunc)}, + {"aclmdlGetNumOutputs", reinterpret_cast(&aclmdlGetNumOutputsFunc)}, + {"aclmdlGetInputNameByIndex", reinterpret_cast(&aclmdlGetInputNameByIndexFunc)}, + {"aclmdlGetInputSizeByIndex", reinterpret_cast(&aclmdlGetInputSizeByIndexFunc)}, + {"aclmdlGetInputDataType", reinterpret_cast(&aclmdlGetInputDataTypeFunc)}, + {"aclmdlGetInputDims", reinterpret_cast(&aclmdlGetInputDimsFunc)}, + {"aclmdlGetOutputSizeByIndex", reinterpret_cast(&aclmdlGetOutputSizeByIndexFunc)}, + {"aclrtMalloc", reinterpret_cast(&aclrtMallocFunc)}, + {"aclmdlCreateDataset", reinterpret_cast(&aclmdlCreateDatasetFunc)}, + {"aclCreateDataBuffer", reinterpret_cast(&aclCreateDataBufferFunc)}, + {"aclmdlAddDatasetBuffer", reinterpret_cast(&aclmdlAddDatasetBufferFunc)}, + {"aclmdlExecute", reinterpret_cast(&aclmdlExecuteFunc)}, + {"aclrtMemcpy", reinterpret_cast(&aclrtMemcpyFunc)}, + {"aclmdlGetDatasetNumBuffers", reinterpret_cast(&aclmdlGetDatasetNumBuffersFunc)}, + {"aclmdlGetDatasetBuffer", reinterpret_cast(&aclmdlGetDatasetBufferFunc)}, + {"aclDestroyDataBuffer", reinterpret_cast(&aclDestroyDataBufferFunc)}, + {"aclmdlDestroyDataset", reinterpret_cast(&aclmdlDestroyDatasetFunc)}, + {"aclrtFree", reinterpret_cast(&aclrtFreeFunc)}, + {"aclFinalize", reinterpret_cast(&aclFinalizeFunc)}, + {"aclmdlUnload", reinterpret_cast(&aclmdlUnloadFunc)}, + {"aclmdlDestroyDesc", reinterpret_cast(&aclmdlDestroyDescFunc)}, + {"aclrtDestroyContext", reinterpret_cast(&aclrtDestroyContextFunc)}, + {"aclrtResetDevice", reinterpret_cast(&aclrtResetDeviceFunc)}, + {"aclmdlInitDump", reinterpret_cast(&aclmdlInitDumpFunc)}, + {"aclmdlSetDump", reinterpret_cast(&aclmdlSetDumpFunc)}, + {"aclmdlFinalizeDump", reinterpret_cast(&aclmdlFinalizeDumpFunc)}, + {"acldumpRegCallback", reinterpret_cast(&acldumpRegCallbackFunc)}, + {"acldumpUnregCallback", reinterpret_cast(&acldumpUnregCallbackFunc)}, + }; + + AclApi &AclApi::GetInstance() { + static AclApi instance; + return instance; + } + + AclApi::AclApi() { + LoadAclApi(); + } + + void AclApi::LoadAclApi() { + static void *libAscendcl = nullptr; + + if (libAscendcl != nullptr) { + LOG_ERROR << "No need to load acl api again."; + return; + } + libAscendcl = dlopen(ascendAclName, RTLD_LAZY); + if (libAscendcl == nullptr) { + LOG_ERROR << "Failed to search libascendcl.so. " << dlerror(); + return; + } + for (auto &iter : functionMap) { + if (*(iter.second) != nullptr) { + continue; + } + *(iter.second) = dlsym(libAscendcl, iter.first); + if (*(iter.second) == nullptr) { + LOG_ERROR << "Failed to load function " << iter.first << " from libascendcl.so. " << dlerror(); + dlclose(libAscendcl); + libAscendcl = nullptr; + return; + } + LOG_DEBUG << "Load function " << iter.first << " from libascendcl.so."; + } + } + + aclError AclApi::ACLAPI_AclInit(const char *cfg) { + if (aclInitFunc == nullptr) { + throw Utility::MsprobeException("API aclInit does not have a definition."); + } + return aclInitFunc(cfg); + } + + aclError AclApi::ACLAPI_AclRtSetDevice(int32_t deviceId) { + if (aclrtSetDeviceFunc == nullptr) { + throw Utility::MsprobeException("API aclrtSetDevice does not have a definition."); + } + return aclrtSetDeviceFunc(deviceId); + } + + aclError AclApi::ACLAPI_AclRtCreateContext(aclrtContext *context, int32_t deviceId) { + if (aclrtCreateContextFunc == nullptr) { + throw Utility::MsprobeException("API aclrtCreateContext does not have a definition."); + } + return aclrtCreateContextFunc(context, deviceId); + } + + LoadFileResult AclApi::ACLAPI_AclMdlLoadFromFile(const char *modelPath) { + uint32_t modelId; + if (aclmdlLoadFromFileFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlLoadFromFile does not have a definition."); + } + aclError ret = aclmdlLoadFromFileFunc(modelPath, &modelId); + LoadFileResult loadFileResult; + loadFileResult.modelId = modelId; + loadFileResult.ret = ret; + return loadFileResult; + } + + aclmdlDesc *AclApi::ACLAPI_AclMdlCreateDesc() { + if (aclmdlCreateDescFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlCreateDesc does not have a definition."); + } + return aclmdlCreateDescFunc(); + } + + aclError AclApi::ACLAPI_AclMdlGetDesc(aclmdlDesc *modelDesc, uint32_t modelId) { + if (aclmdlGetDescFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetDesc does not have a definition."); + } + return aclmdlGetDescFunc(modelDesc, modelId); + } + + size_t AclApi::ACLAPI_AclMdlGetNumInputs(aclmdlDesc *modelDesc) { + if (aclmdlGetNumInputsFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetNumInputs does not have a definition."); + } + return aclmdlGetNumInputsFunc(modelDesc); + } + + size_t AclApi::ACLAPI_AclMdlGetNumOutputs(aclmdlDesc *modelDesc) { + if (aclmdlGetNumOutputsFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetNumInputs does not have a definition."); + } + return aclmdlGetNumOutputsFunc(modelDesc); + } + + const char *AclApi::ACLAPI_AclMdlGetInputNameByIndex(const aclmdlDesc *modelDesc, size_t index) { + if (aclmdlGetInputNameByIndexFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetInputNameByIndex does not have a definition."); + } + return aclmdlGetInputNameByIndexFunc(modelDesc, index); + } + + size_t AclApi::ACLAPI_AclMdlGetInputSizeByIndex(aclmdlDesc *modelDesc, size_t index) { + if (aclmdlGetInputSizeByIndexFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetInputSizeByIndex does not have a definition."); + } + return aclmdlGetInputSizeByIndexFunc(modelDesc, index); + } + + size_t AclApi::ACLAPI_AclMdlGetOutputSizeByIndex(aclmdlDesc *modelDesc, size_t index) { + if (aclmdlGetOutputSizeByIndexFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetInputSizeByIndex does not have a definition."); + } + return aclmdlGetOutputSizeByIndexFunc(modelDesc, index); + } + + aclDataType AclApi::ACLAPI_AclMdlGetInputDataType(const aclmdlDesc *modelDesc, size_t index) { + if (aclmdlGetInputDataTypeFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetInputDataType does not have a definition."); + } + return aclmdlGetInputDataTypeFunc(modelDesc, index); + } + + aclError AclApi::ACLAPI_AclMdlGetInputDims(const aclmdlDesc *modelDesc, size_t index, aclmdlIODims *dims) { + if (aclmdlGetInputDimsFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetInputDims does not have a definition."); + } + return aclmdlGetInputDimsFunc(modelDesc, index, dims); + } + + aclError AclApi::ACLAPI_AclRtMalloc(void **devPtr, size_t size, aclrtMemMallocPolicy policy) { + if (aclrtMallocFunc == nullptr) { + throw Utility::MsprobeException("API aclrtMalloc does not have a definition."); + } + return aclrtMallocFunc(devPtr, size, policy); + } + + aclmdlDataset *AclApi::ACLAPI_AclMdlCreateDataset() { + if (aclmdlCreateDatasetFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlCreateDataset does not have a definition."); + } + return aclmdlCreateDatasetFunc(); + } + + aclDataBuffer *AclApi::ACLAPI_AclCreateDataBuffer(void *data, size_t size) { + if (aclCreateDataBufferFunc == nullptr) { + throw Utility::MsprobeException("API aclCreateDataBuffer does not have a definition."); + } + return aclCreateDataBufferFunc(data, size); + } + + aclError AclApi::ACLAPI_AclMdlAddDatasetBuffer(aclmdlDataset *dataset, aclDataBuffer *databuffer) { + if (aclmdlAddDatasetBufferFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlAddDatasetBuffer does not have a definition."); + } + return aclmdlAddDatasetBufferFunc(dataset, databuffer); + } + + aclError AclApi::ACLAPI_AclMdlExecute(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output) { + if (aclmdlExecuteFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlExecute does not have a definition."); + } + return aclmdlExecuteFunc(modelId, input, output); + } + + aclError + AclApi::ACLAPI_AclRtMemcpy(void *dst, size_t destMax, const void *src, size_t count, aclrtMemcpyKind kind) { + if (aclrtMemcpyFunc == nullptr) { + throw Utility::MsprobeException("API aclrtMemcpy does not have a definition."); + } + return aclrtMemcpyFunc(dst, destMax, src, count, kind); + } + + size_t AclApi::ACLAPI_AclMdlGetDatasetNumBuffers(const aclmdlDataset *dataset) { + if (aclmdlGetDatasetNumBuffersFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetDatasetNumBuffers does not have a definition."); + } + return aclmdlGetDatasetNumBuffersFunc(dataset); + } + + aclDataBuffer *AclApi::ACLAPI_AclMdlGetDatasetBuffer(const aclmdlDataset *dataset, size_t index) { + if (aclmdlGetDatasetBufferFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetDatasetBuffer does not have a definition."); + } + return aclmdlGetDatasetBufferFunc(dataset, index); + } + + aclError AclApi::ACLAPI_AclDestroyDataBuffer(const aclDataBuffer *dataBuffer) { + if (aclDestroyDataBufferFunc == nullptr) { + throw Utility::MsprobeException("API aclDestroyDataBuffer does not have a definition."); + } + return aclDestroyDataBufferFunc(dataBuffer); + } + + aclError AclApi::ACLAPI_AclMdlDestroyDataset(const aclmdlDataset *dataset) { + if (aclmdlDestroyDatasetFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlDestroyDataset does not have a definition."); + } + return aclmdlDestroyDatasetFunc(dataset); + } + + aclError AclApi::ACLAPI_AclRtFree(void *devPtr) { + if (aclrtFreeFunc == nullptr) { + throw Utility::MsprobeException("API aclrtFree does not have a definition."); + } + return aclrtFreeFunc(devPtr); + } + + aclError AclApi::ACLAPI_AclFinalize() { + if (aclFinalizeFunc == nullptr) { + throw Utility::MsprobeException("API aclFinalize does not have a definition."); + } + return aclFinalizeFunc(); + } + + aclError AclApi::ACLAPI_AclMdlUnload(uint32_t modelId) { + if (aclmdlUnloadFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlUnload does not have a definition."); + } + return aclmdlUnloadFunc(modelId); + } + + aclError AclApi::ACLAPI_AclMdlDestroyDesc(aclmdlDesc *modelDesc) { + if (aclmdlDestroyDescFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlDestroyDesc does not have a definition."); + } + return aclmdlDestroyDescFunc(modelDesc); + } + + aclError AclApi::ACLAPI_AclRtDestroyContext(aclrtContext context) { + if (aclrtDestroyContextFunc == nullptr) { + throw Utility::MsprobeException("API aclrtDestroyContext does not have a definition."); + } + return aclrtDestroyContextFunc(context); + } + + aclError AclApi::ACLAPI_AclRtResetDevice(int32_t deviceId) { + if (aclrtResetDeviceFunc == nullptr) { + throw Utility::MsprobeException("API aclrtResetDevice does not have a definition."); + } + return aclrtResetDeviceFunc(deviceId); + } + + aclError AclApi::ACLAPI_AclInitDump() { + if (aclmdlInitDumpFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlInitDump does not have a definition."); + } + return aclmdlInitDumpFunc(); + } + + aclError AclApi::ACLAPI_AclSetDump(const char *dumpCfgPath) { + if (aclmdlSetDumpFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlSetDump does not have a definition."); + } + return aclmdlSetDumpFunc(dumpCfgPath); + } + + aclError AclApi::ACLAPI_AclFinalizeDump() { + if (aclmdlFinalizeDumpFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlFinalizeDump does not have a definition."); + } + return aclmdlFinalizeDumpFunc(); + } + + aclError AclApi::ACLAPI_AclDumpRegCallBack(AclDumpCallbackFuncType messageCallback, int32_t flag) { + if (acldumpRegCallbackFunc == nullptr) { + throw Utility::MsprobeException("API acldumpRegCallback does not have a definition."); + } + return acldumpRegCallbackFunc(messageCallback, flag); + } + + void AclApi::ACLAPI_AclDumpUnregCallBack() { + if (acldumpUnregCallbackFunc == nullptr) { + throw Utility::MsprobeException("API acldumpUnregCallback does not have a definition."); + } + acldumpUnregCallbackFunc(); + } +} // namespace Ascendcl diff --git a/accuracy_tools/msprobe/csrc/acl/include/AclApi.h b/accuracy_tools/msprobe/csrc/acl/include/AclApi.h new file mode 100644 index 00000000000..967733ac322 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/acl/include/AclApi.h @@ -0,0 +1,152 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ACLDUMP_ACLAPI_H +#define ACLDUMP_ACLAPI_H + +#include +#include +#include +#include +#include + +#include "utils/Path.h" + +#define ACL_MAX_DIM_CNT 128 +#define ACL_MAX_TENSOR_NAME_LEN 128 + +extern "C" { +using aclError = int; +using aclrtContext = void *; +using aclmdlDesc = struct aclmdlDesc; +using aclmdlDataset = struct aclmdlDataset; +using aclDataBuffer = struct aclDataBuffer; + +typedef enum { + ACL_DT_UNDEFINED = -1, // 未知数据类型,默认值 + ACL_FLOAT = 0, + ACL_FLOAT16 = 1, + ACL_INT8 = 2, + ACL_INT32 = 3, + ACL_UINT8 = 4, + ACL_INT16 = 6, + ACL_UINT16 = 7, + ACL_UINT32 = 8, + ACL_INT64 = 9, + ACL_UINT64 = 10, + ACL_DOUBLE = 11, + ACL_BOOL = 12, + ACL_STRING = 13, + ACL_COMPLEX64 = 16, + ACL_COMPLEX128 = 17, + ACL_BF16 = 27, + ACL_INT4 = 29, + ACL_UINT1 = 30, + ACL_COMPLEX32 = 33, +} aclDataType; + +typedef enum aclrtMemMallocPolicy { + ACL_MEM_MALLOC_HUGE_FIRST, + ACL_MEM_MALLOC_HUGE_ONLY, + ACL_MEM_MALLOC_NORMAL_ONLY, + ACL_MEM_MALLOC_HUGE_FIRST_P2P, + ACL_MEM_MALLOC_HUGE_ONLY_P2P, + ACL_MEM_MALLOC_NORMAL_ONLY_P2P, + ACL_MEM_TYPE_LOW_BAND_WIDTH = 0x0100, + ACL_MEM_TYPE_HIGH_BAND_WIDTH = 0x1000 +} aclrtMemMallocPolicy; + +typedef enum aclrtMemcpyKind { + ACL_MEMCPY_HOST_TO_HOST, // Host内的内存复制 + ACL_MEMCPY_HOST_TO_DEVICE, // Host到Device的内存复制 + ACL_MEMCPY_DEVICE_TO_HOST, // Device到Host的内存复制 + ACL_MEMCPY_DEVICE_TO_DEVICE, // Device内或Device间的内存复制 + ACL_MEMCPY_DEFAULT, // 由系统根据源、目的内存地址自行判断拷贝方向 +} aclrtMemcpyKind; + +typedef struct aclmdlIODims { + char name[ACL_MAX_TENSOR_NAME_LEN]; /**< tensor name */ + size_t dimCount; /**Shape中的维度个数,如果为标量,则维度个数为0*/ + int64_t dims[ACL_MAX_DIM_CNT]; /**< 维度信息 */ +} aclmdlIODims; + +typedef struct acldumpChunk { + char fileName[Utility::SafePath::MAX_PATH_LENGTH]; // 待落盘的Dump数据文件名 + uint32_t bufLen; // dataBuf数据长度,单位Byte + uint32_t isLastChunk; // 标识Dump数据是否为最后一个分片,0表示不是最后一个分片,1表示最后一个分片 + int64_t offset; // Dump数据文件内容的偏移,其中-1表示文件追加内容 + int32_t flag; // 预留Dump数据标识,当前数据无标识 + uint8_t dataBuf[0]; // Dump数据的内存地址 +} acldumpChunk; +} + +using AclDumpCallbackFuncType = int32_t (*)(const acldumpChunk *, int32_t); + +namespace Ascendcl { + typedef struct LoadFileResult { + uint32_t modelId; + aclError ret; + } LoadFileResult; + + class AclApi { + public: + static AclApi &GetInstance(); + AclApi(); + aclError ACLAPI_AclInit(const char *cfg = nullptr); + aclError ACLAPI_AclRtSetDevice(int32_t deviceId); + aclError ACLAPI_AclRtCreateContext(aclrtContext *context, int32_t deviceId); + LoadFileResult ACLAPI_AclMdlLoadFromFile(const char *modelPath); + aclmdlDesc *ACLAPI_AclMdlCreateDesc(); + aclError ACLAPI_AclMdlGetDesc(aclmdlDesc *modelDesc, uint32_t modelId); + aclDataBuffer *ACLAPI_AclCreateDataBuffer(void *data, size_t size); + size_t ACLAPI_AclMdlGetInputSizeByIndex(aclmdlDesc *modelDesc, size_t index); + size_t ACLAPI_AclMdlGetOutputSizeByIndex(aclmdlDesc *modelDesc, size_t index); + size_t ACLAPI_AclMdlGetNumInputs(aclmdlDesc *modelDesc); + size_t ACLAPI_AclMdlGetNumOutputs(aclmdlDesc *modelDesc); + const char *ACLAPI_AclMdlGetInputNameByIndex(const aclmdlDesc *modelDesc, size_t index); + aclDataType ACLAPI_AclMdlGetInputDataType(const aclmdlDesc *modelDesc, size_t index); + aclError ACLAPI_AclMdlGetInputDims(const aclmdlDesc *modelDesc, size_t index, aclmdlIODims *dims); + aclError + ACLAPI_AclRtMalloc(void **devPtr, size_t size, aclrtMemMallocPolicy policy = ACL_MEM_MALLOC_HUGE_FIRST); + aclmdlDataset *ACLAPI_AclMdlCreateDataset(); + aclError ACLAPI_AclMdlAddDatasetBuffer(aclmdlDataset *dataset, aclDataBuffer *databuffer); + aclError ACLAPI_AclMdlExecute(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output); + aclError ACLAPI_AclRtMemcpy(void *dst, size_t destMax, const void *src, size_t count, aclrtMemcpyKind kind); + size_t ACLAPI_AclMdlGetDatasetNumBuffers(const aclmdlDataset *dataset); + aclDataBuffer *ACLAPI_AclMdlGetDatasetBuffer(const aclmdlDataset *dataset, size_t index); + aclError ACLAPI_AclDestroyDataBuffer(const aclDataBuffer *dataBuffer); + aclError ACLAPI_AclMdlDestroyDataset(const aclmdlDataset *dataset); + aclError ACLAPI_AclRtFree(void *devPtr); + aclError ACLAPI_AclFinalize(); + aclError ACLAPI_AclMdlUnload(uint32_t modelId); + aclError ACLAPI_AclMdlDestroyDesc(aclmdlDesc *modelDesc); + aclError ACLAPI_AclRtDestroyContext(aclrtContext context); + aclError ACLAPI_AclRtResetDevice(int32_t deviceId); + aclError ACLAPI_AclInitDump(); + aclError ACLAPI_AclSetDump(const char *dumpCfgPath); + aclError ACLAPI_AclFinalizeDump(); + aclError ACLAPI_AclDumpRegCallBack(AclDumpCallbackFuncType messageCallback, int32_t flag); + void ACLAPI_AclDumpUnregCallBack(); + + private: + void LoadAclApi(); + }; + +#define CALL_ACL_API(func, ...) Ascendcl::AclApi::GetInstance().ACLAPI_##func(__VA_ARGS__) + +} // namespace Ascendcl + +#endif diff --git a/accuracy_tools/msprobe/csrc/atb_probe/Override.cpp b/accuracy_tools/msprobe/csrc/atb_probe/Override.cpp new file mode 100644 index 00000000000..490bf16f65d --- /dev/null +++ b/accuracy_tools/msprobe/csrc/atb_probe/Override.cpp @@ -0,0 +1,252 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "atb_probe/Override.h" + +#include +#include + +#include "atb_probe/include/Helper.h" +#include "atb_probe/include/SaveExtra.h" +#include "atb_probe/include/SaveGraph.h" +#include "atb_probe/include/SaveTensor.h" +#include "atb_probe/include/Stat.h" +#include "common/Toolkit.h" +#include "utils/Constant.h" +#include "utils/DataType.h" +#include "utils/IO.h" +#include "utils/Log.h" +#include "utils/Path.h" +#include "utils/Str.h" + +namespace atb { + // overflow + bool atb::Probe::IsOverflowCheck() { + // ⬇ Priority handling logic. + Kit::SetLogLevel(); + // ⬆ The first function to be called, sets the global log level. + const std::string taskEnvVar = StrSpace::GetEnvVar(Cst::LINK_DUMP_TASK); + return taskEnvVar == Cst::TASK_OVERFLOW; + } + + bool atb::Probe::ReportOperationGraphEnable() { + return true; + } + + void atb::Probe::ReportOperationGraph(const std::string &opName, const std::string &graph) { + Utility::ordered_json graphJson = StrSpace::Str2Json(graph); + GraphSpace::CheckInputValid(opName, graphJson); + GraphSpace::RegisterLayer(opName, graph); + Utility::ordered_json constructedGraph = GraphSpace::Build(graphJson); + Utility::fs::path graphPath = Utility::GetMsprobeDir() / "model" / (opName + ".json"); + Utility::SafePath::MakeParentDir(graphPath); + Utility::SaveJson(constructedGraph, graphPath.string(), std::ios::out); + LOG_INFO << "Graph structure of " << opName << " is already built. Path: " << graphPath.c_str(); + } + + bool atb::Probe::IsTensorNeedSave(const std::vector &ids, const std::string &optype) { + const std::string opidEnvVar = StrSpace::GetEnvVar(Cst::LINK_SAVE_TENSOR_IDS); // 2_1_3,1,5_2 + const std::string opNameEnvVar = StrSpace::GetEnvVar(Cst::LINK_SAVE_TENSOR_RUNNER); // Lin,SelfAttention + if (opidEnvVar.empty() && opNameEnvVar.empty()) { + return true; + } + if (!opidEnvVar.empty() && TensorSpace::IsOpidMatch(ids, opidEnvVar)) { + return true; + } + if (!opNameEnvVar.empty() && TensorSpace::IsOpNameMatch(optype, opNameEnvVar)) { + return true; + } + return false; + } + + bool atb::Probe::IsSaveTensorDesc() { + return true; + } + + bool atb::Probe::IsExecuteCountInRange(const uint64_t executeCount) { + return StrSpace::IsValueInGoal(Cst::LINK_STEP, executeCount); + } + + bool atb::Probe::ReportOperationStatisticEnable() { + const std::string cpuProfFlag = StrSpace::GetEnvVar(Cst::LINK_SAVE_CPU_PROFILING); + if (cpuProfFlag.empty()) { + return false; + } + return StrSpace::Str2Int(cpuProfFlag.c_str(), 0, Cst::LINK_SAVE_CPU_PROFILING) != 0; + } + + void atb::Probe::ReportOperationSetupStatistic(const uint64_t executeCount, + const std::string &opname, + const std::string &st) { + const uint64_t realStep = executeCount - 1; + bool flag = atb::Probe::IsExecuteCountInRange(realStep); + if (!flag) { + return; + } + ExtraSpace::SaveCpuProf(realStep, opname, st); + } + + bool atb::Probe::ReportKernelIOTensorEnable() { + const std::string kernelFlag = StrSpace::GetEnvVar(Cst::LINK_SAVE_KERNEL_INFO); + if (kernelFlag.empty()) { + return false; + } + return StrSpace::Str2Int(kernelFlag.c_str(), 0, Cst::LINK_SAVE_KERNEL_INFO) != 0; + } + + void atb::Probe::ReportKernelIOTensor(const size_t executeCount, + const std::string &opName, + const std::string &opParam, + const std::vector &inTensors, + const std::vector &outTensors) { + bool flag = atb::Probe::IsExecuteCountInRange(executeCount); + if (!flag) { + return; + } + atb::Probe::OpInfo opInfo{opName, opParam, inTensors, outTensors}; + ExtraSpace::SaveInfo(executeCount, opInfo, "kernel_io_info.txt"); + } + + bool atb::Probe::ReportOperationIOTensorEnable() { + const std::string opFlag = StrSpace::GetEnvVar(Cst::LINK_SAVE_OP_INFO); + if (opFlag.empty()) { + return false; + } + return StrSpace::Str2Int(opFlag.c_str(), 0, Cst::LINK_SAVE_OP_INFO) != 0; + } + + void atb::Probe::ReportOperationIOTensor(const size_t executeCount, + const std::string &opName, + const std::string &opParam, + const std::vector &inTensors, + const std::vector &outTensors) { + bool flag = atb::Probe::IsExecuteCountInRange(executeCount); + if (!flag) { + return; + } + atb::Probe::OpInfo opInfo{opName, opParam, inTensors, outTensors}; + ExtraSpace::SaveInfo(executeCount, opInfo, "operation_io_info.txt"); + } + + bool atb::Probe::IsSaveTiling() { + return true; + } + + void atb::Probe::SaveTiling(const uint8_t *data, uint64_t dataSize, const std::string &filePath) { + // ⬇ Mandatory steps for providing dump information: 1. PID and rank 2. PID and dump.json. + Kit::PidTieRank::Add(filePath); + bool hasDumpJson = Kit::DumpJsonManager::Instance().IsHas(filePath); + if (!hasDumpJson) { + Types::ArgsDumpJsonInit args; + std::string bufferSize = StrSpace::GetEnvVar(Cst::LINK_BUFFER_SIZE); + args.bufferSize = StrSpace::Str2Int(bufferSize.c_str(), Cst::BUFFER_SIZE, Cst::LINK_BUFFER_SIZE); + args.task = StrSpace::GetEnvVar(Cst::LINK_DUMP_TASK); + args.level = StrSpace::GetEnvVar(Cst::LINK_DUMP_LEVEL); + args.framework = Cst::FRAMEWORK_MINDIELLM; + args.outputDir = Kit::GetRankDir(filePath); + Kit::DumpJsonManager::Instance().Create(filePath, args); + } + // ⬆ The above parts are unrelated to SaveTiling. + bool saveFlag = ExtraSpace::IsSaveTiling(); + if (!saveFlag) { + return; + } + bool validFlag = ExtraSpace::IsValidParam(data, dataSize, filePath); + if (!validFlag) { + return; + } + ExtraSpace::SaveTiling(data, dataSize, filePath); + } + + bool atb::Probe::IsSaveTensorBefore() { + return true; + } + + bool atb::Probe::IsSaveTensorData() { + return true; + } + + void atb::Probe::SaveTensor(const std::string &format, + const std::string &dtype, + const std::string &dims, + const void *hostData, + uint64_t dataSize, + const std::string &filePath) { + Types::TensorInfo inputFile{format, dtype, dims, hostData, dataSize, filePath}; + if (!TensorSpace::IsExpectedTensor(inputFile)) { + return; + } + const std::string taskFlag = StrSpace::GetEnvVar(Cst::LINK_DUMP_TASK); + if (taskFlag.empty()) { + return; + } + Types::TensorStats stat = Stat::Compute(inputFile); + Types::PathInfo pInfo = Kit::GetPathInfo(filePath); + Kit::DumpJson *dumpJson = Kit::DumpJsonManager::Instance().Get(filePath); + if (dumpJson != nullptr) { + dumpJson->UpdateStat(pInfo.nodeName, pInfo.inOut, pInfo.argsName, stat); + } else { + LOG_ERROR << "DumpJson is nullptr, can not update stat."; + } + if (taskFlag == Cst::TASK_TENSOR) { + dumpJson->AddTensorDir(inputFile.filePath); + TensorSpace::Save(inputFile); + } + } + + bool atb::Probe::IsSaveTensorAfter() { + return true; + } + + bool atb::Probe::IsSaveParam() { + const std::string paramFlag = StrSpace::GetEnvVar(Cst::LINK_SAVE_PARAM); + if (paramFlag.empty()) { + return false; + } + return StrSpace::Str2Int(paramFlag.c_str(), 0, Cst::LINK_SAVE_PARAM) != 0; + } + + void atb::Probe::SaveParam(const std::string ¶m, const std::string &filePath) { + // ⬇ Temporary solution: Wait until everything finishes running before checking the cache. + Kit::DumpJson *dumpJson = Kit::DumpJsonManager::Instance().Get(filePath); + if (dumpJson != nullptr) { + dumpJson->Flush(); + } else { + LOG_ERROR << "DumpJson is nullptr, can not flush cache."; + } + // ⬆ + bool flag = ExtraSpace::IsSaveParam() && Kit::IsPathInGoal(filePath); + if (!flag) { + return; + } + std::string newFilePath = Kit::GetNewPathForSave(filePath, Cst::SUBDIRNAME_DUMP_TENSOR); + Utility::SafePath::MakeParentDir(newFilePath); + Utility::SaveJson(StrSpace::Str2Json(param), newFilePath, std::ios::out); + } + + bool atb::Probe::IsOverflowStop() { + const std::string isExit = StrSpace::GetEnvVar(Cst::LINK_STOP); + if (isExit.empty()) { + return false; + } + return StrSpace::Str2Int(isExit.c_str(), 0, Cst::LINK_STOP) != 0; + } + + void atb::Probe::ReportOverflowKernel(const std::string &kernelPath) { + return; + } + +} // namespace atb diff --git a/accuracy_tools/msprobe/csrc/atb_probe/Override.h b/accuracy_tools/msprobe/csrc/atb_probe/Override.h new file mode 100644 index 00000000000..312258bfb85 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/atb_probe/Override.h @@ -0,0 +1,84 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_PROBE_OVERRIDE_H +#define ATB_PROBE_OVERRIDE_H + +#include +#include + +#define EXPORT_LLM __attribute__((visibility("default"))) + +namespace atb { + class Probe { + public: + struct Tensor { + std::string dtype; + std::string format; + std::string shape; + std::string filePath; + }; + + struct OpInfo { + const std::string &opName; + const std::string &opParam; + const std::vector &inTensors; + const std::vector &outTensors; + }; + + public: + EXPORT_LLM static bool IsOverflowCheck(); + EXPORT_LLM static bool ReportOperationGraphEnable(); + EXPORT_LLM static void ReportOperationGraph(const std::string &opName, const std::string &graph); + EXPORT_LLM static bool IsTensorNeedSave(const std::vector &ids, const std::string &optype); + EXPORT_LLM static bool IsSaveTensorDesc(); + EXPORT_LLM static bool IsExecuteCountInRange(const uint64_t executeCount); + EXPORT_LLM static bool IsSaveTiling(); + EXPORT_LLM static void SaveTiling(const uint8_t *data, uint64_t dataSize, const std::string &filePath); + EXPORT_LLM static bool ReportOperationStatisticEnable(); + EXPORT_LLM static void + ReportOperationSetupStatistic(const uint64_t executeCount, const std::string &opname, const std::string &st); + EXPORT_LLM static bool IsSaveTensorBefore(); + EXPORT_LLM static bool IsSaveTensorData(); + EXPORT_LLM static void SaveTensor(const std::string &format, + const std::string &dtype, + const std::string &dims, + const void *hostData, + uint64_t dataSize, + const std::string &filePath); + EXPORT_LLM static bool IsSaveTensorAfter(); + EXPORT_LLM static bool ReportOperationIOTensorEnable(); + EXPORT_LLM static void ReportOperationIOTensor(const size_t executeCount, + const std::string &opName, + const std::string &opParam, + const std::vector &inTensors, + const std::vector &outTensors); + EXPORT_LLM static bool ReportKernelIOTensorEnable(); + EXPORT_LLM static void ReportKernelIOTensor(const size_t executeCount, + const std::string &opName, + const std::string &opParam, + const std::vector &inTensors, + const std::vector &outTensors); + + EXPORT_LLM static bool IsSaveParam(); + EXPORT_LLM static void SaveParam(const std::string ¶m, const std::string &filePath); + + EXPORT_LLM static bool IsOverflowStop(); + EXPORT_LLM static void ReportOverflowKernel(const std::string &kernelPath); + }; +} // namespace atb + +#endif diff --git a/accuracy_tools/msprobe/csrc/atb_probe/core/Helper.cpp b/accuracy_tools/msprobe/csrc/atb_probe/core/Helper.cpp new file mode 100644 index 00000000000..0f858079e48 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/atb_probe/core/Helper.cpp @@ -0,0 +1,287 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "atb_probe/include/Helper.h" + +#include +#include +#include +#include + +#include "utils/Constant.h" +#include "utils/Exception.h" +#include "utils/IO.h" +#include "utils/Log.h" +#include "utils/Path.h" +#include "utils/Str.h" + +namespace Kit { + inline const uint8_t STEP_LOCATION = 2; + + static std::string ExtractStepFromFilePath(const std::string &filePath) { + // filePath: 3_2035814/5/2_PositionalEmbeddingGather/xxxxxxxx + std::vector dirVec = StrSpace::Split(filePath, "/"); + if (dirVec.size() >= STEP_LOCATION) { + return dirVec[1]; + } else { + LOG_ERROR << "Invalid file path: " << filePath; + return ""; + } + } + + static bool IsValidStep(const std::string &filePath) { + std::string step = ExtractStepFromFilePath(filePath); + return StrSpace::IsValueInGoal(Cst::LINK_STEP, step); + } + + static std::string ExtractRankFromFilePath(const std::string &filePath) { + // filePath: 3_2035814/5/2_PositionalEmbeddingGather/xxxxxxxx + size_t pos = filePath.find('_'); + if (pos != std::string::npos) { + return filePath.substr(0, pos); + } else { + LOG_ERROR << "Invalid file path: " << filePath; + return ""; + } + } + + static std::string GetNodeName(const std::vector &dirVec) { + std::vector nameVec; + for (size_t i = STEP_LOCATION; i < dirVec.size(); ++i) { + if (dirVec[i] == Cst::BEFORE || dirVec[i] == Cst::AFTER) + break; + nameVec.push_back(dirVec[i]); + } + return StrSpace::Join(nameVec, "/"); + } + + bool IsValidRank(const std::string &filePath) { + std::string rank = ExtractRankFromFilePath(filePath); + return StrSpace::IsValueInGoal(Cst::LINK_RANK, rank); + } + + bool IsPathInGoal(const std::string &filePath) { + return IsValidStep(filePath) && IsValidRank(filePath); + } + + Utility::fs::path GetRankDir(const std::string &filePath) { + // filePath: 3_2035814/5/2_PositionalEmbeddingGather/before/intensor0.bin + Utility::fs::path newPath; + newPath /= "step" + ExtractStepFromFilePath(filePath); + newPath /= "rank" + ExtractRankFromFilePath(filePath); + return Utility::GetMsprobeDir() / newPath; + } + + std::unordered_map PidTieRank::g_pidWithRankMap; + + void PidTieRank::Add(const std::string &filePath) { + // filePath: 3_2035814/5/2_PositionalEmbeddingGather/xxxxxxxx + std::string rank = ExtractRankFromFilePath(filePath); + pid_t pid = getpid(); + auto it = g_pidWithRankMap.find(pid); + if (it == g_pidWithRankMap.end()) { + g_pidWithRankMap[pid] = rank; + } else { + if (it->second != rank) { + LOG_WARNING << "The PID: " << pid << " is already associated with a different rank. " + << "Existing rank: " << it->second << ", New rank: " << rank; + } + } + } + + std::string PidTieRank::Get(const pid_t &pid) { + if (g_pidWithRankMap.find(pid) != g_pidWithRankMap.end()) { + return g_pidWithRankMap[pid]; + } else { + LOG_ERROR << "No association between PID and Rank has been established."; + return ""; + } + } + + Types::PathInfo GetPathInfo(const std::string &filePath) { + Types::PathInfo res; + std::vector dirVec = StrSpace::Split(filePath, "/"); + + for (size_t i = 0; i < dirVec.size(); ++i) { + const std::string &part = dirVec[i]; + if (part != Cst::BEFORE && part != Cst::AFTER) { + continue; + } + res.nodeName = (i > 0) ? GetNodeName(dirVec) : ""; + if (i + 1 >= dirVec.size()) { + LOG_ERROR << "Missing file after '" << part << "' in path: " << filePath; + break; + } + std::string fileName = dirVec[i + 1]; + std::vector fileVec = StrSpace::Split(fileName, "."); + const std::string argName = (fileVec.size() == Cst::ARGS_LEN_2) ? fileVec[0] : ""; + if (argName.find(Cst::INTENSOR) != std::string::npos) { + res.inOut = "input_args"; + } else if (argName.find(Cst::OUTTENSOR) != std::string::npos) { + res.inOut = "output_args"; + } else { + LOG_ERROR << "Unknown tensor direction in file name: " << filePath; + } + res.argsName = argName; + return res; + } + LOG_ERROR << "Invalid file path: " << filePath; + return Types::PathInfo{"", "", ""}; + } + + Utility::fs::path GetNewPathForSave(const Utility::fs::path &originalPath, const std::string &subDirName) { + // originalPath: 3_2035814/5/2_PositionalEmbeddingGather/xxxxxxxx + auto iter = originalPath.begin(); + std::string rankStr = iter->string(); // 3_2035814 + std::string rank = rankStr.substr(0, rankStr.find('_')); // extract 3 as rank + ++iter; + std::string step = iter->string(); // 5 + ++iter; + Utility::fs::path newPath; + newPath /= "step" + step; + newPath /= "rank" + rank; + newPath /= subDirName; + while (iter != originalPath.end()) { + newPath /= *iter; + ++iter; + } + return Utility::GetMsprobeDir() / newPath; + } + + std::vector GetColumns(const std::unordered_map &kvMap) { + std::vector columns; + for (const auto &pair : kvMap) { + columns.push_back(pair.first); + } + return columns; + } + + std::vector GetElement(const std::unordered_map &kvMap, + const std::vector &columns) { + std::vector formattedLine; + for (const auto &column : columns) { + auto it = kvMap.find(column); + if (it != kvMap.end()) { + formattedLine.push_back(it->second); + } else { + formattedLine.push_back(""); + } + } + return formattedLine; + } + + std::mutex DumpJson::mutex_; + + DumpJson::DumpJson(const size_t &bufferSize, + const std::string &task, + const std::string &level, + const std::string &framework, + const std::string &outputDir) + : formalBufferSize(bufferSize), formalCurrentSize(0), formalOutputDir(outputDir) { + formalCache["task"] = task; + formalCache["level"] = level; + formalCache["framework"] = framework; + formalCache["dump_data_dir"] = ""; + formalCache["data"] = nlohmann::json::object(); + } + + void DumpJson::AddTensorDir(const std::string &filePath) { + auto it = formalCache.find("dump_data_dir"); + if (it == formalCache.end() || !it.value().is_string() || it.value().get().empty()) { + std::string tensorDir = Kit::GetRankDir(filePath) / Cst::SUBDIRNAME_DUMP_TENSOR; + formalCache["dump_data_dir"] = tensorDir; + } + } + + void DumpJson::UpdateStat(const std::string &nodeName, + const std::string &inOut, + const std::string &dataName, + const Types::TensorStats &stats) { + Utility::ordered_json statJson = { + {"data_name", dataName}, + {"type", stats.type}, + {"dtype", stats.dtype}, + {"shape", stats.shape}, + {"Max", stats.max}, + {"Min", stats.min}, + {"Mean", stats.mean}, + {"Norm", stats.norm}, + }; + if (!stats.crc32.empty()) { + statJson["crc32"] = stats.crc32; + } + + std::lock_guard lock(mutex_); + if (!formalCache["data"].contains(nodeName)) { + formalCache["data"][nodeName] = Utility::json::object(); + } + if (!formalCache["data"][nodeName].contains(inOut)) { + formalCache["data"][nodeName][inOut] = Utility::json::array(); + } + formalCurrentSize += statJson.dump().size(); + formalCache["data"][nodeName][inOut].emplace_back(std::move(statJson)); + Flush(formalBufferSize); + } + + void DumpJson::Flush(size_t flushThreshold) { + if (formalCurrentSize > flushThreshold) { + std::string dumpJsonPath = formalOutputDir + "/dump.json"; + Utility::SafePath::MakeParentDir(dumpJsonPath); + Utility::SaveJson(formalCache, dumpJsonPath, std::ios_base::out); + formalCurrentSize = 0; + } + } + + DumpJsonManager &DumpJsonManager::Instance() { + static DumpJsonManager instance; + return instance; + } + + std::unordered_map> DumpJsonManager::g_pidWithDumpJsonMap; + std::mutex DumpJsonManager::mutex_; + + bool DumpJsonManager::IsHas(const std::string &filePath) { + std::lock_guard lock(mutex_); + std::string key; + key += "step" + ExtractStepFromFilePath(filePath); + key += "rank" + ExtractStepFromFilePath(filePath); + LOG_DEBUG << "DumpJson for StepRank " << key; + return g_pidWithDumpJsonMap.find(key) != g_pidWithDumpJsonMap.end(); + } + + void DumpJsonManager::Create(const std::string &filePath, const Types::ArgsDumpJsonInit &args) { + std::lock_guard lock(mutex_); + std::string key; + key += "step" + ExtractStepFromFilePath(filePath); + key += "rank" + ExtractStepFromFilePath(filePath); + if (g_pidWithDumpJsonMap.find(key) == g_pidWithDumpJsonMap.end()) { + g_pidWithDumpJsonMap[key] = + std::make_unique(args.bufferSize, args.task, args.level, args.framework, args.outputDir); + } else { + LOG_WARNING << "DumpJson for StepRank " << key << " already exists."; + } + } + + DumpJson *DumpJsonManager::Get(const std::string &filePath) { + std::lock_guard lock(mutex_); + std::string key; + key += "step" + ExtractStepFromFilePath(filePath); + key += "rank" + ExtractStepFromFilePath(filePath); + auto it = g_pidWithDumpJsonMap.find(key); + return (it != g_pidWithDumpJsonMap.end()) ? it->second.get() : nullptr; + } + +} // namespace Kit diff --git a/accuracy_tools/msprobe/csrc/atb_probe/core/SaveExtra.cpp b/accuracy_tools/msprobe/csrc/atb_probe/core/SaveExtra.cpp new file mode 100644 index 00000000000..139019c3c41 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/atb_probe/core/SaveExtra.cpp @@ -0,0 +1,181 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "atb_probe/include/SaveExtra.h" + +#include +#include +#include + +#include "atb_probe/Override.h" +#include "atb_probe/include/Helper.h" +#include "common/Toolkit.h" +#include "utils/Constant.h" +#include "utils/DataType.h" +#include "utils/Log.h" +#include "utils/Path.h" +#include "utils/Str.h" + +namespace atb { + namespace ExtraSpace { + const std::vector TENSOR_INFO_HEADER = { + "CaseNum", "CaseName", "OpName", "OpParam", "InNum", "InDType", "InFormat", + "InShape", "OutNum", "OutDType", "OutFormat", "OutShape", "DataGenType", "DataGenRange", + "InTensorFile", "OutTensorFile", "TestType", "TestLevel", "FromModel", "SocVersion", "ExpectedError", + }; + + static void AddTensorAttributes(const std::vector &tensors, + Types::TensorInfo &info, + std::string *dataGenTypes = nullptr) { + bool first = true; + for (const auto &tensor : tensors) { + if (!first) { + info.dtype += ";"; + info.format += ";"; + info.dims += ";"; + info.filePath += ";"; + if (dataGenTypes) { + *dataGenTypes += ";"; + } + } + info.dtype += tensor.dtype; + info.format += tensor.format; + info.dims += tensor.shape; + info.filePath += Kit::GetNewPathForSave(tensor.filePath, Cst::SUBDIRNAME_DUMP_TENSOR).c_str(); + first = false; + if (dataGenTypes) { + *dataGenTypes += "customize"; + } + } + } + + static std::vector GetIOInfo(const uint16_t &caseNum, const atb::Probe::OpInfo &opInfo) { + const std::string caseName = opInfo.opName + std::to_string(caseNum); + Types::TensorInfo inInfo; + std::string dataGenType; + Types::TensorInfo outInfo; + AddTensorAttributes(opInfo.inTensors, inInfo, &dataGenType); + AddTensorAttributes(opInfo.outTensors, outInfo); + + const std::vector inputStringParts = { + caseName, + opInfo.opName, + opInfo.opParam, + inInfo.dtype, + inInfo.format, + inInfo.dims, + inInfo.filePath, + outInfo.dtype, + outInfo.format, + outInfo.dims, + outInfo.filePath, + dataGenType, + }; + for (const std::string &value : inputStringParts) { + if (Kit::IsSpecialCharInjected(value)) { + return {}; + } + } + + std::vector fields = { + std::to_string(caseNum), + caseName, + opInfo.opName, + opInfo.opParam, + std::to_string(opInfo.inTensors.size()), + inInfo.dtype, + inInfo.format, + inInfo.dims, + std::to_string(opInfo.outTensors.size()), + outInfo.dtype, + outInfo.format, + outInfo.dims, + dataGenType, + "", // DataGenRange + inInfo.filePath, + outInfo.filePath, + "", // TestType + "", // TestLevel + "", // FromModel + "", // SocVersion + "NO_ERROR", + }; + return fields; + } + + bool IsSaveTiling() { + const std::string tilingFlag = StrSpace::GetEnvVar(Cst::LINK_SAVE_TILING); + if (tilingFlag.empty()) { + return false; + } + return StrSpace::Str2Int(tilingFlag.c_str(), 0, Cst::LINK_SAVE_TILING) != 0; + } + + bool IsValidParam(const uint8_t *data, const uint64_t &dataSize, const std::string &filePath) { + bool dataFlag = data != nullptr; + bool dataSizeFlag = (dataSize <= Utility::SafePath::SIZE_10G) && (dataSize > 0); + bool pathFlag = Kit::IsPathInGoal(filePath); + return dataFlag && dataSizeFlag && pathFlag; + } + + void SaveTiling(const uint8_t *data, const uint64_t &dataSize, const std::string &filePath) { + Utility::fs::path tilingFilePath = Kit::GetNewPathForSave(filePath, Cst::SUBDIRNAME_TILING); + Utility::SafePath::MakeParentDir(tilingFilePath); + Utility::SaveBytes(data, tilingFilePath, dataSize, std::ios::out | std::ios::binary); + } + + void SaveCpuProf(const uint64_t &executeCount, const std::string &opName, const std::string &st) { + Utility::fs::path cpuFilePath = Utility::GetMsprobeDir() / ("step" + std::to_string(executeCount)) / + ("rank" + Kit::PidTieRank::Get(getpid())) / "cpu_profiling_info.txt"; + Utility::SafePath::MakeParentDir(cpuFilePath); + + std::unordered_map nameNumMap = StrSpace::Str2Map(st, ", ", ":"); + std::vector header = Kit::GetColumns(nameNumMap); + std::vector content = Kit::GetElement(nameNumMap, header); + header.insert(header.begin(), "opName"); + content.insert(content.begin(), opName); + if (!Utility::fs::exists(cpuFilePath)) { + const std::string newHeader = StrSpace::Join(header, "\t"); + Utility::SaveTxt(newHeader, cpuFilePath, std::ios_base::app); + } + const std::string newContent = StrSpace::Join(content, "\t"); + Utility::SaveTxt(newContent, cpuFilePath, std::ios_base::app); + } + + void SaveInfo(const uint64_t &executeCount, const atb::Probe::OpInfo &opInfo, const std::string &fileName) { + Utility::fs::path opPath = Utility::GetMsprobeDir() / ("step" + std::to_string(executeCount)) / + ("rank" + Kit::PidTieRank::Get(getpid())) / fileName; + Utility::SafePath::MakeParentDir(opPath); + uint16_t caseNum = Utility::fs::exists(opPath) ? Kit::GetLineNum(opPath.c_str()) : 1; + if (!Utility::fs::exists(opPath)) { + const std::string newHeader = StrSpace::Join(TENSOR_INFO_HEADER, "\t"); + Utility::SaveTxt(newHeader, opPath, std::ios_base::app); + } + std::vector content = GetIOInfo(caseNum, opInfo); + std::string newContent = StrSpace::Join(content, "\t"); + Utility::SaveTxt(newContent, opPath, std::ios_base::app); + } + + bool IsSaveParam() { + const std::string taskFlag = StrSpace::GetEnvVar(Cst::LINK_DUMP_TASK); + if (taskFlag == Cst::TASK_TENSOR) { + return true; + } + return false; + } + + } // namespace ExtraSpace +} // namespace atb diff --git a/accuracy_tools/msprobe/csrc/atb_probe/core/SaveGraph.cpp b/accuracy_tools/msprobe/csrc/atb_probe/core/SaveGraph.cpp new file mode 100644 index 00000000000..48cb5ef95d5 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/atb_probe/core/SaveGraph.cpp @@ -0,0 +1,188 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "atb_probe/include/SaveGraph.h" + +#include + +#include "common/Toolkit.h" +#include "utils/DataType.h" +#include "utils/Exception.h" +#include "utils/IO.h" +#include "utils/Log.h" + +namespace GraphSpace { + uint32_t g_maxDeep = 1000; + static Types::LayerGraphMap g_layerGraphMap; + static Types::ModelGraphMap g_modelGraphMap; + + static void AddTensors(Utility::ordered_json &newJson, + std::vector &rootNames, + const std::string &opName, + const std::string &tensorType, + const std::string &jsonKey, + size_t count) { + for (size_t i = 0; i < count; ++i) { + std::string tensorName = opName + "_" + tensorType + "_" + std::to_string(i); + newJson[jsonKey].emplace_back(tensorName); + rootNames.emplace_back(tensorName); + } + } + + static void AddByFields(const std::vector &name, + const Utility::ordered_json &originJson, + Utility::ordered_json &newJson) { + for (const auto &fieldName : name) { + if (originJson.contains(fieldName)) { + newJson[fieldName] = originJson[fieldName]; + } else { + LOG_ERROR << fieldName << " not found in graph."; + return; + } + } + } + + static void ModifyRootNodes(Utility::ordered_json &newJson, + std::vector &rootNames, + const Utility::ordered_json &originJson) { + std::string opName = originJson["opName"].get(); + uint32_t inTensorNum = originJson["inTensorNum"].get(); + uint32_t outTensorNum = originJson["outTensorNum"].get(); + uint32_t internalTensorNum = originJson.value("internalTensorNum", 0); + + AddTensors(newJson, rootNames, opName, "input", "inTensors", inTensorNum); + LOG_DEBUG << "Last inTensors: " << (rootNames.empty() ? "" : rootNames.back()); + AddTensors(newJson, rootNames, opName, "output", "outTensors", outTensorNum); + LOG_DEBUG << "Last outTensors: " << (rootNames.empty() ? "" : rootNames.back()); + AddTensors(newJson, rootNames, opName, "internal", "internalTensors", internalTensorNum); + LOG_DEBUG << "Last internalTensors: " << (rootNames.empty() ? "" : rootNames.back()); + } + + static void ProcessTensorIds(const Utility::ordered_json &tensorIds, + const std::vector &sourceTensorList, + std::vector &tensorNameList, + Utility::ordered_json &tensorContainer, + const std::string &opName, + const std::string &field) { + for (const auto &item : tensorIds) { + uint32_t index = item.get(); + if (index >= sourceTensorList.size()) { + LOG_ERROR << field << " index out of range for op: " << opName; + return; + } + const auto &tensorName = sourceTensorList[index]; + tensorContainer[field].emplace_back(tensorName); + tensorNameList.emplace_back(tensorName); + } + } + + static void DfsToModifyGraphTensors(Utility::ordered_json &curNodeToSave, + const std::vector &fatherNodeTensorNameList, + const Utility::ordered_json &curNodeInput, + uint32_t curDeep = 0) { + if (curDeep >= g_maxDeep) { + LOG_ERROR << "Function " + std::string(__func__) + " has been terminated due to max depth."; + throw Utility::MsprobeException("Exceeded maximum recursion depth of " + std::to_string(g_maxDeep)); + } + + const std::string opName = curNodeInput["opName"].get(); + curNodeToSave["opName"] = opName; + curNodeToSave["opType"] = curNodeInput["opType"]; + curNodeToSave["param"] = curNodeInput["param"]; + + std::vector curNodeTensorNameList; + // Process the input tensor. + ProcessTensorIds(curNodeInput["inTensorIds"], + fatherNodeTensorNameList, + curNodeTensorNameList, + curNodeToSave, + opName, + "inTensors"); + // Process the output tensor. + ProcessTensorIds(curNodeInput["outTensorIds"], + fatherNodeTensorNameList, + curNodeTensorNameList, + curNodeToSave, + opName, + "outTensors"); + + uint32_t internalTensorNum = curNodeInput.value("internalTensorNum", 0U); + for (uint32_t i = 0; i < internalTensorNum; ++i) { + std::string tensorName = opName + "_internal_" + std::to_string(i); + curNodeToSave["internalTensors"].emplace_back(tensorName); + curNodeTensorNameList.emplace_back(tensorName); + } + + if (curNodeInput.contains("nodes")) { + for (const auto &childNodeInput : curNodeInput["nodes"]) { + Utility::ordered_json childNodeToSave; + DfsToModifyGraphTensors(childNodeToSave, curNodeTensorNameList, childNodeInput, curDeep + 1); + curNodeToSave["nodes"].emplace_back(childNodeToSave); + } + } + } + + static void ProcessChildNodes(Utility::ordered_json &newJson, + const std::vector &tensorNameList, + const Utility::ordered_json &graphJson) { + if (graphJson.find("nodes") != graphJson.end()) { + for (const auto &childNodeInput : graphJson["nodes"]) { + Utility::ordered_json childNodeToSave; + try { + DfsToModifyGraphTensors(childNodeToSave, tensorNameList, childNodeInput); + } catch (const std::exception &e) { + LOG_ERROR << "An unexpected error occurred: " << e.what(); + return; + } + newJson["nodes"].emplace_back(childNodeToSave); + } + } + } + + void CheckInputValid(const std::string &opName, const Utility::ordered_json &graphJson) { + bool flag = Kit::IsValidKeys({"opName", "opType", "inTensorNum", "outTensorNum"}, graphJson); + if (!flag) { + return; + } + std::string opNameInJson = graphJson["opName"].get(); + if (opNameInJson != opName) { + LOG_ERROR << "opName is not in json. Currently: " << opName << " vs " << opNameInJson; + return; + } + } + + void RegisterLayer(const std::string &opName, const std::string &graph) { + g_layerGraphMap.RegisterLayerGraph(opName, graph); + } + + Utility::ordered_json Build(const Utility::ordered_json &graphJson) { + Utility::ordered_json graphNodeJsonToSave; + AddByFields({"opName", "opType", "param"}, graphJson, graphNodeJsonToSave); + std::vector tensorNameList; + ModifyRootNodes(graphNodeJsonToSave, tensorNameList, graphJson); + ProcessChildNodes(graphNodeJsonToSave, tensorNameList, graphJson); + return graphNodeJsonToSave; + } + + bool IsRegisterModel(const std::string &modelName) { + return g_modelGraphMap.IsRegisterModelGraph(modelName); + } + + void RegisterModel(const std::string &modelName, const std::string &graph) { + g_modelGraphMap.RegisterModelGraph(modelName, graph); + } + +} // namespace GraphSpace diff --git a/accuracy_tools/msprobe/csrc/atb_probe/core/SaveTensor.cpp b/accuracy_tools/msprobe/csrc/atb_probe/core/SaveTensor.cpp new file mode 100644 index 00000000000..3fc11ce5870 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/atb_probe/core/SaveTensor.cpp @@ -0,0 +1,171 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "atb_probe/include/SaveTensor.h" + +#include +#include +#include + +#include "atb_probe/include/Helper.h" +#include "utils/Constant.h" +#include "utils/DataType.h" +#include "utils/Exception.h" +#include "utils/IO.h" +#include "utils/Log.h" +#include "utils/Str.h" + +namespace TensorSpace { + const std::unordered_map levelMap = { + {"L0", "Module"}, + {"L1", "Operation"}, + {"L2", "Kernel"}, + }; + + const std::unordered_map> levelKeywords = { + {"Module", {"Operation/before", "Operation/after", "Kernel/before", "Kernel/after"}}, // exclusion item + {"Operation", {"Operation/before", "Operation/after"}}, + {"Kernel", {"Kernel/before", "Kernel/after"}}, + }; + + static std::string BuildQueryFromIds(const std::vector &ids) { + std::string query = std::to_string(ids[0]); + for (size_t i = 1; i < ids.size(); ++i) { + query += "_" + std::to_string(ids[i]); + } + return query; + } + + static bool IsPathInModeGoal(const std::string &filePath, const std::vector &keywords) { + for (const auto &keyword : keywords) { + if (filePath.find(keyword) == std::string::npos) { + return false; + } + } + return true; + } + + static bool IsPathInLevelGoal(const std::string &filePath, const std::vector &levels) { + for (const auto &level : levels) { + auto it = levelKeywords.find(level); + const auto &keywords = it->second; + if (level == "Module") { + // No excluded keywords in the path. + bool hasExcluded = false; + for (const auto &kw : keywords) { + if (filePath.find(kw) != std::string::npos) { + LOG_DEBUG << "IsPathInLevelGoal of " << level << ": true"; + hasExcluded = true; + break; + } + } + if (!hasExcluded) { + return true; + } + } else { + // Any keyword in the path is acceptable. + for (const auto &kw : keywords) { + if (filePath.find(kw) != std::string::npos) { + LOG_DEBUG << "IsPathInLevelGoal of " << level << ": true"; + return true; + } + } + } + } + return false; + } + + static bool IsDumpMode(const std::string &label) { + return StrSpace::IsValueInGoal(Cst::LINK_DATA_MODE, label) || + StrSpace::IsValueInGoal(Cst::LINK_DATA_MODE, Cst::MODE_ALL); + } + + static std::vector GetDumpLevel() { + const std::string levelStr = StrSpace::GetEnvVar(Cst::LINK_DUMP_LEVEL); // "L1" + if (levelStr.empty()) { + return {}; + } + const std::vector levelVec = StrSpace::Split(levelStr, ","); + if (levelVec.empty()) { + return {}; + } + std::vector res; + for (const auto &level : levelVec) { + auto it = levelMap.find(level); + if (it != levelMap.end()) { + res.push_back(it->second); + } else { + LOG_ERROR << "Unsupported level: " << level; + } + } + return res; + } + + bool IsOpidMatch(const std::vector &ids, const std::string &opidStr) { + std::vector opidVec = StrSpace::Split(opidStr, ","); // 1_1_1,1_2_3 + std::string query = BuildQueryFromIds(ids); + for (const auto &indice : opidVec) { + LOG_DEBUG << "op_id from model: " << query << ", op_id from user: " << indice; + bool isChildMatch = indice.find('_') != std::string::npos; + if (isChildMatch) { + if (StrSpace::IsPrefix(query, indice) && + (query == indice || (query.length() > indice.length() && query[indice.length()] == '_'))) { + return true; + } + } else { + if (indice == query) { + return true; + } + } + } + return false; + } + + bool IsOpNameMatch(const std::string &optype, const std::string &opNameStr) { + std::vector opNameVec = StrSpace::Split(opNameStr, ","); + std::string lowerCaseOptype = StrSpace::ToLower(optype); + for (const auto &indice : opNameVec) { + LOG_DEBUG << "op_name from model: " << lowerCaseOptype << ", op_name from user: " << indice; + if (StrSpace::IsPrefix(lowerCaseOptype, indice)) { + return true; + } + } + return false; + } + + bool IsExpectedTensor(const Types::TensorInfo &inputFile) { + LOG_DEBUG << "Check filePath: " << inputFile.filePath; + bool modeFlag = + (IsDumpMode(Cst::MODE_INPUT) && IsPathInModeGoal(inputFile.filePath, {Cst::BEFORE, Cst::INTENSOR})) || + (IsDumpMode(Cst::MODE_OUTPUT) && IsPathInModeGoal(inputFile.filePath, {Cst::AFTER, Cst::OUTTENSOR})); + bool levelFlag = IsPathInLevelGoal(inputFile.filePath, GetDumpLevel()); + bool rankFlag = Kit::IsValidRank(inputFile.filePath); + bool dataSizeFlag = (inputFile.dataSize <= Utility::SafePath::SIZE_10G) && (inputFile.dataSize > 0); + bool hostDataFlag = (inputFile.hostData != nullptr); + bool headLengthFlag = StrSpace::IsStringLengthSafety({inputFile.format, inputFile.dtype, inputFile.dims}); + bool flag = modeFlag && levelFlag && rankFlag && dataSizeFlag && hostDataFlag && headLengthFlag; + LOG_DEBUG << "IsExpectedTensor: " << flag; + return flag; + } + + void Save(const Types::TensorInfo &inputFile) { + Utility::fs::path newPath = + Kit::GetNewPathForSave(inputFile.filePath, Cst::SUBDIRNAME_DUMP_TENSOR).replace_extension(".npy"); + Utility::SafePath::MakeParentDir(newPath); + std::vector shapeVec = StrSpace::SplitToInt(inputFile.dims, ","); + Utility::SaveNpy(inputFile.dtype, shapeVec, inputFile.hostData, inputFile.dataSize, newPath); + } +} // namespace TensorSpace diff --git a/accuracy_tools/msprobe/csrc/atb_probe/core/Stat.cpp b/accuracy_tools/msprobe/csrc/atb_probe/core/Stat.cpp new file mode 100644 index 00000000000..0b167a4e47b --- /dev/null +++ b/accuracy_tools/msprobe/csrc/atb_probe/core/Stat.cpp @@ -0,0 +1,229 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "atb_probe/include/Stat.h" + +#include +#include + +#include + +#include "atb_probe/include/Helper.h" +#include "utils/Constant.h" +#include "utils/DataType.h" +#include "utils/Exception.h" +#include "utils/IO.h" +#include "utils/Log.h" +#include "utils/Str.h" + +namespace Stat { + inline constexpr int K_CRC_HEX_WIDTH = 8; + inline constexpr int K_FLOAT16_EXP_BITS = 5; + inline constexpr int K_FLOAT16_MANT_BITS = 10; + inline constexpr int K_BFLOAT16_EXP_BITS = 8; + inline constexpr int K_BFLOAT16_MANT_BITS = 7; + + template std::string ComputeCRC32(const T *data, size_t count) { + const unsigned char *bytePtr = static_cast(static_cast(data)); + size_t byteSize = count * sizeof(T); + + uLong crc = crc32(0L, Z_NULL, 0); + crc = crc32(crc, bytePtr, byteSize); + + std::stringstream ss; + ss << std::hex << std::setw(K_CRC_HEX_WIDTH) << std::setfill('0') << crc; + return ss.str(); + } + + static uint64_t GetCountFromDType(const std::string &dtypeKey, const uint64_t &dataSize) { + auto it = Utility::dtypeMap.find(dtypeKey); + if (it != Utility::dtypeMap.end()) { + uint64_t elemSize = it->second.elemSize; + uint64_t count = dataSize / elemSize; + return count; + } else { + throw Utility::MsprobeException("Invalid dtype: " + dtypeKey); + } + } + + static std::string GetDtype(const std::string &dtypeKey) { + auto it = Utility::dtypeMap.find(dtypeKey); + if (it != Utility::dtypeMap.end()) { + return it->second.dtypeName; + } else { + LOG_WARNING << "Unsupported dtype: " << dtypeKey; + return "Unknown"; + } + } + + static float ConvertToFloat32(const uint16_t value, const size_t exponentBits, const size_t mantissaBits) { + // Determine the bias of the semi-precision type + int32_t exponentBias = (1 << (exponentBits - 1)) - 1; + // Obtain the mask + uint16_t exponentMask = ((1 << exponentBits) - 1) << mantissaBits; + uint16_t mantissaMask = (1 << mantissaBits) - 1; + uint16_t signMask = 1 << (exponentBits + mantissaBits); + // Extract symbol bits + int sign = (value & signMask) ? -1 : 1; + // Extract index and mantissa + int32_t rawExponent = (value & exponentMask) >> mantissaBits; + uint32_t mantissa = value & mantissaMask; + // Handle special values + if (rawExponent == (1 << exponentBits) - 1) { // All 1s represent NaN or infinity + if (mantissa != 0) { + return std::numeric_limits::quiet_NaN(); // NaN + } else { + return sign * std::numeric_limits::infinity(); // Infinity + } + } else if (rawExponent == 0) { // Exponents with all zeros indicate non-normalized numbers or zeros + if (mantissa == 0) { + return sign * 0.0f; // Zero + } else { + // Unnormalized number + float result = sign * std::ldexp(static_cast(mantissa), + 1 - static_cast(exponentBias) - static_cast(mantissaBits)); + return result; + } + } + // Normalized number + float normalizedMantissa = 1.0f + static_cast(mantissa) / (1 << mantissaBits); + float result = sign * std::ldexp(normalizedMantissa, rawExponent - exponentBias); + return result; + } + + template Types::TensorStats ComputeStats(Types::TensorInfo info) { + const T *ptr = static_cast(info.hostData); + uint64_t count = GetCountFromDType(info.dtype, info.dataSize); + + double mean = 0.0; + double m2 = 0.0; // sum of squares of differences from the current mean + double minVal = static_cast(ptr[0]); + double maxVal = static_cast(ptr[0]); + + for (uint64_t i = 0; i < count; ++i) { + double val = static_cast(ptr[i]); + // Update min/max + if (val > maxVal) { + maxVal = val; + } + if (val < minVal) { + minVal = val; + } + // Welford's online update + double delta = val - mean; + mean += delta / (i + 1); + m2 += delta * (val - mean); // (val - new_mean) + } + + Types::TensorStats stats; + stats.type = "Tensor"; + stats.dtype = GetDtype(info.dtype); + stats.shape = StrSpace::SplitToInt(info.dims, ","); + stats.mean = mean; + stats.min = minVal; + stats.max = maxVal; + stats.norm = std::sqrt(m2 + count * mean * mean); // Equivalent to sqrt(sum(x^2)) + std::string summary_mode = StrSpace::GetEnvVar(Cst::LINK_SUMMARY_MODE); + stats.crc32 = (summary_mode == "md5") ? ComputeCRC32(ptr, count * sizeof(T)) : ""; + return stats; + } + + Types::TensorStats ComputeStatsFloat16(const Types::TensorInfo &info) { + const uint16_t *ptr = static_cast(info.hostData); + uint64_t count = GetCountFromDType(info.dtype, info.dataSize); + + size_t exponentBits, mantissaBits; + if (info.dtype == "1") { // float16 + exponentBits = K_FLOAT16_EXP_BITS; + mantissaBits = K_FLOAT16_MANT_BITS; + } else if (info.dtype == "27") { // bfloat16 + exponentBits = K_BFLOAT16_EXP_BITS; + mantissaBits = K_BFLOAT16_MANT_BITS; + } else { + throw Utility::MsprobeException("Unsupported dtype in float16-compatible handler."); + } + + // Initialize Welford variables + double mean = 0.0; + double m2 = 0.0; + float firstVal = ConvertToFloat32(ptr[0], exponentBits, mantissaBits); + double minVal = firstVal; + double maxVal = firstVal; + + for (uint64_t i = 0; i < count; ++i) { + float val = ConvertToFloat32(ptr[i], exponentBits, mantissaBits); + + if (val > maxVal) { + maxVal = val; + } + if (val < minVal) { + minVal = val; + } + // Welford update + double delta = val - mean; + mean += delta / (i + 1); + m2 += delta * (val - mean); + } + + Types::TensorStats stats; + stats.type = "Tensor"; + stats.dtype = GetDtype(info.dtype); + stats.shape = StrSpace::SplitToInt(info.dims, ","); + stats.mean = mean; + stats.min = minVal; + stats.max = maxVal; + stats.norm = std::sqrt(m2 + count * mean * mean); // Equivalent to sqrt(sum(val^2)) + + std::string summary_mode = StrSpace::GetEnvVar(Cst::LINK_SUMMARY_MODE); + stats.crc32 = (summary_mode == "md5") ? ComputeCRC32(ptr, count * sizeof(uint16_t)) : ""; + + return stats; + } + + Types::TensorStats Compute(const Types::TensorInfo &info) { + const std::string &dtype = info.dtype; + if (dtype == "1" || dtype == "27") { + // float16 or bfloat16 + return ComputeStatsFloat16(info); + } + if (dtype == "0") { + return ComputeStats(info); + } else if (dtype == "2") { + return ComputeStats(info); + } else if (dtype == "3") { + return ComputeStats(info); + } else if (dtype == "4") { + return ComputeStats(info); + } else if (dtype == "6") { + return ComputeStats(info); + } else if (dtype == "7") { + return ComputeStats(info); + } else if (dtype == "8") { + return ComputeStats(info); + } else if (dtype == "9") { + return ComputeStats(info); + } else if (dtype == "10") { + return ComputeStats(info); + } else if (dtype == "11") { + return ComputeStats(info); + } else if (dtype == "12") { + return ComputeStats(info); + } else { + throw Utility::MsprobeException("Unsupported dtype in ComputeByDType: " + dtype); + } + } + +} // namespace Stat diff --git a/accuracy_tools/msprobe/csrc/atb_probe/include/Helper.h b/accuracy_tools/msprobe/csrc/atb_probe/include/Helper.h new file mode 100644 index 00000000000..28f953f8bb7 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/atb_probe/include/Helper.h @@ -0,0 +1,86 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef KIT_HELPER_H +#define KIT_HELPER_H + +#include +#include +#include + +#include "utils/IO.h" +#include "utils/Path.h" + +namespace Kit { + bool IsValidRank(const std::string &filePath); + bool IsPathInGoal(const std::string &filePath); + Utility::fs::path GetRankDir(const std::string &filePath); + + class PidTieRank { + public: + static void Add(const std::string &filePath); + static std::string Get(const pid_t &pid); + + private: + static std::unordered_map g_pidWithRankMap; + }; + + Types::PathInfo GetPathInfo(const std::string &filePath); + Utility::fs::path GetNewPathForSave(const Utility::fs::path &original_path, const std::string &subDirName); + std::vector GetColumns(const std::unordered_map &kvMap); + std::vector GetElement(const std::unordered_map &kvMap, + const std::vector &columns); + + class DumpJson { + public: + DumpJson(const size_t &bufferSize, + const std::string &task, + const std::string &level, + const std::string &framework, + const std::string &outputDir); + void AddTensorDir(const std::string &filePath); + void UpdateStat(const std::string &nodeName, + const std::string &inOut, + const std::string &dataName, + const Types::TensorStats &stats); + void Flush(size_t flushThreshold = 0); + + private: + size_t formalBufferSize; + size_t formalCurrentSize; + std::string formalOutputDir; + Utility::ordered_json formalCache; + static std::mutex mutex_; + }; + + class DumpJsonManager { + public: + static DumpJsonManager &Instance(); + bool IsHas(const std::string &filePath); + void Create(const std::string &filePath, const Types::ArgsDumpJsonInit &args); + DumpJson *Get(const std::string &filePath); + + private: + DumpJsonManager() = default; + ~DumpJsonManager() = default; + DumpJsonManager(const DumpJsonManager &) = delete; + DumpJsonManager &operator=(const DumpJsonManager &) = delete; + static std::unordered_map> g_pidWithDumpJsonMap; + static std::mutex mutex_; + }; +} // namespace Kit + +#endif diff --git a/accuracy_tools/msprobe/csrc/atb_probe/include/SaveExtra.h b/accuracy_tools/msprobe/csrc/atb_probe/include/SaveExtra.h new file mode 100644 index 00000000000..28e5eb37cd6 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/atb_probe/include/SaveExtra.h @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SAVE_TILING_H +#define SAVE_TILING_H + +#include +#include + +#include "atb_probe/Override.h" + +namespace atb { + namespace ExtraSpace { + bool IsSaveTiling(); + bool IsValidParam(const uint8_t *data, const uint64_t &dataSize, const std::string &filePath); + void SaveTiling(const uint8_t *data, const uint64_t &dataSize, const std::string &filePath); + void SaveCpuProf(const uint64_t &executeCount, const std::string &opName, const std::string &st); + void SaveInfo(const uint64_t &executeCount, const atb::Probe::OpInfo &opInfo, const std::string &fileName); + bool IsSaveParam(); + } // namespace ExtraSpace +} // namespace atb + +#endif diff --git a/accuracy_tools/msprobe/csrc/atb_probe/include/SaveGraph.h b/accuracy_tools/msprobe/csrc/atb_probe/include/SaveGraph.h new file mode 100644 index 00000000000..edb60fb483c --- /dev/null +++ b/accuracy_tools/msprobe/csrc/atb_probe/include/SaveGraph.h @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SAVE_GRAPH_H +#define SAVE_GRAPH_H + +#include + +#include "utils/IO.h" + +namespace GraphSpace { + void CheckInputValid(const std::string &opName, const Utility::ordered_json &graphJson); + void RegisterLayer(const std::string &opName, const std::string &graph); + Utility::ordered_json Build(const Utility::ordered_json &graphJson); + bool IsRegisterModel(const std::string &modelName); + void RegisterModel(const std::string &modelName, const std::string &graph); +} // namespace GraphSpace + +#endif diff --git a/accuracy_tools/msprobe/csrc/atb_probe/include/SaveTensor.h b/accuracy_tools/msprobe/csrc/atb_probe/include/SaveTensor.h new file mode 100644 index 00000000000..0d92ba497b4 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/atb_probe/include/SaveTensor.h @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SAVE_TENSOR_H +#define SAVE_TENSOR_H + +#include +#include +#include + +#include "utils/DataType.h" + +namespace TensorSpace { + bool IsOpidMatch(const std::vector &ids, const std::string &opidStr); + bool IsOpNameMatch(const std::string &optype, const std::string &opNameStr); + bool IsExpectedTensor(const Types::TensorInfo &inputFile); + void Save(const Types::TensorInfo &inputFile); +} // namespace TensorSpace + +#endif diff --git a/accuracy_tools/msprobe/csrc/atb_probe/include/Stat.h b/accuracy_tools/msprobe/csrc/atb_probe/include/Stat.h new file mode 100644 index 00000000000..0da704227d0 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/atb_probe/include/Stat.h @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef STATS_H +#define STATS_H + +#include + +#include "utils/DataType.h" +#include "utils/IO.h" + +namespace Stat { + Types::TensorStats Compute(const Types::TensorInfo &info); +} // namespace Stat + +#endif diff --git a/accuracy_tools/msprobe/csrc/common/Toolkit.cpp b/accuracy_tools/msprobe/csrc/common/Toolkit.cpp new file mode 100644 index 00000000000..7633fe70a65 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/common/Toolkit.cpp @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/Toolkit.h" + +#include +#include +#include +#include + +#include "utils/Constant.h" +#include "utils/IO.h" +#include "utils/Log.h" +#include "utils/Str.h" + +namespace Kit { + inline const uint16_t maxLines = 50000; + inline const std::regex MALICIOUS_CSV_PATTERN(R"(^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@])"); + + void SetLogLevel() { + static std::once_flag flag; + std::call_once(flag, []() { + const std::string logLevel = StrSpace::GetEnvVar(Cst::LINK_LOG_LEVEL); + uint8_t defaultValue = static_cast(Utility::Log::LogLevel::INFO); + Utility::Log::GetInstance().SetLogLevel( + StrSpace::Str2Int(logLevel.c_str(), defaultValue, Cst::LINK_LOG_LEVEL)); + }); + } + + bool IsValidKeys(const std::vector &keys, const Utility::ordered_json &json) { + bool allValid = true; + for (const auto &key : keys) { + if (json.find(key) == json.end()) { + LOG_ERROR << "Found invalid key: " << key; + allValid = false; + } + } + return allValid; + } + + uint16_t GetLineNum(const std::string &filePath) { + std::ifstream inFile(filePath); + if (!inFile.is_open()) { + LOG_ERROR << "Failed to open file for reading: " << filePath; + return 0; + } + std::string line; + uint16_t lineCount = 0; + while (std::getline(inFile, line) && lineCount < maxLines) { + ++lineCount; + } + if (lineCount >= maxLines) { + LOG_WARNING << "Contents in file [" << filePath << "] reached the maximum size of " << maxLines + << " lines."; + } + return lineCount; + } + + bool IsSpecialCharInjected(const std::string &value) { + if (value.empty()) { + return false; + } + if (std::regex_search(value, MALICIOUS_CSV_PATTERN)) { + LOG_ERROR << "Found malicious characters in value: " << value << ". Cannot write file!"; + return true; + } + return false; + } +} // namespace Kit diff --git a/accuracy_tools/msprobe/csrc/common/Toolkit.h b/accuracy_tools/msprobe/csrc/common/Toolkit.h new file mode 100644 index 00000000000..806ae13108c --- /dev/null +++ b/accuracy_tools/msprobe/csrc/common/Toolkit.h @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TOOLKIT_H +#define TOOLKIT_H + +#include "utils/IO.h" + +namespace Kit { + void SetLogLevel(); + bool IsValidKeys(const std::vector &keys, const Utility::ordered_json &json); + uint16_t GetLineNum(const std::string &filePath); + bool IsSpecialCharInjected(const std::string &value); +} // namespace Kit + +#endif diff --git a/accuracy_tools/msprobe/csrc/python/PyACLActuator.cpp b/accuracy_tools/msprobe/csrc/python/PyACLActuator.cpp new file mode 100644 index 00000000000..de87fa3f247 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/python/PyACLActuator.cpp @@ -0,0 +1,591 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "python/PyACLActuator.h" + +#include "acl/include/AclApi.h" +#include "utils/Constant.h" + +namespace MSPROBE_C { + PyDoc_STRVAR(ACLInferfaceModuleDoc, "The part of the module acl actuator that is implemented in CXX."); + static PyObject *pyCallBack = nullptr; + + static PyObject *AclApi_AclInit(PyObject *module) { + int ret = CALL_ACL_API(AclInit); + return PyLong_FromLong(ret); + } + + static PyObject *AclApi_AclRtSetDevice(PyObject *module, PyObject *arg) { + int ret = -1; + if (!PyLong_Check(arg)) { + PyErr_SetString(PyExc_TypeError, " should be a integer."); + return PyLong_FromLong(ret); + } + int32_t num = static_cast(PyLong_AsLong(arg)); + if (PyErr_Occurred()) { + return PyLong_FromLong(ret); + } + ret = CALL_ACL_API(AclRtSetDevice, num); + return PyLong_FromLong(ret); + } + + static PyObject *AclApi_AclRtCreateContext(PyObject *module, PyObject *arg) { + int ret = -1; + aclrtContext context; + PyObject *pyContext = nullptr; + if (!PyLong_Check(arg)) { + PyErr_SetString(PyExc_TypeError, " should be a integer."); + return Py_BuildValue("(Oi)", Py_None, ret); + } + int32_t num = static_cast(PyLong_AsLong(arg)); + if (PyErr_Occurred()) { + return Py_BuildValue("(Oi)", Py_None, ret); + } + ret = CALL_ACL_API(AclRtCreateContext, &context, num); + pyContext = PyCapsule_New(context, "aclrtContext", nullptr); + if (pyContext == nullptr) { + return Py_BuildValue("(Oi)", Py_None, ret); + } + return Py_BuildValue("(Oi)", pyContext, ret); + } + + static PyObject *AclApi_AclMdlLoadFromFile(PyObject *module, PyObject *arg) { + int ret = -1; + if (!PyUnicode_Check(arg)) { + PyErr_SetString(PyExc_TypeError, " should be a string."); + return Py_BuildValue("(Oi)", Py_None, ret); + } + const char *path = PyUnicode_AsUTF8(arg); + if (path == nullptr) { + return Py_BuildValue("(Oi)", Py_None, ret); + } + Ascendcl::LoadFileResult cTuple = CALL_ACL_API(AclMdlLoadFromFile, path); + return Py_BuildValue("(Ii)", cTuple.modelId, cTuple.ret); + } + + static PyObject *AclApi_AclMdlCreateDesc(PyObject *module) { + aclmdlDesc *modelDesc = CALL_ACL_API(AclMdlCreateDesc); + PyObject *pyModelDesc = PyCapsule_New(modelDesc, "aclmdlDesc", nullptr); + if (pyModelDesc == nullptr) { + Py_RETURN_NONE; + } + return pyModelDesc; + } + + static PyObject *AclApi_AclMdlGetDesc(PyObject *module, PyObject *args) { + int ret = -1; + // Two parameters: modelDesc and modelId. + if (args == nullptr || PyTuple_GET_SIZE(args) != Cst::ARGS_LEN_2) { + PyErr_SetString(PyExc_TypeError, " expects 2 arguments."); + return PyLong_FromLong(ret); + } + PyObject *pyModelDesc = nullptr; + uint32_t modelId; + if (!PyArg_ParseTuple(args, "OI", &pyModelDesc, &modelId)) { + return PyLong_FromLong(ret); + } + aclmdlDesc *modelDesc = reinterpret_cast(PyCapsule_GetPointer(pyModelDesc, "aclmdlDesc")); + ret = CALL_ACL_API(AclMdlGetDesc, modelDesc, modelId); + return PyLong_FromLong(ret); + } + + static PyObject *AclApi_AclMdlGetNumInputs(PyObject *module, PyObject *args) { + size_t inputSize; + PyObject *pyModelDesc = nullptr; + if (!PyArg_ParseTuple(args, "O", &pyModelDesc)) { + Py_RETURN_NONE; + } + aclmdlDesc *modelDesc = reinterpret_cast(PyCapsule_GetPointer(pyModelDesc, "aclmdlDesc")); + inputSize = CALL_ACL_API(AclMdlGetNumInputs, modelDesc); + return PyLong_FromSize_t(inputSize); + } + + static PyObject *AclApi_AclMdlGetNumOutputs(PyObject *module, PyObject *args) { + size_t outputSize; + PyObject *pyModelDesc = nullptr; + if (!PyArg_ParseTuple(args, "O", &pyModelDesc)) { + Py_RETURN_NONE; + } + aclmdlDesc *modelDesc = reinterpret_cast(PyCapsule_GetPointer(pyModelDesc, "aclmdlDesc")); + outputSize = CALL_ACL_API(AclMdlGetNumOutputs, modelDesc); + return PyLong_FromSize_t(outputSize); + } + + static PyObject *AclApi_AclMdlGetInputNameByIndex(PyObject *module, PyObject *args) { + const char *name; + // Two parameters: modelDesc and index. + if (args == nullptr || PyTuple_GET_SIZE(args) != Cst::ARGS_LEN_2) { + PyErr_SetString(PyExc_TypeError, " expects 2 arguments."); + Py_RETURN_NONE; + } + PyObject *pyModelDesc = nullptr; + size_t index; + if (!PyArg_ParseTuple(args, "OK", &pyModelDesc, &index)) { + Py_RETURN_NONE; + } + aclmdlDesc *modelDesc = reinterpret_cast(PyCapsule_GetPointer(pyModelDesc, "aclmdlDesc")); + name = CALL_ACL_API(AclMdlGetInputNameByIndex, modelDesc, index); + return PyUnicode_FromString(name); + } + + static PyObject *AclApi_AclMdlGetInputSizeByIndex(PyObject *module, PyObject *args) { + size_t inputSize; + // Two parameters: modelDesc and index. + if (args == nullptr || PyTuple_GET_SIZE(args) != Cst::ARGS_LEN_2) { + PyErr_SetString(PyExc_TypeError, " expects 2 arguments."); + Py_RETURN_NONE; + } + PyObject *pyModelDesc = nullptr; + size_t index; + if (!PyArg_ParseTuple(args, "OK", &pyModelDesc, &index)) { + Py_RETURN_NONE; + } + aclmdlDesc *modelDesc = reinterpret_cast(PyCapsule_GetPointer(pyModelDesc, "aclmdlDesc")); + inputSize = CALL_ACL_API(AclMdlGetInputSizeByIndex, modelDesc, index); + return PyLong_FromSize_t(inputSize); + } + + static PyObject *AclApi_AclMdlGetOutputSizeByIndex(PyObject *module, PyObject *args) { + size_t outputSize; + // Two parameters: modelDesc and index. + if (args == nullptr || PyTuple_GET_SIZE(args) != Cst::ARGS_LEN_2) { + PyErr_SetString(PyExc_TypeError, " expects 2 arguments."); + Py_RETURN_NONE; + } + PyObject *pyModelDesc = nullptr; + size_t index; + if (!PyArg_ParseTuple(args, "OK", &pyModelDesc, &index)) { + Py_RETURN_NONE; + } + aclmdlDesc *modelDesc = reinterpret_cast(PyCapsule_GetPointer(pyModelDesc, "aclmdlDesc")); + outputSize = CALL_ACL_API(AclMdlGetOutputSizeByIndex, modelDesc, index); + return PyLong_FromSize_t(outputSize); + } + + static PyObject *AclApi_AclMdlGetInputDataType(PyObject *module, PyObject *args) { + // Two parameters: modelDesc and index. + if (args == nullptr || PyTuple_GET_SIZE(args) != Cst::ARGS_LEN_2) { + PyErr_SetString(PyExc_TypeError, " expects 2 arguments."); + Py_RETURN_NONE; + } + PyObject *pyModelDesc = nullptr; + size_t index; + if (!PyArg_ParseTuple(args, "OK", &pyModelDesc, &index)) { + Py_RETURN_NONE; + } + aclmdlDesc *modelDesc = reinterpret_cast(PyCapsule_GetPointer(pyModelDesc, "aclmdlDesc")); + int type = CALL_ACL_API(AclMdlGetInputDataType, modelDesc, index); + return PyLong_FromLong(type); + } + + static PyObject *AclApi_AclMdlGetInputDims(PyObject *module, PyObject *args) { + int ret = -1; + // Two parameters: modelDesc and index. + if (args == nullptr || PyTuple_GET_SIZE(args) != Cst::ARGS_LEN_2) { + PyErr_SetString(PyExc_TypeError, " expects 2 arguments."); + return Py_BuildValue("(Oi)", Py_None, ret); + } + PyObject *pyModelDesc = nullptr; + size_t index; + if (!PyArg_ParseTuple(args, "OK", &pyModelDesc, &index)) { + return Py_BuildValue("(Oi)", Py_None, ret); + } + aclmdlDesc *modelDesc = reinterpret_cast(PyCapsule_GetPointer(pyModelDesc, "aclmdlDesc")); + aclmdlIODims ioDims; + ret = CALL_ACL_API(AclMdlGetInputDims, modelDesc, index, &ioDims); + PyObject *dimsDict = PyDict_New(); + PyObject *pyDims = PyList_New(ioDims.dimCount); + for (size_t i = 0; i < ioDims.dimCount; ++i) { + PyList_SET_ITEM(pyDims, i, PyLong_FromSize_t(ioDims.dims[i])); + } + PyDict_SetItemString(dimsDict, "name", PyUnicode_FromString(ioDims.name)); + PyDict_SetItemString(dimsDict, "dimCount", PyLong_FromSize_t(ioDims.dimCount)); + PyDict_SetItemString(dimsDict, "dims", pyDims); + return Py_BuildValue("(Oi)", dimsDict, ret); + } + + static PyObject *AclApi_AclRtMalloc(PyObject *module, PyObject *args) { + int ret = -1; + void *ptr = nullptr; + size_t bufferSize; + PyObject *pyMallocPtr = nullptr; + if (!PyArg_ParseTuple(args, "K", &bufferSize)) { + return Py_BuildValue("(Oi)", Py_None, ret); + } + ret = CALL_ACL_API(AclRtMalloc, &ptr, bufferSize); + pyMallocPtr = PyCapsule_New(ptr, "aclrtMallocPtr", nullptr); + if (pyMallocPtr == nullptr) { + return Py_BuildValue("(Oi)", Py_None, ret); + } + return Py_BuildValue("(Oi)", pyMallocPtr, ret); + } + + static PyObject *AclApi_AclMdlCreateDataset(PyObject *module) { + aclmdlDataset *dataset = CALL_ACL_API(AclMdlCreateDataset); + PyObject *pyDataset = PyCapsule_New(dataset, "aclmdlDataset", nullptr); + if (pyDataset == nullptr) { + Py_RETURN_NONE; + } + return pyDataset; + } + + static PyObject *AclApi_AclCreateDataBuffer(PyObject *module, PyObject *args) { + // Two parameters: ptr and buffersize. + if (args == nullptr || PyTuple_GET_SIZE(args) != Cst::ARGS_LEN_2) { + PyErr_SetString(PyExc_TypeError, " expects 2 arguments."); + Py_RETURN_NONE; + } + PyObject *pyMallocPtr = nullptr; + size_t bufferSize; + if (!PyArg_ParseTuple(args, "OK", &pyMallocPtr, &bufferSize)) { + Py_RETURN_NONE; + } + void *ptr = reinterpret_cast(PyCapsule_GetPointer(pyMallocPtr, "aclrtMallocPtr")); + aclDataBuffer *dataBuffer = CALL_ACL_API(AclCreateDataBuffer, ptr, bufferSize); + PyObject *pyDataBuffer = PyCapsule_New(dataBuffer, "aclDataBuffer", nullptr); + if (pyDataBuffer == nullptr) { + Py_RETURN_NONE; + } + return pyDataBuffer; + } + + static PyObject *AclApi_AclMdlAddDatasetBuffer(PyObject *module, PyObject *args) { + int ret = -1; + // Two parameters: dataset and buffer. + if (args == nullptr || PyTuple_GET_SIZE(args) != Cst::ARGS_LEN_2) { + PyErr_SetString(PyExc_TypeError, " expects 2 arguments."); + return PyLong_FromLong(ret); + } + PyObject *pyDataset = nullptr; + PyObject *pyDataBuffer = nullptr; + if (!PyArg_ParseTuple(args, "OO", &pyDataset, &pyDataBuffer)) { + return PyLong_FromLong(ret); + } + aclDataBuffer *buffer = reinterpret_cast(PyCapsule_GetPointer(pyDataBuffer, "aclDataBuffer")); + aclmdlDataset *dataset = reinterpret_cast(PyCapsule_GetPointer(pyDataset, "aclmdlDataset")); + ret = CALL_ACL_API(AclMdlAddDatasetBuffer, dataset, buffer); + return PyLong_FromLong(ret); + } + + static PyObject *AclApi_AclMdlExecute(PyObject *module, PyObject *args) { + int ret = -1; + // Three parameters: modelId, inputDataset and outputDataset. + if (args == nullptr || PyTuple_GET_SIZE(args) != Cst::ARGS_LEN_3) { + PyErr_SetString(PyExc_TypeError, " expects 3 arguments."); + return PyLong_FromLong(ret); + } + PyObject *pyInputDataset = nullptr; + PyObject *pyOutputDataset = nullptr; + uint32_t modelId; + if (!PyArg_ParseTuple(args, "IOO", &modelId, &pyInputDataset, &pyOutputDataset)) { + return PyLong_FromLong(ret); + } + aclmdlDataset *inputDataset = + reinterpret_cast(PyCapsule_GetPointer(pyInputDataset, "aclmdlDataset")); + aclmdlDataset *outputDataset = + reinterpret_cast(PyCapsule_GetPointer(pyOutputDataset, "aclmdlDataset")); + ret = CALL_ACL_API(AclMdlExecute, modelId, inputDataset, outputDataset); + return PyLong_FromLong(ret); + } + + static PyObject *AclApi_AclRtMemcpy(PyObject *module, PyObject *args) { + int ret = -1; + // Five parameters: dest, destMax, src, count and kind. + if (args == nullptr || PyTuple_GET_SIZE(args) != Cst::ARGS_LEN_5) { + PyErr_SetString(PyExc_TypeError, " expects 5 arguments."); + return PyLong_FromLong(ret); + } + PyObject *pyDst = nullptr; + PyObject *pySrc = nullptr; + size_t destCount; + size_t srcCount; + int kind; + if (!PyArg_ParseTuple(args, "OKOKi", &pyDst, &destCount, &pySrc, &srcCount, &kind)) { + return PyLong_FromLong(ret); + } + const void *src = nullptr; + if (PyCapsule_CheckExact(pySrc)) { + src = reinterpret_cast(PyCapsule_GetPointer(pySrc, "aclrtMallocPtr")); + } else if (PyBytes_Check(pySrc)) { + src = reinterpret_cast(PyBytes_AsString(pySrc)); + } else { + PyErr_SetString(PyExc_TypeError, " invalid type."); + } + void *dst = reinterpret_cast(PyCapsule_GetPointer(pyDst, "aclrtMallocPtr")); + aclrtMemcpyKind cKind = static_cast(kind); + ret = CALL_ACL_API(AclRtMemcpy, dst, destCount, src, srcCount, cKind); + return PyLong_FromLong(ret); + } + + static PyObject *AclApi_AclMdlGetDatasetNumBuffers(PyObject *module, PyObject *args) { + size_t bufferNum; + PyObject *pyDataset = nullptr; + if (!PyArg_ParseTuple(args, "O", &pyDataset)) { + Py_RETURN_NONE; + } + const aclmdlDataset *dataset = + reinterpret_cast(PyCapsule_GetPointer(pyDataset, "aclmdlDataset")); + bufferNum = CALL_ACL_API(AclMdlGetDatasetNumBuffers, dataset); + return PyLong_FromSize_t(bufferNum); + } + + static PyObject *AclApi_AclMdlGetDatasetBuffer(PyObject *module, PyObject *args) { + // Two parameters: dataset and index. + if (args == nullptr || PyTuple_GET_SIZE(args) != Cst::ARGS_LEN_2) { + PyErr_SetString(PyExc_TypeError, " expects 2 arguments."); + Py_RETURN_NONE; + } + PyObject *pyDataset = nullptr; + size_t index; + if (!PyArg_ParseTuple(args, "OK", &pyDataset, &index)) { + Py_RETURN_NONE; + } + const aclmdlDataset *dataset = + reinterpret_cast(PyCapsule_GetPointer(pyDataset, "aclmdlDataset")); + aclDataBuffer *dataBuffer = CALL_ACL_API(AclMdlGetDatasetBuffer, dataset, index); + PyObject *pyDataBuffer = PyCapsule_New(dataBuffer, "aclDataBuffer", nullptr); + if (pyDataBuffer == nullptr) { + Py_RETURN_NONE; + } + return pyDataBuffer; + } + + static PyObject *AclApi_AclDestroyDataBuffer(PyObject *module, PyObject *args) { + int ret = -1; + PyObject *pyDataBuffer = nullptr; + if (!PyArg_ParseTuple(args, "O", &pyDataBuffer)) { + return PyLong_FromLong(ret); + } + const aclDataBuffer *buffer = + reinterpret_cast(PyCapsule_GetPointer(pyDataBuffer, "aclDataBuffer")); + ret = CALL_ACL_API(AclDestroyDataBuffer, buffer); + return PyLong_FromLong(ret); + } + + static PyObject *AclApi_AclMdlDestroyDataset(PyObject *module, PyObject *args) { + int ret = -1; + PyObject *pyDataset = nullptr; + if (!PyArg_ParseTuple(args, "O", &pyDataset)) { + return PyLong_FromLong(ret); + } + const aclmdlDataset *dataset = + reinterpret_cast(PyCapsule_GetPointer(pyDataset, "aclmdlDataset")); + ret = CALL_ACL_API(AclMdlDestroyDataset, dataset); + return PyLong_FromLong(ret); + } + + static PyObject *AclApi_AclRtFree(PyObject *module, PyObject *args) { + int ret = -1; + PyObject *pyMallocPtr = nullptr; + if (!PyArg_ParseTuple(args, "O", &pyMallocPtr)) { + return PyLong_FromLong(ret); + } + void *ptr = reinterpret_cast(PyCapsule_GetPointer(pyMallocPtr, "aclrtMallocPtr")); + ret = CALL_ACL_API(AclRtFree, ptr); + return PyLong_FromLong(ret); + } + + static PyObject *AclApi_AclFinalize(PyObject *module) { + int ret = CALL_ACL_API(AclFinalize); + return PyLong_FromLong(ret); + } + + static PyObject *AclApi_AclMdlUnload(PyObject *module, PyObject *arg) { + int ret = -1; + if (!PyLong_Check(arg)) { + PyErr_SetString(PyExc_TypeError, " should be a integer."); + return PyLong_FromLong(ret); + } + int32_t num = static_cast(PyLong_AsLong(arg)); + if (PyErr_Occurred()) { + return PyLong_FromLong(ret); + } + ret = CALL_ACL_API(AclMdlUnload, num); + return PyLong_FromLong(ret); + } + + static PyObject *AclApi_AclMdlDestroyDesc(PyObject *module, PyObject *args) { + int ret = -1; + PyObject *pyModelDesc = nullptr; + if (!PyArg_ParseTuple(args, "O", &pyModelDesc)) { + return PyLong_FromLong(ret); + } + aclmdlDesc *modelDesc = reinterpret_cast(PyCapsule_GetPointer(pyModelDesc, "aclmdlDesc")); + ret = CALL_ACL_API(AclMdlDestroyDesc, modelDesc); + return PyLong_FromLong(ret); + } + + static PyObject *AclApi_AclRtDestroyContext(PyObject *module, PyObject *args) { + int ret = -1; + PyObject *pyContext = nullptr; + if (!PyArg_ParseTuple(args, "O", &pyContext)) { + return PyLong_FromLong(ret); + } + aclrtContext modelDesc = reinterpret_cast(PyCapsule_GetPointer(pyContext, "aclrtContext")); + ret = CALL_ACL_API(AclRtDestroyContext, modelDesc); + return PyLong_FromLong(ret); + } + + static PyObject *AclApi_AclRtResetDevice(PyObject *module, PyObject *arg) { + int ret = -1; + if (!PyLong_Check(arg)) { + PyErr_SetString(PyExc_TypeError, " should be a integer."); + return PyLong_FromLong(ret); + } + int32_t num = static_cast(PyLong_AsLong(arg)); + if (PyErr_Occurred()) { + return PyLong_FromLong(ret); + } + ret = CALL_ACL_API(AclRtResetDevice, num); + return PyLong_FromLong(ret); + } + + static PyObject *AclApi_AclInitDump(PyObject *module) { + int ret = CALL_ACL_API(AclInitDump); + return PyLong_FromLong(ret); + } + + static PyObject *AclApi_AclSetDump(PyObject *module, PyObject *arg) { + int ret = -1; + if (!PyUnicode_Check(arg)) { + PyErr_SetString(PyExc_TypeError, "\"dump config path\" should be a string."); + return PyLong_FromLong(ret); + } + const char *path = PyUnicode_AsUTF8(arg); + if (path == nullptr) { + return PyLong_FromLong(ret); + } + ret = CALL_ACL_API(AclSetDump, path); + return PyLong_FromLong(ret); + } + + static PyObject *AclApi_AclFinalizeDump(PyObject *module) { + int ret = CALL_ACL_API(AclFinalizeDump); + return PyLong_FromLong(ret); + } + + static int32_t CppToPyCallBack(const acldumpChunk *chunk, int32_t len) { + if (!pyCallBack || !chunk) { + return 0; + } + if (!PyCallable_Check(pyCallBack)) { + PyErr_SetString(PyExc_TypeError, "callback func must can be call."); + return 0; + } + PyObject *pyChunk = PyDict_New(); + PyDict_SetItemString(pyChunk, "file_name", PyUnicode_FromString(chunk->fileName)); + PyDict_SetItemString(pyChunk, "buf_len", PyLong_FromUnsignedLong(chunk->bufLen)); + PyDict_SetItemString(pyChunk, "is_last_chunk", PyLong_FromUnsignedLong(chunk->isLastChunk)); + PyDict_SetItemString(pyChunk, "offset", PyLong_FromLongLong(chunk->offset)); + PyDict_SetItemString(pyChunk, "flag", PyLong_FromLong(chunk->flag)); + PyDict_SetItemString(pyChunk, + "data_buf", + PyBytes_FromStringAndSize(reinterpret_cast(chunk->dataBuf), chunk->bufLen)); + PyObject *args = Py_BuildValue("(Oi)", pyChunk, PyLong_FromLong(len)); + PyObject *result = PyObject_CallObject(pyCallBack, args); + Py_XDECREF(args); + Py_XDECREF(result); + Py_XDECREF(pyChunk); + return 0; + } + + static PyObject *AclApi_AclDumpRegCallBack(PyObject *module, PyObject *args) { + PyObject *callback; + int flag; + int ret = -1; + if (!PyArg_ParseTuple(args, "Oi", &callback, &flag)) { + Py_RETURN_NONE; + } + if (!PyCallable_Check(callback)) { + PyErr_SetString(PyExc_TypeError, "callback func must can be call."); + Py_RETURN_NONE; + } + Py_XDECREF(pyCallBack); + pyCallBack = callback; + Py_INCREF(pyCallBack); + ret = CALL_ACL_API(AclDumpRegCallBack, CppToPyCallBack, flag); + return PyLong_FromLong(ret); + } + + static PyObject *AclApi_AclDumpUnregCallBack(PyObject *module) { + CALL_ACL_API(AclDumpUnregCallBack); + Py_INCREF(pyCallBack); + pyCallBack = nullptr; + Py_RETURN_NONE; + } + + static PyMethodDef ACLActuatorMethods[] = { + {"init", reinterpret_cast(AclApi_AclInit), METH_NOARGS, nullptr}, + {"rt_set_device", reinterpret_cast(AclApi_AclRtSetDevice), METH_O, nullptr}, + {"rt_create_context", reinterpret_cast(AclApi_AclRtCreateContext), METH_O, nullptr}, + {"load_from_file", reinterpret_cast(AclApi_AclMdlLoadFromFile), METH_O, nullptr}, + {"create_desc", reinterpret_cast(AclApi_AclMdlCreateDesc), METH_NOARGS, nullptr}, + {"get_desc", reinterpret_cast(AclApi_AclMdlGetDesc), METH_VARARGS, nullptr}, + {"get_num_inputs", reinterpret_cast(AclApi_AclMdlGetNumInputs), METH_VARARGS, nullptr}, + {"get_num_outputs", reinterpret_cast(AclApi_AclMdlGetNumOutputs), METH_VARARGS, nullptr}, + {"get_input_name_by_index", + reinterpret_cast(AclApi_AclMdlGetInputNameByIndex), + METH_VARARGS, + nullptr}, + {"get_input_size_by_index", + reinterpret_cast(AclApi_AclMdlGetInputSizeByIndex), + METH_VARARGS, + nullptr}, + {"get_output_size_by_index", + reinterpret_cast(AclApi_AclMdlGetOutputSizeByIndex), + METH_VARARGS, + nullptr}, + {"get_input_data_type", reinterpret_cast(AclApi_AclMdlGetInputDataType), METH_VARARGS, nullptr}, + {"get_input_dims", reinterpret_cast(AclApi_AclMdlGetInputDims), METH_VARARGS, nullptr}, + {"rt_malloc", reinterpret_cast(AclApi_AclRtMalloc), METH_VARARGS, nullptr}, + {"create_dataset", reinterpret_cast(AclApi_AclMdlCreateDataset), METH_NOARGS, nullptr}, + {"create_databuffer", reinterpret_cast(AclApi_AclCreateDataBuffer), METH_VARARGS, nullptr}, + {"add_dataset_buffer", reinterpret_cast(AclApi_AclMdlAddDatasetBuffer), METH_VARARGS, nullptr}, + {"execute", reinterpret_cast(AclApi_AclMdlExecute), METH_VARARGS, nullptr}, + {"rt_memcpy", reinterpret_cast(AclApi_AclRtMemcpy), METH_VARARGS, nullptr}, + {"get_dataset_num_buffers", + reinterpret_cast(AclApi_AclMdlGetDatasetNumBuffers), + METH_VARARGS, + nullptr}, + {"get_dataset_buffer", reinterpret_cast(AclApi_AclMdlGetDatasetBuffer), METH_VARARGS, nullptr}, + {"destroy_databuffer", reinterpret_cast(AclApi_AclDestroyDataBuffer), METH_VARARGS, nullptr}, + {"destroy_dataset", reinterpret_cast(AclApi_AclMdlDestroyDataset), METH_VARARGS, nullptr}, + {"rt_free", reinterpret_cast(AclApi_AclRtFree), METH_VARARGS, nullptr}, + {"finalize", reinterpret_cast(AclApi_AclFinalize), METH_NOARGS, nullptr}, + {"unload", reinterpret_cast(AclApi_AclMdlUnload), METH_O, nullptr}, + {"destroy_desc", reinterpret_cast(AclApi_AclMdlDestroyDesc), METH_VARARGS, nullptr}, + {"rt_destroy_context", reinterpret_cast(AclApi_AclRtDestroyContext), METH_VARARGS, nullptr}, + {"rt_reset_device", reinterpret_cast(AclApi_AclRtResetDevice), METH_O, nullptr}, + {"init_dump", reinterpret_cast(AclApi_AclInitDump), METH_NOARGS, nullptr}, + {"set_dump", reinterpret_cast(AclApi_AclSetDump), METH_O, nullptr}, + {"finalize_dump", reinterpret_cast(AclApi_AclFinalizeDump), METH_NOARGS, nullptr}, + {"dump_reg_callback", reinterpret_cast(AclApi_AclDumpRegCallBack), METH_VARARGS, nullptr}, + {"dump_unreg_callback", reinterpret_cast(AclApi_AclDumpUnregCallBack), METH_NOARGS, nullptr}, + {nullptr, nullptr, 0, nullptr}, + }; + + static struct PyModuleDef g_ACLActuatorModule = { + PyModuleDef_HEAD_INIT, + "msprobe_c.acl", // m_name + ACLInferfaceModuleDoc, // m_doc + -1, // m_size + ACLActuatorMethods, // m_methods + }; + + PyObject *GetACLActuatorModule() { + return PyModule_Create(&g_ACLActuatorModule); + } +} // namespace MSPROBE_C diff --git a/accuracy_tools/msprobe/csrc/python/PyACLActuator.h b/accuracy_tools/msprobe/csrc/python/PyACLActuator.h new file mode 100644 index 00000000000..a47d402123d --- /dev/null +++ b/accuracy_tools/msprobe/csrc/python/PyACLActuator.h @@ -0,0 +1,26 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PYACL_ACTUATOR_H +#define PYACL_ACTUATOR_H + +#include + +namespace MSPROBE_C { + PyObject *GetACLActuatorModule(); +} + +#endif diff --git a/accuracy_tools/msprobe/csrc/python/PyInterface.cpp b/accuracy_tools/msprobe/csrc/python/PyInterface.cpp new file mode 100644 index 00000000000..c6aa54ad7de --- /dev/null +++ b/accuracy_tools/msprobe/csrc/python/PyInterface.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "python/PyACLActuator.h" +#include "python/PyLog.h" + +#include + +namespace MSPROBE_C { + PyDoc_STRVAR(InferfaceModuleDoc, "The part of the msprobe module that is implemented in CXX."); + + static struct PyModuleDef g_InterfaceModule = { + PyModuleDef_HEAD_INIT, + "msprobe_c", // m_name + InferfaceModuleDoc, // m_doc + -1, // m_size + nullptr, // m_methods + }; +} // namespace MSPROBE_C + +PyMODINIT_FUNC PyInit_msprobe_c(void) { + PyObject *m = PyModule_Create(&MSPROBE_C::g_InterfaceModule); + if (m == nullptr) { + return nullptr; + } + + PyObject *cpyACLActuator = MSPROBE_C::GetACLActuatorModule(); + if (cpyACLActuator == nullptr) { + PyErr_SetString(PyExc_ImportError, "Failed to create submodule ACLActuatorModule."); + Py_DECREF(m); + return nullptr; + } + if (PyModule_AddObject(m, "acl", cpyACLActuator) < 0) { + PyErr_SetString(PyExc_ImportError, "Failed to bind submodule ACLActuatorModule."); + Py_DECREF(m); + return nullptr; + } + + PyObject *cpyLog = MSPROBE_C::GetLogModule(); + if (cpyLog == nullptr) { + PyErr_SetString(PyExc_ImportError, "Failed to create submodule LogModule."); + Py_DECREF(m); + return nullptr; + } + if (PyModule_AddObject(m, "log", cpyLog) < 0) { + PyErr_SetString(PyExc_ImportError, "Failed to bind submodule LogModule."); + Py_DECREF(m); + return nullptr; + } + return m; +} diff --git a/accuracy_tools/msprobe/csrc/python/PyLog.cpp b/accuracy_tools/msprobe/csrc/python/PyLog.cpp new file mode 100644 index 00000000000..9b7fe4b54db --- /dev/null +++ b/accuracy_tools/msprobe/csrc/python/PyLog.cpp @@ -0,0 +1,107 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "python/PyLog.h" + +#include "utils/Constant.h" +#include "utils/Log.h" + +namespace MSPROBE_C { + PyDoc_STRVAR(LogModuleDoc, "The part of the module log that is implemented in CXX."); + + void LogDebugPy(const std::string &msg) { + Utility::Log::GetInstance().DebugStream() << msg; + } + + void LogInfoPy(const std::string &msg) { + Utility::Log::GetInstance().InfoStream() << msg; + } + + void LogWarningPy(const std::string &msg) { + Utility::Log::GetInstance().WarningStream() << msg; + } + + void LogErrorPy(const std::string &msg) { + Utility::Log::GetInstance().ErrorStream() << msg; + } + + static PyObject *Print_Log(PyObject *module, PyObject *args) { + int logLevel; + const char *msg; + if (args == nullptr || PyTuple_GET_SIZE(args) != Cst::ARGS_LEN_2) { + PyErr_SetString(PyExc_TypeError, " expects 2 arguments."); + Py_RETURN_NONE; + } + if (!PyArg_ParseTuple(args, "is", &logLevel, &msg)) { + PyErr_SetString(PyExc_TypeError, " should input a integer and a string."); + Py_RETURN_NONE; + } + switch (logLevel) { + case static_cast(Utility::Log::LogLevel::DEBUG): + LogDebugPy(msg); + break; + case static_cast(Utility::Log::LogLevel::INFO): + LogInfoPy(msg); + break; + case static_cast(Utility::Log::LogLevel::WARNING): + LogWarningPy(msg); + break; + case static_cast(Utility::Log::LogLevel::ERROR): + LogErrorPy(msg); + break; + default: + break; + } + Py_RETURN_NONE; + } + + static PyObject *Set_Log_Level(PyObject *module, PyObject *arg) { + if (!PyLong_Check(arg)) { + PyErr_SetString(PyExc_TypeError, " should be a integer."); + Py_RETURN_NONE; + } + int logLevel = static_cast(PyLong_AsLong(arg)); + if (PyErr_Occurred()) { + Py_RETURN_NONE; + } + Utility::Log::GetInstance().SetLogLevel(logLevel); + Py_RETURN_NONE; + } + + static PyObject *Get_Log_Level(PyObject *module) { + int logLevel = Utility::Log::GetInstance().GetLogLevel(); + return PyLong_FromLong(logLevel); + } + + static PyMethodDef LogMethods[] = { + {"print_log", reinterpret_cast(Print_Log), METH_VARARGS, nullptr}, + {"set_log_level", reinterpret_cast(Set_Log_Level), METH_O, nullptr}, + {"get_log_level", reinterpret_cast(Get_Log_Level), METH_NOARGS, nullptr}, + {nullptr, nullptr, 0, nullptr}, + }; + + static struct PyModuleDef g_LogModule = { + PyModuleDef_HEAD_INIT, + "msprobe_c.log", // m_name + LogModuleDoc, // m_doc + -1, // m_size + LogMethods, // m_methods + }; + + PyObject *GetLogModule() { + return PyModule_Create(&g_LogModule); + } +} // namespace MSPROBE_C diff --git a/accuracy_tools/msprobe/csrc/python/PyLog.h b/accuracy_tools/msprobe/csrc/python/PyLog.h new file mode 100644 index 00000000000..d768e4a99b6 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/python/PyLog.h @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PY_LOG_H +#define PY_LOG_H + +#include + +#include "utils/Log.h" + +namespace MSPROBE_C { + void LogDebugPy(const std::string &msg); + void LogInfoPy(const std::string &msg); + void LogWarningPy(const std::string &msg); + void LogErrorPy(const std::string &msg); + PyObject *GetLogModule(); +} // namespace MSPROBE_C + +#endif diff --git a/accuracy_tools/msprobe/csrc/utils/Constant.h b/accuracy_tools/msprobe/csrc/utils/Constant.h new file mode 100644 index 00000000000..2365ce198de --- /dev/null +++ b/accuracy_tools/msprobe/csrc/utils/Constant.h @@ -0,0 +1,63 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CONST_H +#define CONST_H + +#include + +namespace Cst { + inline const size_t BUFFER_SIZE = 1024 * 1024; // 1MB + inline const uint8_t INDENT_WIDTH = 4; + inline const uint8_t ARGS_LEN_2 = 2; + inline const uint8_t ARGS_LEN_3 = 3; + inline const uint8_t ARGS_LEN_5 = 5; + + inline const char *LINK_LOG_LEVEL = "LINK_LOG_LEVEL"; + inline const char *LINK_DUMP_PATH = "LINK_DUMP_PATH"; + inline const char *LINK_STEP = "LINK_STEP"; + inline const char *LINK_RANK = "LINK_RANK"; + inline const char *LINK_SAVE_TENSOR_IDS = "LINK_SAVE_TENSOR_IDS"; + inline const char *LINK_SAVE_TENSOR_RUNNER = "LINK_SAVE_TENSOR_RUNNER"; + inline const char *LINK_SAVE_TILING = "LINK_SAVE_TILING"; + inline const char *LINK_SAVE_CPU_PROFILING = "LINK_SAVE_CPU_PROFILING"; + inline const char *LINK_SAVE_KERNEL_INFO = "LINK_SAVE_KERNEL_INFO"; + inline const char *LINK_SAVE_OP_INFO = "LINK_SAVE_OP_INFO"; + inline const char *LINK_SAVE_PARAM = "LINK_SAVE_PARAM"; + inline const char *LINK_DUMP_TASK = "LINK_DUMP_TASK"; + inline const char *LINK_SUMMARY_MODE = "LINK_SUMMARY_MODE"; + inline const char *LINK_BUFFER_SIZE = "LINK_BUFFER_SIZE"; + inline const char *LINK_DATA_MODE = "LINK_DATA_MODE"; + inline const char *LINK_DUMP_LEVEL = "LINK_DUMP_LEVEL"; + inline const char *LINK_STOP = "LINK_STOP"; + + const std::string SUBDIRNAME_DUMP_TENSOR = "dump_tensor_data"; + const std::string SUBDIRNAME_TILING = "dump_tiling_data"; + const std::string TASK_STAT = "statistics"; + const std::string TASK_TENSOR = "tensor"; + const std::string TASK_OVERFLOW = "overflow"; + const std::string INTENSOR = "intensor"; + const std::string OUTTENSOR = "outtensor"; + const std::string BEFORE = "before"; + const std::string AFTER = "after"; + const std::string MODE_INPUT = "input"; + const std::string MODE_OUTPUT = "output"; + const std::string MODE_ALL = "all"; + const std::string SUMMARY_MD5 = "md5"; + const std::string FRAMEWORK_MINDIELLM = "MindIE_LLM"; +} // namespace Cst + +#endif diff --git a/accuracy_tools/msprobe/csrc/utils/DataType.h b/accuracy_tools/msprobe/csrc/utils/DataType.h new file mode 100644 index 00000000000..f1398e2d44a --- /dev/null +++ b/accuracy_tools/msprobe/csrc/utils/DataType.h @@ -0,0 +1,88 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATA_TYPE_H +#define DATA_TYPE_H + +#include +#include +#include + +namespace Types { + struct LayerGraphMap { + std::map LayerGraphMap_; + void RegisterLayerGraph(const std::string &opName, const std::string &graph) { + LayerGraphMap_[opName] = graph; + }; + std::string GetLayerGraph(const std::string &opName) { + auto it = LayerGraphMap_.find(opName); + return (it != LayerGraphMap_.end()) ? it->second : ""; + }; + }; + + struct ModelGraphMap { + std::map modelGraphMap_; + bool IsRegisterModelGraph(const std::string &modelName) { + auto it = modelGraphMap_.find(modelName); + return (it == modelGraphMap_.end()) ? true : false; + }; + void RegisterModelGraph(const std::string &modelName, const std::string &graph) { + modelGraphMap_[modelName] = graph; + }; + }; + + struct TensorInfo { + std::string format; + std::string dtype; + std::string dims; + const void *hostData; + uint64_t dataSize; + std::string filePath; + }; + + struct TensorStats { + std::string type; + std::string dtype; + std::vector shape; + double max; + double min; + double mean; + double norm; + std::string crc32; + }; + + struct DTypeInfo { + uint64_t elemSize; + std::string descrDtype; + std::string dtypeName; + }; + + struct PathInfo { + std::string inOut; + std::string nodeName; + std::string argsName; + }; + + struct ArgsDumpJsonInit { + uint32_t bufferSize; + std::string task; + std::string level; + std::string framework; + std::string outputDir; + }; +} // namespace Types + +#endif diff --git a/accuracy_tools/msprobe/csrc/utils/Exception.h b/accuracy_tools/msprobe/csrc/utils/Exception.h new file mode 100644 index 00000000000..4eb65006b6d --- /dev/null +++ b/accuracy_tools/msprobe/csrc/utils/Exception.h @@ -0,0 +1,30 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef EXCEPTION_H +#define EXCEPTION_H + +#include + +namespace Utility { + class MsprobeException : public std::runtime_error { + public: + explicit MsprobeException(const std::string &msg) : std::runtime_error("msprobe [ERROR]: " + msg) { + } + }; +} // namespace Utility + +#endif diff --git a/accuracy_tools/msprobe/csrc/utils/IO.cpp b/accuracy_tools/msprobe/csrc/utils/IO.cpp new file mode 100644 index 00000000000..4a109339e54 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/utils/IO.cpp @@ -0,0 +1,195 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/IO.h" + +#include +#include +#include + +#include "utils/Constant.h" +#include "utils/DataType.h" +#include "utils/Exception.h" +#include "utils/Log.h" +#include "utils/Path.h" +#include "utils/Str.h" + +namespace Utility { + static void PrintSaveFailMsg(const std::exception &e, const std::string &filePath, const uint64_t dataSize) { + LOG_ERROR << "Exception during write: " << e.what(); + bool tagSpaceFree = SafePath::IsDiskSpaceValid(filePath, dataSize); + LOG_ERROR << " Possible check 1: Disk space - 1 (sufficient), 0 (insufficient), current: " << tagSpaceFree; + } + + template + void SaveRawNpy(const std::vector &data, + const std::string &filePath, + const std::vector &shape, + const std::string &descrDtype) { + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); + + std::ofstream out(filePath, std::ios::binary); + if (!out) { + throw MsprobeException("Cannot open file: " + filePath); + } + + std::string magic = "\x93NUMPY"; + uint8_t major = 1; + uint8_t minor = 0; + + std::ostringstream headerStream; + headerStream << "{'descr': '" << descrDtype << "', 'fortran_order': False, 'shape': ("; + for (size_t i = 0; i < shape.size(); ++i) { + headerStream << shape[i]; + if (i != shape.size() - 1) + headerStream << ", "; + } + if (shape.size() == 1) + headerStream << ","; + headerStream << "), }"; + + std::string header = headerStream.str(); + size_t padding = 16 - ((magic.size() + 2 + 2 + header.size()) % 16); + header.append(padding, ' '); + header += '\n'; + + out.write(magic.c_str(), magic.size()); + out.put(static_cast(major)); + out.put(static_cast(minor)); + + uint16_t hlen = static_cast(header.size()); + char buf[sizeof(hlen)]; + std::memcpy(buf, &hlen, sizeof(hlen)); + out.write(buf, sizeof(hlen)); + + out.write(header.c_str(), header.size()); + + try { + std::vector rawBytes(data.size() * sizeof(T)); + std::memcpy(rawBytes.data(), data.data(), rawBytes.size()); + out.write(rawBytes.data(), rawBytes.size()); + } catch (const std::exception &e) { + PrintSaveFailMsg(e, filePath, data.size() * sizeof(T)); + } + } + + static void SaveStringNpy(const std::vector &strings, + const std::string &filePath, + const std::vector &shape) { + if (strings.empty()) { + throw MsprobeException("Invalid or empty string array."); + } + size_t maxStrLen = 0; + for (const auto &s : strings) { + maxStrLen = std::max(maxStrLen, s.size()); + } + std::vector flatData(strings.size() * maxStrLen, '\0'); + for (size_t i = 0; i < strings.size(); ++i) { + std::memcpy(&flatData[i * maxStrLen], strings[i].c_str(), std::min(strings[i].size(), maxStrLen)); + } + std::string descrDtype = "|S" + std::to_string(maxStrLen); + SaveRawNpy(flatData, filePath, shape, descrDtype); + } + + void SaveJson(const ordered_json &data, const std::string &path, const std::ios_base::openmode &mode) { + SafePath checker(path, SafePath::PathType::FILE, SafePath::Mode::WRITE, 0, ".json"); + std::string validatedPath = checker.Check(false); + std::ofstream outfile(validatedPath, mode); + if (outfile.is_open()) { + outfile << data.dump(Cst::INDENT_WIDTH) << std::endl; + outfile.close(); + SafePath::ChangePermission(validatedPath, SafePath::PERM_640); + } else { + LOG_ERROR << "Unable to open file! File path: " << validatedPath; + } + } + + void SaveBytes(const uint8_t *data, + const fs::path &path, + const uint64_t &dataSize, + const std::ios_base::openmode &mode) { + SafePath checker(path, SafePath::PathType::FILE, SafePath::Mode::WRITE, 0, ".bin"); + std::string validatedPath = checker.Check(false); + std::ofstream outfile(validatedPath, mode); + if (outfile.is_open()) { + try { + outfile.write(static_cast(static_cast(data)), dataSize); + outfile.close(); + SafePath::ChangePermission(validatedPath, SafePath::PERM_640); + } catch (const std::exception &e) { + PrintSaveFailMsg(e, validatedPath, dataSize); + } + } else { + LOG_ERROR << "Unable to open file. File path: " << validatedPath; + } + } + + const std::unordered_map dtypeMap = { + {"0", {sizeof(float), " &shape, + const void *data, + const uint64_t &dataSize, + const std::string &path) { + SafePath checker(path, SafePath::PathType::FILE, SafePath::Mode::WRITE, 0, ".npy"); + std::string validatedPath = checker.Check(false); + + if (dtypeKey != "13") { + auto it = dtypeMap.find(dtypeKey); + if (it != dtypeMap.end()) { + const Types::DTypeInfo &info = it->second; + const auto *bytePtr = static_cast(data); + std::vector byteVec(bytePtr, bytePtr + dataSize); + SaveRawNpy(byteVec, validatedPath, shape, info.descrDtype); + } else { + LOG_ERROR << "Unsupported dtype: " << dtypeKey; + } + } else { + auto stringVecPtr = static_cast *>(data); + SaveStringNpy(*stringVecPtr, validatedPath, shape); + } + } + + void SaveTxt(const std::string &data, const std::string &path, const std::ios_base::openmode &mode) { + SafePath checker(path, SafePath::PathType::FILE, SafePath::Mode::WRITE, 0, ".txt"); + std::string validatedPath = checker.Check(false); + std::ofstream outfile(validatedPath, mode); + if (outfile.is_open()) { + outfile << data << std::endl; + outfile.close(); + SafePath::ChangePermission(validatedPath, SafePath::PERM_640); + } else { + LOG_ERROR << "Unable to open file. File path: " << validatedPath; + } + } +} // namespace Utility diff --git a/accuracy_tools/msprobe/csrc/utils/IO.h b/accuracy_tools/msprobe/csrc/utils/IO.h new file mode 100644 index 00000000000..01ea29f6069 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/utils/IO.h @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef IO_H +#define IO_H + +#include + +#include "nlohmann/json.hpp" + +#include "utils/DataType.h" +#include "utils/Path.h" + +namespace Utility { + using json = nlohmann::json; + using ordered_json = nlohmann::ordered_json; + + extern const std::unordered_map dtypeMap; + + void SaveJson(const ordered_json &data, const std::string &path, const std::ios_base::openmode &mode); + void + SaveBytes(const uint8_t *data, const fs::path &path, const uint64_t &dataSize, const std::ios_base::openmode &mode); + void SaveNpy(const std::string &dtypeKey, + const std::vector &shape, + const void *data, + const uint64_t &dataSize, + const std::string &path); + void SaveTxt(const std::string &data, const std::string &path, const std::ios_base::openmode &mode); +} // namespace Utility + +#endif diff --git a/accuracy_tools/msprobe/csrc/utils/Log.h b/accuracy_tools/msprobe/csrc/utils/Log.h new file mode 100644 index 00000000000..6f230e2d3d5 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/utils/Log.h @@ -0,0 +1,156 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LOG_H +#define LOG_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Utility { + class Log { + public: + static constexpr uint8_t BUF_SIZE = 32; + enum class LogLevel { DEBUG = 0, INFO, WARNING, ERROR }; + class LogStream { + public: + LogStream(LogLevel level, Log &logger) : level_(level), logger_(logger) { + } + + ~LogStream() { + if (logger_.CheckLogLevel(level_)) { + logger_.PrintLog(buffer_.str(), level_); + } + } + + template LogStream &operator<<(T &&arg) { + buffer_ << (std::forward(arg)); + return *this; + } + + private: + std::ostringstream buffer_; + LogLevel level_; + Log &logger_; + }; + + static Log &GetInstance() { + static Log instance; + return instance; + } + + LogStream DebugStream() { + return LogStream(LogLevel::DEBUG, *this); + } + LogStream InfoStream() { + return LogStream(LogLevel::INFO, *this); + } + LogStream WarningStream() { + return LogStream(LogLevel::WARNING, *this); + } + LogStream ErrorStream() { + return LogStream(LogLevel::ERROR, *this); + } + + void SetLogLevel(int level) { + logLv_ = level; + } + + int GetLogLevel() const { + return logLv_; + } + + bool CheckLogLevel(LogLevel level) const { + return static_cast(level) >= logLv_; + } + + void PrintLog(const std::string &msg, LogLevel lv) { + std::lock_guard lock(logMutex_); + std::string filteredMsg = msg; + filterSpecialChar(filteredMsg); + char buf[BUF_SIZE]; + auto now = std::chrono::system_clock::now(); + std::time_t time = std::chrono::system_clock::to_time_t(now); + std::tm *tm = std::localtime(&time); + std::strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", tm); + printf("%s (msprobe) (PID %ld) [%s] %s\n", buf, GetPid(), LogLevelString.at(lv), filteredMsg.c_str()); + fflush(stdout); + } + + private: + Log() = default; + ~Log() = default; + void filterSpecialChar(std::string &msg) { + for (const auto &s : SpecialChar) { + size_t pos = 0; + while ((pos = msg.find(s, pos)) != std::string::npos) { + msg.replace(pos, s.length(), "_"); + pos += 1; + } + } + } + + static uint64_t GetPid() { + return static_cast(getpid()); + } + + std::mutex logMutex_; + int logLv_ = 1; + const std::map LogLevelString = { + {LogLevel::DEBUG, "DEBUG"}, + {LogLevel::INFO, "INFO"}, + {LogLevel::WARNING, "WARNING"}, + {LogLevel::ERROR, "ERROR"}, + }; + const std::unordered_set SpecialChar = { + "\n", + "\r", + "\u007f", + "\b", + "\f", + "\t", + "\v", + "\u000b", + "%08", + "%09", + "%0a", + "%0b", + "%0c", + "%0d", + "%7f", + "//", + "\\", + "&", + }; + }; + +#define LOG_DEBUG Utility::Log::GetInstance().DebugStream() +#define LOG_INFO Utility::Log::GetInstance().InfoStream() +#define LOG_WARNING Utility::Log::GetInstance().WarningStream() +#define LOG_ERROR Utility::Log::GetInstance().ErrorStream() + +} // namespace Utility + +#endif diff --git a/accuracy_tools/msprobe/csrc/utils/Path.cpp b/accuracy_tools/msprobe/csrc/utils/Path.cpp new file mode 100644 index 00000000000..1ee2d1a6b6e --- /dev/null +++ b/accuracy_tools/msprobe/csrc/utils/Path.cpp @@ -0,0 +1,237 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/Path.h" + +#include +#include +#include +#include +#include + +#include "utils/Constant.h" +#include "utils/Exception.h" +#include "utils/Log.h" +#include "utils/Str.h" + +namespace Utility { + const int BYTE_TO_MB_SHIFT = 20; + + static int GetFreeSpace(const std::string &path, uint64_t *freeSpace) { + struct statvfs diskInfo; + if (statvfs(path.c_str(), &diskInfo) == -1) { + LOG_ERROR << "statvfs() error: " << errno; + return 1; + } + *freeSpace = diskInfo.f_bavail * diskInfo.f_bsize; + return 0; + } + + SafePath::SafePath( + std::string path, PathType pathType, Mode mode, uint64_t sizeLimit, std::string suffix, int maxDepth) + : formalPath(path), formalPathType(pathType), formalMode(mode), formalSizeLimit(sizeLimit), + formalSuffix(suffix), formalMaxDirDepth(maxDepth) { + } + + std::string SafePath::Check(bool pathExist, SoftLinkLevel linkLevel) { + formalPath = fs::absolute(fs::path(formalPath)).string(); + CheckSpecialChars(); + CheckPathLength(); + if (formalMode == Mode::WRITE && !pathExist) { + std::string parent = fs::absolute(fs::path(formalPath).parent_path()).string(); + CheckExists(parent); // The current directory doesn't exist, but the parent directory does. + parent = CheckSoftLink(parent, linkLevel); + EnsureDirectory(parent); + CheckPermissions(parent); + } else { + CheckExists(formalPath); + formalPath = CheckSoftLink(formalPath, linkLevel); + if (formalPathType == PathType::FILE) { + EnsureFile(formalPath); + CheckFileSuffix(); + CheckFileSize(); + } else if (formalPathType == PathType::DIRECTORY) { + EnsureDirectory(formalPath); + CheckDirectorySize(); + } + CheckPermissions(formalPath); + } + if (formalPathType == PathType::DIRECTORY && !StrSpace::IsSuffix(formalPath, "/")) { + formalPath += "/"; + } + return formalPath; + } + + void SafePath::CheckExists(const std::string &p) const { + if (!fs::exists(p)) { + throw MsprobeException("Path not found: " + p); + } + } + + void SafePath::EnsureDirectory(const std::string &p) const { + if (!fs::is_directory(p)) { + throw MsprobeException("Path is not a directory: " + p); + } + } + + void SafePath::EnsureFile(const std::string &p) const { + if (!fs::is_regular_file(p)) { + throw MsprobeException("Path is not a file: " + p); + } + } + + std::string SafePath::CheckSoftLink(const std::string &p, SoftLinkLevel level) const { + if (!fs::is_symlink(p)) + return p; + std::string realPath = fs::read_symlink(p).string(); + switch (level) { + case SoftLinkLevel::STRICT: + throw MsprobeException("Symlinks are prohibited: " + p); + case SoftLinkLevel::WARNING: + LOG_WARNING << "Symlink detected: " << p << " -> " << realPath; + break; + case SoftLinkLevel::IGNORE: + break; + } + return realPath; + } + + void SafePath::CheckPermissions(const std::string &p) const { + struct stat statbuf{}; + if (stat(p.c_str(), &statbuf) != 0) { + throw MsprobeException("Cannot stat path: " + p); + } + uid_t uid = geteuid(); + if (statbuf.st_uid != uid && statbuf.st_uid != 0) { + throw MsprobeException("Path must be owned by current user or root: " + p); + } + mode_t perm = statbuf.st_mode; + if ((perm & S_IWGRP) || (perm & S_IWOTH)) { + throw MsprobeException("Path writable by group/others: " + p); + } + if (formalMode == Mode::READ && !(perm & S_IRUSR)) { + throw MsprobeException("No read permission: " + p); + } + if (formalMode == Mode::WRITE && !(perm & S_IWUSR)) { + throw MsprobeException("No write permission: " + p); + } + if (formalMode == Mode::EXECUTE && !(perm & S_IXUSR)) { + throw MsprobeException("No execute permission: " + p); + } + } + + void SafePath::CheckSpecialChars() const { + std::regex valid(SPECIAL_CHAR_WHITE_LIST); + if (!std::regex_match(formalPath, valid)) { + throw MsprobeException("Path contains invalid characters: " + formalPath); + } + } + + void SafePath::CheckPathLength() const { + if (formalPath.length() > MAX_PATH_LENGTH) { + throw MsprobeException("Path length exceeds limit: " + std::to_string(formalPath.length())); + } + int depth = 0; + for (const auto &part : fs::path(formalPath)) { + if (++depth > formalMaxDirDepth) { + throw MsprobeException("Exceeded max directory depth: " + std::to_string(formalMaxDirDepth)); + } + if (part.string().length() > MAX_LAST_NAME_LENGTH) { + throw MsprobeException("Directory entry too long: " + part.string()); + } + } + } + + void SafePath::CheckFileSuffix() const { + if (!formalSuffix.empty() && !StrSpace::IsSuffix(formalPath, formalSuffix)) { + throw MsprobeException("File does not have expected suffix: " + formalSuffix); + } + } + + void SafePath::CheckFileSize() const { + if (formalSizeLimit > 0 && fs::file_size(formalPath) > static_cast(formalSizeLimit)) { + throw MsprobeException("File size exceeds limit."); + } + } + + void SafePath::CheckDirectorySize() const { + if (formalSizeLimit <= 0) + return; + uint64_t totalSize = 0; + for (const auto &entry : fs::recursive_directory_iterator(formalPath)) { + if (fs::is_regular_file(entry)) { + totalSize += fs::file_size(entry); + } + } + if (totalSize > static_cast(formalSizeLimit)) { + throw MsprobeException("Directory size exceeds limit."); + } + } + + bool SafePath::IsDiskSpaceValid(const std::string &path, const uint64_t &dataSize) { + uint64_t freeSpace = 0; + int ret = GetFreeSpace(path, &freeSpace); + if (ret == 0) { + if (freeSpace <= SafePath::SIZE_2G || freeSpace < dataSize * BYTE_TO_MB_SHIFT) { + LOG_ERROR << "Disk space is not enough, it's must more than 2G. Now free size(MB): " + << (freeSpace >> BYTE_TO_MB_SHIFT); + return false; + } + } else { + LOG_ERROR << "Failed to get disk space for path: " << path; + return false; + } + return true; + } + + void SafePath::ChangePermission(const fs::path &path, const fs::perms &permission) { + if (!fs::exists(path) || fs::is_symlink(path)) { + return; + } + try { + fs::permissions(path, permission, fs::perm_options::replace); + } catch (const fs::filesystem_error &e) { + throw MsprobeException("Failed to set permissions for " + path.string() + ": " + e.what()); + } + } + + void SafePath::MakeParentDir(const fs::path &path) { + fs::path parentPath = path.parent_path(); + if (!fs::exists(parentPath)) { + int depth = 0; + for (const auto &part : parentPath) { + if (!part.empty()) { + ++depth; + } + } + if (depth > MAX_DIR_DEPTH) { + throw MsprobeException("Exceeded max directory depth: " + std::to_string(MAX_DIR_DEPTH)); + } + fs::create_directories(parentPath); + } + SafePath::ChangePermission(parentPath, SafePath::PERM_750); + } + + fs::path GetMsprobeDir() { + std::string pathEnvVar = StrSpace::GetEnvVar(Cst::LINK_DUMP_PATH); + if (!pathEnvVar.empty()) { + return pathEnvVar; + } else { + throw MsprobeException("Dump dir has not been set."); + } + } + +} // namespace Utility diff --git a/accuracy_tools/msprobe/csrc/utils/Path.h b/accuracy_tools/msprobe/csrc/utils/Path.h new file mode 100644 index 00000000000..1ae3bd070c0 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/utils/Path.h @@ -0,0 +1,78 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SAFE_PATH_H +#define SAFE_PATH_H + +#include +#include + +namespace Utility { + namespace fs = std::filesystem; + + class SafePath { + public: + static inline fs::perms PERM_750 = fs::perms::owner_all | fs::perms::group_read | fs::perms::group_exec; + static inline fs::perms PERM_640 = fs::perms::owner_read | fs::perms::owner_write | fs::perms::group_read; + static constexpr size_t SIZE_2G = 2ULL * 1024 * 1024 * 1024; + static constexpr size_t SIZE_10G = 10ULL * 1024 * 1024 * 1024; + static constexpr size_t SIZE_30G = 30ULL * 1024 * 1024 * 1024; + + static constexpr int MAX_PATH_LENGTH = 4096; + static constexpr int MAX_DIR_DEPTH = 32; + static constexpr int MAX_LAST_NAME_LENGTH = 255; + static inline const std::string SPECIAL_CHAR_WHITE_LIST = "^[a-zA-Z0-9_./-]+$"; + + enum class PathType { FILE, DIRECTORY }; + enum class Mode { READ, WRITE, EXECUTE }; + enum class SoftLinkLevel { IGNORE, WARNING, STRICT }; + + public: + static bool IsDiskSpaceValid(const std::string &path, const uint64_t &dataSize); + static void ChangePermission(const fs::path &path, const fs::perms &permission); + static void MakeParentDir(const fs::path &path); + SafePath(std::string path, + PathType pathType, + Mode mode, + uint64_t sizeLimit = 0, + std::string suffix = "", + int maxDepth = MAX_DIR_DEPTH); + std::string Check(bool pathExist = true, SoftLinkLevel linkLevel = SoftLinkLevel::STRICT); + + private: + std::string formalPath; + PathType formalPathType; + Mode formalMode; + int formalSizeLimit; + std::string formalSuffix; + int formalMaxDirDepth; + + void CheckExists(const std::string &p) const; + void EnsureDirectory(const std::string &p) const; + void EnsureFile(const std::string &p) const; + std::string CheckSoftLink(const std::string &p, SoftLinkLevel level) const; + void CheckPermissions(const std::string &p) const; + void CheckSpecialChars() const; + void CheckPathLength() const; + void CheckFileSuffix() const; + void CheckFileSize() const; + void CheckDirectorySize() const; + }; + + fs::path GetMsprobeDir(); +} // namespace Utility + +#endif diff --git a/accuracy_tools/msprobe/csrc/utils/Str.cpp b/accuracy_tools/msprobe/csrc/utils/Str.cpp new file mode 100644 index 00000000000..63df927900b --- /dev/null +++ b/accuracy_tools/msprobe/csrc/utils/Str.cpp @@ -0,0 +1,197 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/Str.h" + +#include +#include +#include +#include + +#include "utils/Constant.h" +#include "utils/DataType.h" +#include "utils/Exception.h" +#include "utils/Log.h" + +namespace StrSpace { + const uint16_t MAX_LENGTH = 512; + + std::string GetEnvVar(const char *varName) { + const char *envVar = std::getenv(varName); + return (envVar != nullptr) ? std::string(envVar) : ""; + } + + std::vector Split(const std::string &inputString, const std::string &delimiter) { + if (delimiter.empty()) { + throw Utility::MsprobeException("delimiter is null."); + } + std::vector tokens; + size_t start = 0, end = 0; + while ((end = inputString.find(delimiter, start)) != std::string::npos) { + if (end != start) { + tokens.push_back(inputString.substr(start, end - start)); + } + start = end + delimiter.length(); + } + if (start < inputString.length()) { + tokens.push_back(inputString.substr(start)); + } + return tokens; + } + + uint64_t Str2Int(const char *inputString, const uint64_t &defaultValue, const char *tag) { + std::string tagStr = (tag != nullptr) ? std::string(tag) : "null"; + std::string tagInfo = "<" + tagStr + ">"; + if (inputString == nullptr) { + LOG_WARNING << "Input string is null. " << tagInfo << " Using default value: " << defaultValue; + return defaultValue; + } + std::string str(inputString); + if (str.empty()) { + LOG_INFO << "Input string is empty. " << tagInfo << " Using default value: " << defaultValue; + return defaultValue; + } + try { + size_t idx = 0; + int value = std::stoi(str, &idx); + if (idx != str.length()) { + LOG_WARNING << "Partial conversion: '" << str << "' -> " << value << ", unparsed: '" << str.substr(idx) + << "'. " << tagInfo; + } + return value; + } catch (const std::invalid_argument &e) { + LOG_ERROR << "Invalid argument: cannot convert '" << str << "' to int. " << tagInfo + << " Using default value: " << defaultValue; + } catch (const std::out_of_range &e) { + LOG_ERROR << "Out of range: cannot convert '" << str << "' to int. " << tagInfo + << " Using default value: " << defaultValue; + } catch (...) { + LOG_ERROR << "Unknown exception when converting '" << str << "' to int. " << tagInfo; + } + return defaultValue; + } + + std::vector SplitToInt(const std::string &inputString, const std::string &delimiter) { + auto str_vec = Split(inputString, delimiter); + std::vector int_vec; + for (const std::string &s : str_vec) { + uint64_t val = Str2Int(s.c_str(), 0, nullptr); + int_vec.push_back(val); + } + return int_vec; + } + + std::vector + Str2IntForVector(const std::vector &strVec, const int &defaultValue, const char *tag) { + std::vector IntVec; + IntVec.reserve(strVec.size()); + for (const auto &strEle : strVec) { + uint64_t val = Str2Int(strEle.c_str(), defaultValue, tag); + IntVec.push_back(static_cast(val)); + } + return IntVec; + } + + bool IsValueInGoal(const char *varName, const uint64_t &query) { + const std::string vvEnvVar = GetEnvVar(varName); // 5,6,7,8 + LOG_DEBUG << "IsValueInGoal: " << vvEnvVar << " for " << varName; + if (vvEnvVar.empty()) { + return true; + } + const std::vector vvString = Split(vvEnvVar, ","); + if (vvString.empty()) { + return false; + } + std::vector vvInt = Str2IntForVector(vvString, 0, varName); + if (vvInt.empty()) { + return false; + } + return std::find(vvInt.begin(), vvInt.end(), query) != vvInt.end(); + } + + bool IsValueInGoal(const char *varName, const std::string &query) { + const std::string vvEnvVar = GetEnvVar(varName); // "L0,L1,L2" + LOG_DEBUG << "IsValueInGoal: " << vvEnvVar << " for " << varName; + if (vvEnvVar.empty()) { + return true; + } + const std::vector vvString = Split(vvEnvVar, ","); + if (vvString.empty()) { + return false; + } + return std::find(vvString.begin(), vvString.end(), query) != vvString.end(); + } + + bool IsPrefix(const std::string &str, const std::string &prefix) { + return str.compare(0, prefix.length(), prefix) == 0; + } + + bool IsSuffix(const std::string &str, const std::string &suffix) { + if (str.length() < suffix.length()) { + return false; + } + return str.compare(str.length() - suffix.length(), suffix.length(), suffix) == 0; + } + + std::string ToLower(const std::string &input) { + std::string result = input; + for (char &c : result) { + c = std::tolower(c); + } + return result; + } + + bool IsStringLengthSafety(const std::string &ss) { + return (ss.size() <= MAX_LENGTH) && (ss.size() > 0); + } + + bool IsStringLengthSafety(const std::vector &vec) { + for (const auto &s : vec) { + if (!IsStringLengthSafety(s)) { + return false; + } + } + return true; + } + + ordered_json Str2Json(const std::string &value) { + try { + ordered_json value2dict; + value2dict = ordered_json::parse(value); + return value2dict; + } catch (const ordered_json::parse_error &ex) { + LOG_ERROR << ex.what() << ". Exception id: " << ex.id << ". Byte position of error: " << ex.byte; + return ordered_json(); + } + } + + std::unordered_map + Str2Map(const std::string &line, const std::string &delimiter, const std::string joint) { + // line: xxx1=xxx2;xxx3=xxx4;xxx5=xxx6 + std::vector pairs = Split(line, delimiter); + std::unordered_map kvs; + for (const auto &pair : pairs) { + size_t pos = pair.find(joint); + if (pos != std::string::npos) { + std::string key = pair.substr(0, pos); + std::string value = pair.substr(pos + 1); + kvs[key] = value; + } + } + return kvs; + } + +} // namespace StrSpace diff --git a/accuracy_tools/msprobe/csrc/utils/Str.h b/accuracy_tools/msprobe/csrc/utils/Str.h new file mode 100644 index 00000000000..7414e7603f9 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/utils/Str.h @@ -0,0 +1,64 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef STR_METHOD_H +#define STR_METHOD_H + +#include +#include +#include +#include +#include + +#include "nlohmann/json.hpp" + +namespace StrSpace { + using ordered_json = nlohmann::ordered_json; + + std::string GetEnvVar(const char *varName); + std::vector Split(const std::string &inputString, const std::string &delimiter); + uint64_t Str2Int(const char *inputString, const uint64_t &defaultValue, const char *tag); + std::vector SplitToInt(const std::string &inputString, const std::string &delimiter); + std::vector + Str2IntForVector(const std::vector &strVec, const int &defaultValue, const char *tag); + bool IsValueInGoal(const char *varName, const uint64_t &query); + bool IsValueInGoal(const char *varName, const std::string &query); + + bool IsPrefix(const std::string &str, const std::string &prefix); + bool IsSuffix(const std::string &str, const std::string &suffix); + std::string ToLower(const std::string &input); + bool IsStringLengthSafety(const std::string &ss); + bool IsStringLengthSafety(const std::vector &vec); + ordered_json Str2Json(const std::string &value); + std::unordered_map + Str2Map(const std::string &line, const std::string &delimiter, const std::string joint); + + template std::string Join(const Container &container, const std::string &delimiter) { + std::stringstream ss; + auto it = container.begin(); + if (it != container.end()) { + ss << *it; + ++it; + } + for (; it != container.end(); ++it) { + ss << delimiter; // 确保 delimiter 是 std::string + ss << *it; + } + return ss.str(); + } +} // namespace StrSpace + +#endif diff --git a/accuracy_tools/msprobe/utils/__init__.py b/accuracy_tools/msprobe/utils/__init__.py new file mode 100644 index 00000000000..53529bc8d31 --- /dev/null +++ b/accuracy_tools/msprobe/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/accuracy_tools/msprobe/utils/constants.py b/accuracy_tools/msprobe/utils/constants.py new file mode 100644 index 00000000000..6872a8999b5 --- /dev/null +++ b/accuracy_tools/msprobe/utils/constants.py @@ -0,0 +1,201 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class CmdConst: + """ + Class for command line const + """ + + DUMP = "dump" + COMPARE = "compare" + + HELP_SERVICE_MAP = {DUMP: "Data collection for Ascend device.", COMPARE: "Accuracy compare for dump task."} + HELP_TASK_MAP = {} + + +class PathConst: + """ + Class for file or dir path const + """ + + FILE = "file" + DIR = "dir" + + SIZE_20M = 20_971_520 # 20 * 1024 * 1024 + SIZE_500M = 524_288_000 # 500 * 1024 * 1024 + SIZE_2G = 2_147_483_648 # 2 * 1024 * 1024 * 1024 + SIZE_4G = 4_294_967_296 # 4 * 1024 * 1024 * 1024 + SIZE_10G = 10_737_418_240 # 10 * 1024 * 1024 * 1024 + SIZE_30G = 32_212_254_720 # 30 * 1024 * 1024 * 1024 + SIZE_50G = 53_687_091_200 # 50 * 1024 * 1024 * 1024 + + SUFFIX_ONLINE_SCRIPT = (".py", ".sh") + SUFFIX_OFFLINE_MODEL = (".pb", ".onnx", ".om", ".prototxt") + + +class MsgConst: + """ + Class for log messages const + """ + + INVALID_ARGU = "[ERROR] invalid argument." + INVALID_DATA_TYPE = "[ERROR] invalid data type." + REQUIRED_ARGU_MISSING = "[ERROR] Required argument missing." + RISK_ALERT = "[ERROR] Risk alert." + NO_PERMISSION = "[ERROR] No permission." + IO_FAILURE = "[ERROR] I/O failure." + PATH_NOT_FOUND = "[ERROR] Path not found." + VALUE_NOT_FOUND = "[ERROR] Value not found." + PARSING_FAILED = "[ERROR] Parsing failed." + CANN_FAILED = "[ERROR] CANN enabling failed." + ATTRIBUTE_ERROR = "[ERROR] Attribute not found." + CALL_FAILED = "[ERROR] Call failed." + CONVERSION_FAILED = "[ERROR] Conversion failed." + MAX_RECURSION_DEPTH = 5 + + +class CompConst: + """ + Class for component name const + """ + + DUMP_WRITER_COMP = "DumpWriterComp" + ACL_DUMPER_COMP = "ACLDumperComp" + ONNX_ACTUATOR_COMP = "OnnxActuatorComp" + ONNX_DUMPER_COMP = "OnnxDumperComp" + FROZEN_GRAPH_ACTUATOR_COMP_CPU = "FrozenGraphActuatorCompCPU" + FROZEN_GRAPH_DUMPER_COMP_CPU = "FrozenGraphDumperCompCPU" + FROZEN_GRAPH_ACTUATOR_COMP_NPU = "FrozenGraphActuatorCompNPU" + FROZEN_GRAPH_SET_GE_COMP_NPU = "FrozenGraphSetGECompNPU" + CAFFE_ACTUATOR_COMP = "CaffeActuatorComp" + CAFFE_DUMPER_COMP = "CaffeDumperComp" + ATB_ACTUATOR_COMP = "ATBActuatorComp" + OM_ACTUATOR_COMP = "OmActuatorComp" + ACL_COMPATIBLE_COMP = "ACLCompatibleComp" + + +class CfgConst: + """ + Class for config items + """ + + CONFIG_PATH = "config_path" + TASK = "task" + TASK_STAT = "statistics" + TASK_TENSOR = "tensor" + ALL_TASK = {TASK_STAT, TASK_TENSOR} + EXEC = "exec" + FRAMEWORK = "framework" + FRAMEWORK_MINDIE_LLM = "mindie_llm" + FRAMEWORK_TORCH_AIR = "torch_air" + FRAMEWORK_MINDIE_TORCH = "mindie_torch" + FRAMEWORK_PT = "pytorch" + FRAMEWORK_MS = "mindspore" + FRAMEWORK_ONNX = "ONNX" + FRAMEWORK_TF = "TensorFlow" + FRAMEWORK_OM = "Ascend OM" + FRAMEWORK_CAFFE = "Caffe" + ALL_FRAMEWORK = {FRAMEWORK_MINDIE_LLM, FRAMEWORK_TORCH_AIR, FRAMEWORK_MINDIE_TORCH, FRAMEWORK_PT, FRAMEWORK_MS} + RANK = "rank" + STEP = "step" + LEVEL = "level" + LEVEL_MODULE = "L0" + LEVEL_API = "L1" + LEVEL_KERNEL = "L2" + ALL_LEVEL = {LEVEL_MODULE, LEVEL_API, LEVEL_KERNEL} + LOG_LEVEL = "log_level" + SEED = "seed" + BUFFER_SIZE = "buffer_size" + + +class DumpConst: + """ + Class for dump const + """ + + DEVICE = "device" + INPUT_ARGS = "input_args" + OUTPUT_ARGS = "output_args" + INPUT = "input" + OUTPUT = "output" + INPUT_ALL = [INPUT, "all"] + OUTPUT_ALL = [OUTPUT, "all"] + ALL_DATA_MODE = [INPUT, OUTPUT, "all"] + + DUMP_PATH = "dump_path" + LIST = "list" + DATA_MODE = "data_mode" + SUMMARY_MODE = "summary_mode" + SUMMARY_MD5 = "md5" + ALL_SUMMARY_MODE = {CfgConst.TASK_STAT, SUMMARY_MD5} + DUMP_EXTRA = "dump_extra" + ALL_DUMP_EXTRA = {"tiling", "cpu_profiling", "kernel_info", "op_info"} + OP_ID = "op_id" + DUMP_LAST_LOGITS = "dump_last_logits" + DUMP_WEIGHT = "dump_weight" + DUMP_GE_GRAPH = "dump_ge_graph" + ALL_DUMP_GE_GRAPH = {"1", "2", "3"} + DUMP_GRAPH_LEVEL = "dump_graph_level" + ALL_DUMP_GRAPH_LEVEL = {"1", "2", "3", "4"} + FUSION_SWITCH_FILE = "fusion_switch_file" + ONNX_FUSION_switch = "onnx_fusion_switch" + SAVED_MODEL_TAG = "saved_model_tag" + SAVED_MODEL_SIGN = "saved_model_signature" + WEIGHT_PATH = "weight_path" + + DUMP_DATA_DIR = "dump_data_dir" + DATA = "data" + DUMP_JSON = "dump.json" + STACK_JSON = "stack.json" + NPY_FORMAT = "npy_format" + BIN_FORMAT = "bin_format" + NET_OUTPUT_NODES_JSON = "net_output_nodes.json" + + ENVVAR_DUMP_GE_GRAPH = "DUMP_GE_GRAPH" + ENVVAR_DUMP_GRAPH_LEVEL = "DUMP_GRAPH_LEVEL" + ENVVAR_DUMP_GRAPH_PATH = "DUMP_GRAPH_PATH" + ENVVAR_ASCEND_WORK_PATH = "ASCEND_WORK_PATH" + + ENVVAR_LINK_DUMP_PATH = "LINK_DUMP_PATH" + ENVVAR_LINK_DUMP_TASK = "LINK_DUMP_TASK" + ENVVAR_LINK_DUMP_LEVEL = "LINK_DUMP_LEVEL" + ENVVAR_LINK_STEP = "LINK_STEP" + ENVVAR_LINK_RANK = "LINK_RANK" + ENVVAR_LINK_LOG_LEVEL = "LINK_LOG_LEVEL" + ENVVAR_LINK_SUMMARY_MODE = "LINK_SUMMARY_MODE" + ENVVAR_LINK_BUFFER_SIZE = "LINK_BUFFER_SIZE" + ENVVAR_LINK_DATA_MODE = "LINK_DATA_MODE" + ENVVAR_LINK_SAVE_TILING = "LINK_SAVE_TILING" + ENVVAR_LINK_SAVE_CPU_PROFILING = "LINK_SAVE_CPU_PROFILING" + ENVVAR_LINK_SAVE_ONNX = "LINK_SAVE_ONNX" + ENVVAR_LINK_SAVE_KERNEL_INFO = "LINK_SAVE_KERNEL_INFO" + ENVVAR_LINK_SAVE_OP_INFO = "LINK_SAVE_OP_INFO" + ENVVAR_LINK_SAVE_PARAM = "LINK_SAVE_PARAM" + ENVVAR_LINK_SAVE_TENSOR_IDS = "LINK_SAVE_TENSOR_IDS" + ENVVAR_LINK_SAVE_TENSOR_RUNNER = "LINK_SAVE_TENSOR_RUNNER" + + +class ACLConst: + """ + Class for Ascendcl const + """ + + SUCCESS = 0 + MEMCPY_HOST_TO_DEVICE = 1 + MEMCPY_DEVICE_TO_HOST = 2 + IS_LAST_CHUNK = "is_last_chunk" + BUF_LEN = "buf_len" + FILE_NAME = "file_name" + DATA_BUF = "data_buf" diff --git a/accuracy_tools/msprobe/utils/dependencies.py b/accuracy_tools/msprobe/utils/dependencies.py new file mode 100644 index 00000000000..1e790a8ec09 --- /dev/null +++ b/accuracy_tools/msprobe/utils/dependencies.py @@ -0,0 +1,96 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import wraps +from importlib import import_module + +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger + +import_warnings_shown = set() + + +def safely_import(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception: + dependency = args[1] + if dependency not in import_warnings_shown: + logger.warning(f"{dependency} is not installed. Please install it if needed.") + import_warnings_shown.add(dependency) + return None + + return wrapper + + +def temporary_tf_log_level(func): + @wraps(func) + def wrapper(*args, **kwargs): + original_log_level = os.environ.get("TF_CPP_MIN_LOG_LEVEL", "0") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # 只打印 warning、error + result = func(*args, **kwargs) + os.environ["TF_CPP_MIN_LOG_LEVEL"] = original_log_level + return result + + return wrapper + + +class DependencyManager: + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(DependencyManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + self._dependencies = {} + + def get(self, package_name): + return self._dependencies.get(package_name, self._import_package(package_name)) + + def get_tensorflow(self): + tf = self.get("tensorflow") + re_writer_config = self.get("tensorflow/RewriterConfig") + sm2pb = self.get("tensorflow/convert_variables_to_constants") + return tf, re_writer_config, sm2pb + + @safely_import + def _import_package(self, package_name): + if package_name in self._dependencies: + return self._dependencies[package_name] + if package_name == "tensorflow": + return self._import_tensorflow() + module = import_module(package_name) + self._dependencies[package_name] = module + return module + + @temporary_tf_log_level + def _import_tensorflow(self): + module = import_module("tensorflow") + if module.__version__ != "2.6.5": + raise MsprobeException("[ERROR] Incompatible versions. Currently only supports TensorFlow v2.6.5.") + from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig + from tensorflow.python.framework.graph_util import convert_variables_to_constants + + self._dependencies["tensorflow/convert_variables_to_constants"] = convert_variables_to_constants + self._dependencies["tensorflow/RewriterConfig"] = RewriterConfig + self._dependencies["tensorflow"] = module + return module + + +dependent = DependencyManager() diff --git a/accuracy_tools/msprobe/utils/env.py b/accuracy_tools/msprobe/utils/env.py new file mode 100644 index 00000000000..a74859c4c3a --- /dev/null +++ b/accuracy_tools/msprobe/utils/env.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from msprobe.utils.constants import MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger + + +class EnvVarManager: + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(EnvVarManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + self.prefix = "" + + @staticmethod + def _log(msg): + logger.debug(msg) + + def set_prefix(self, prefix): + self.prefix = prefix + + def get(self, key, default=None, cast_type=None, required=True): + value = os.environ.get(key, default) + self._log(f"Accessed environment variable {key}, Value: {value}.") + if required and value is None: + raise MsprobeException( + MsgConst.REQUIRED_ARGU_MISSING, + f"Environment variable {key} is required but not set. " + f"Please check the current environment configuration by `echo ${key}`.", + ) + if value is not None and cast_type: + try: + value = cast_type(value) + self._log(f"Casted {key} to {cast_type.__name__}, Result: {value}.") + except Exception as e: + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, f"Failed to cast environment variable {key} to {cast_type}." + ) from e + return value + + def set(self, key, value): + os.environ[key] = str(value) + self._log(f"Set environment variable {key} to {value}.") + + def delete(self, key): + if key in os.environ: + os.environ.pop(key, None) + self._log(f"Deleted environment variable {key}.") + else: + self._log(f"{key} not found to delete.") + + def list_all(self): + if self.prefix: + filtered_env = {k: v for k, v in os.environ.items() if k.startswith(self.prefix)} + self._log(f"Listed environment variables with prefix {self.prefix}: {filtered_env}.") + return filtered_env + else: + self._log(f"Listed all environment variables: {dict(os.environ)}.") + return dict(os.environ) + + +evars = EnvVarManager() diff --git a/accuracy_tools/msprobe/utils/exceptions.py b/accuracy_tools/msprobe/utils/exceptions.py new file mode 100644 index 00000000000..55c51b0a794 --- /dev/null +++ b/accuracy_tools/msprobe/utils/exceptions.py @@ -0,0 +1,22 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class MsprobeException(Exception, object): + def __init__(self, error_group, error_msg=""): + super().__init__() + self.error_msg = " ".join([error_group, error_msg]) + + def __str__(self): + return self.error_msg diff --git a/accuracy_tools/msprobe/utils/hijack.py b/accuracy_tools/msprobe/utils/hijack.py new file mode 100644 index 00000000000..91d31d8b842 --- /dev/null +++ b/accuracy_tools/msprobe/utils/hijack.py @@ -0,0 +1,405 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from abc import ABC, abstractmethod +from collections import defaultdict +from enum import Enum +from importlib.abc import Loader, MetaPathFinder +from importlib.util import module_from_spec, spec_from_loader +from uuid import uuid4 + +from msprobe.utils.constants import MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.toolkits import check_int_border + + +class ActionType(Enum): + REPLACE = 0 + PRE_HOOK = 1 + POST_HOOK = 2 + + +class HijackHandler: + def __init__(self, unit): + self.unit = unit + self.call_count = 0 + self.call_data = defaultdict(dict) + self.released = False + + +def hijacker( + *, + stub: callable, + module: str, + cls: str = "", + function: str = "", + action: ActionType = ActionType.REPLACE, + priority: int = 100, +) -> str: + """ + Hijack module-import process or function execution process. + Support attaching pre/post hooks to the process, or replacing function implementations. + + .. target:: + When only set "module": module + When set "module" and "function": function in module + When set "module", "cls" and "function": function in class + + .. warning:: + The pre-hook of the module-import process will only take effect if it is set before the module is imported. + If the module is modified in its post-hook, the impact cannot be restored even if the hijacking is released. + + Parameters + ---------- + stub: Callable object. + Follow different format under different target and action. + --------------------------------------------------------------------------------------------------------------- + | target | action | format | description | + |-------------------------------------------------------------------------------------------------------------| + | module | pre-hook | callable() | Called before module import. | + |-------------------------------------------------------------------------------------------------------------| + | module | post-hook | callable(m) | Called after module import. "m" is the module. | + |-------------------------------------------------------------------------------------------------------------| + | function | replace | ret = callable(*args, **kws) | Replace original object. | + |-------------------------------------------------------------------------------------------------------------| + | function | pre-hook | args, kws = | Called before function execution, and the return will | + | | | callable(*args) | replace original input of the target function. | + |-------------------------------------------------------------------------------------------------------------| + | function | post-hook | ret = callable(ret, *args, | Called after function execution, and the return will | + | | | **kws) | replace original return of target function. | + --------------------------------------------------------------------------------------------------------------- + module: str + Full name of target module. + cls: str, optional + Full name of target class. + function: str, optional + Name of target function. + action: enum, optional + Choose between REPLACE, PRE_HOOK, and POST_HOOK. + priority: int, optional + The smaller the value is, the higher the priority is. When multiple hooks are set on the same target, they will + be excuted by priority. + + Returns + ------- + hander: + Handler to a hijacking. E.g., handler.unit, handler.call_data, handler.released. + """ + HiJackerManager.initialize() + unit = HijackerUnit(stub, module, cls, function, action, priority) + handler = HijackHandler(unit) + unit.handler = handler + HiJackerManager.add_unit(unit) + return handler + + +def release(handler): + """ + Cancel a hijacking. "handler" is returned by function "hijack". + """ + if isinstance(handler, HijackHandler): + handler.released = True + HiJackerManager.remove_unit(handler.unit) + else: + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "Handler must be an instance of HijackHandler.") + + +class HijackerUnit: + def __init__(self, stub, module, cls, function, action, priority): + self.stub = stub + self.module = module + self.cls = cls + self.function = function + self.action = action + self.priority = priority + self.target = f"{module}-{cls}-{function}" + self.handler = None + self._check_para_valid() + + def _check_para_valid(self): + if not callable(self.stub): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"stub" should be callable.') + if not self.module: + raise MsprobeException(MsgConst.REQUIRED_ARGU_MISSING, '"module" is required.') + if not isinstance(self.module, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"module" should be a str.') + if self.cls and not isinstance(self.cls, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"cls" should be a str.') + if self.cls and not self.function: + raise MsprobeException(MsgConst.REQUIRED_ARGU_MISSING, '"function" should be used when "cls" used.') + if self.function and not isinstance(self.function, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"function" should be a str.') + if not isinstance(self.action, ActionType): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"action" should be an ActionType.') + if not self.cls and not self.function and self.action == ActionType.REPLACE: + raise MsprobeException(MsgConst.INVALID_ARGU, "replacement of a module is not supported") + check_int_border(self.priority, tag="priority of HijackerUnit") + + +class HiJackerWrapperObj(ABC): + def __init__(self, name): + self.name = name + self.ori_obj = None + self.replacement = [] + self.pre_hooks = [] + self.post_hooks = [] + self.mod_name, self.class_name, self.func_name = name.split("-") + + @property + def is_empty(self): + return not self.replacement and not self.pre_hooks and not self.post_hooks + + @abstractmethod + def activate(self): + pass + + @abstractmethod + def deactivate(self): + pass + + def add_unit(self, unit): + if unit.action == ActionType.REPLACE: + self.replacement.append(unit) + self.replacement.sort(key=lambda x: x.priority) + elif unit.action == ActionType.PRE_HOOK: + self.pre_hooks.append(unit) + self.pre_hooks.sort(key=lambda x: x.priority) + else: + self.post_hooks.append(unit) + self.post_hooks.sort(key=lambda x: x.priority) + + def remove_unit(self, unit): + if unit.action == ActionType.REPLACE: + self.replacement.remove(unit) + elif unit.action == ActionType.PRE_HOOK: + self.pre_hooks.remove(unit) + else: + self.post_hooks.remove(unit) + + def set_ori_obj(self, obj): + self.ori_obj = obj + + +class HiJackerWrapperModule(HiJackerWrapperObj): + def __init__(self, name): + super().__init__(name) + + def exec_pre_hook(self): + for unit in self.pre_hooks: + unit.stub() + + def exec_post_hook(self, m): + self.set_ori_obj(m) + for unit in self.post_hooks: + unit.stub(m) + + def add_unit(self, unit): + super().add_unit(unit) + if unit.action == ActionType.POST_HOOK: + m = sys.modules.get(self.mod_name) + if m: + unit.stub(m) + + def activate(self): + HiJackerPathFinder.add_mod(self.mod_name) + + def deactivate(self): + HiJackerPathFinder.remove_mod(self.mod_name) + + +class HiJackerWrapperFunction(HiJackerWrapperObj): + def __init__(self, name): + super().__init__(name) + self.mod_hijacker = None + + def activate(self): + def replace_closure(class_name, func_name, wrapper): + def modify_module(m): + parent_obj = m + class_chain = class_name.split(".") if class_name else [] + for c in class_chain: + if not hasattr(parent_obj, c): + return + parent_obj = getattr(parent_obj, c) + if parent_obj and hasattr(parent_obj, func_name): + ori_obj = getattr(parent_obj, func_name) + self.set_ori_obj(ori_obj) + setattr(parent_obj, func_name, wrapper) + return + + return modify_module + + self.mod_hijacker = hijacker( + stub=replace_closure(self.class_name, self.func_name, self._get_wrapper()), + module=self.mod_name, + action=ActionType.POST_HOOK, + priority=0, + ) + return + + def deactivate(self): + if self.mod_hijacker: + release(self.mod_hijacker) + self.mod_hijacker = None + mod = sys.modules.get(self.mod_name) + if mod and self.ori_obj: + parent_obj = mod + class_chain = self.class_name.split(".") if self.class_name else [] + for c in class_chain: + if not hasattr(parent_obj, c): + self.ori_obj = None + return + parent_obj = getattr(parent_obj, c) + if parent_obj and hasattr(parent_obj, self.func_name): + setattr(parent_obj, self.func_name, self.ori_obj) + self.ori_obj = None + return + + def _get_wrapper(self): + def wrapper(*args, **kws): + if not self.ori_obj: + raise MsprobeException( + MsgConst.VALUE_NOT_FOUND, + "Original function object not found. Ensure activate() was called successfully.", + ) + call_index = None + for unit in self.pre_hooks + self.replacement + self.post_hooks: + if unit.handler: + unit.handler.call_count += 1 + call_index = unit.handler.call_count + unit.handler.call_data[call_index] = {"args": args, "kwargs": kws} + for unit in self.pre_hooks: + result = unit.stub(*args, **kws) + if isinstance(result, tuple): + args, kws = result + else: + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "Pre-hook must return a tuple of (args, kws)") + f = self.replacement[0].stub if self.replacement else self.ori_obj + ret = f(*args, **kws) + for unit in self.post_hooks: + ret = unit.stub(ret, *args, **kws) + if call_index: + for unit in self.pre_hooks + self.replacement + self.post_hooks: + if unit.handler: + unit.handler.call_data[call_index]["return"] = ret + return ret + + return wrapper + + +class HiJackerManager: + _initialized = False + _hijacker_units = {} + _hijacker_wrappers = {} + + @classmethod + def initialize(cls): + if cls._initialized: + return + sys.meta_path.insert(0, HiJackerPathFinder()) + cls._initialized = True + + @classmethod + def add_unit(cls, unit): + handler = uuid4().hex + cls._hijacker_units[handler] = unit + wrapper_obj = cls._hijacker_wrappers.get(unit.target) + if not wrapper_obj: + wrapper_obj = cls._build_wrapper_obj(unit.target) + cls._hijacker_wrappers[unit.target] = wrapper_obj + wrapper_obj.activate() + wrapper_obj.add_unit(unit) + return handler + + @classmethod + def remove_unit(cls, handler): + unit = cls._hijacker_units.get(handler) + if not unit: + return + wrapper_obj = cls._hijacker_wrappers.get(unit.target) + wrapper_obj.remove_unit(unit) + if wrapper_obj.is_empty: + wrapper_obj.deactivate() + del cls._hijacker_wrappers[unit.target] + del cls._hijacker_units[handler] + + @classmethod + def get_module_wrapper(cls, name): + return cls._hijacker_wrappers.get(f"{name}--") + + @classmethod + def _build_wrapper_obj(cls, name): + _, _, f = name.split("-") + if f: + return HiJackerWrapperFunction(name) + else: + return HiJackerWrapperModule(name) + + +class HiJackerPathFinder(MetaPathFinder): + _modules_of_insterest = set() + + @classmethod + def add_mod(cls, name): + cls._modules_of_insterest.add(name) + + @classmethod + def remove_mod(cls, name): + cls._modules_of_insterest.discard(name) + + def find_spec(self, fullname, path, target=None): + if fullname not in self._modules_of_insterest: + return None + for finder in sys.meta_path: + if isinstance(finder, HiJackerPathFinder): + continue + spec = finder.find_spec(fullname, path, target) + if not spec: + continue + return spec_from_loader(fullname, HiJackerLoader(spec)) + return None + + def find_module(self, fullname, path=None): + if fullname not in self._modules_of_insterest: + return None + for finder in sys.meta_path: + if isinstance(finder, HiJackerPathFinder): + continue + loader = finder.find_module(fullname, path) + if not loader: + continue + return HiJackerLoader(spec_from_loader(fullname, loader)) + return None + + +class HiJackerLoader(Loader): + def __init__(self, ori_spec): + self.ori_spec = ori_spec + + def create_module(self, spec): + module = module_from_spec(self.ori_spec) + return module + + def load_module(self, fullname): + module = self.ori_spec.loader.load_module(fullname) + return module + + def exec_module(self, module): + wrapper = HiJackerManager.get_module_wrapper(module.__name__) + if wrapper: + wrapper.exec_pre_hook() + self.ori_spec.loader.exec_module(module) + if wrapper: + wrapper.exec_post_hook(module) diff --git a/accuracy_tools/msprobe/utils/io.py b/accuracy_tools/msprobe/utils/io.py new file mode 100644 index 00000000000..77533a5ecc6 --- /dev/null +++ b/accuracy_tools/msprobe/utils/io.py @@ -0,0 +1,347 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import json +import pickle +from functools import wraps + +import numpy as np +import pandas as pd +import yaml + +from msprobe.utils.constants import MsgConst, PathConst +from msprobe.utils.dependencies import dependent +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger +from msprobe.utils.path import ( + AUTHORITY_DIR, + AUTHORITY_FILE, + SafePath, + change_permission, + get_basename_from_path, + get_file_size, + join_path, +) +from msprobe.utils.toolkits import CsvCheckLevel, is_input_yes, sanitize_csv_value + +_LOAD_ERROR = 'Failed to load the path "{}" using <{}>.' +_SAVE_ERROR = 'Failed to save {} to "{}" using <{}>. Please check permissions or disk space.' + + +class SafelyOpen: + def __init__(self, file_path, mode, file_size_limitation=None, suffix=None, path_exist=True, encoding="utf-8"): + self.file_path = SafePath(file_path, PathConst.FILE, mode, file_size_limitation, suffix).check( + path_exist=path_exist + ) + self.mode = mode + self.encoding = encoding + self._file = None + + def __enter__(self): + if "b" not in self.mode: + self._file = open(self.file_path, self.mode, encoding=self.encoding) + else: + self._file = open(self.file_path, self.mode) + return self._file + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def close(self): + if self._file and not self._file.closed: + self._file.close() + + +def _load_file(mode, file_size, file_suffix, use_safely_open: bool, encoding="utf-8"): + def decorator(func): + @wraps(func) + def wrapper(path, *args, **kwargs): + try: + if use_safely_open: + with SafelyOpen(path, mode, file_size, file_suffix, encoding) as f: + return func(f) + else: + path = SafePath(path, PathConst.FILE, mode, file_size, file_suffix).check() + return func(path, *args, **kwargs) + except Exception as e: + raise MsprobeException(MsgConst.IO_FAILURE, _LOAD_ERROR.format(path, func.__name__)) from e + + return wrapper + + return decorator + + +def _load_dir(dir_size): + def decorator(func): + @wraps(func) + def wrapper(path, *args, **kwargs): + path = SafePath(path, PathConst.DIR, "r", dir_size).check() + try: + return func(path, *args, **kwargs) + except Exception as e: + raise MsprobeException(MsgConst.IO_FAILURE, _LOAD_ERROR.format(path, func.__name__)) from e + + return wrapper + + return decorator + + +def _save_file(mode, file_size, file_suffix, use_safely_open: bool): + def decorator(func): + @wraps(func) + def wrapper(data, path, *args, **kwargs): + try: + if use_safely_open: + with SafelyOpen(path, mode, file_size, file_suffix, path_exist=False) as f: + func(data, f, *args, **kwargs) + else: + path = SafePath(path, PathConst.FILE, mode, file_size, file_suffix).check(path_exist=False) + func(data, path, *args, **kwargs) + except Exception as e: + raise MsprobeException( + MsgConst.IO_FAILURE, _SAVE_ERROR.format(data.__class__.__name__, path, func.__name__) + ) from e + change_permission(path, AUTHORITY_FILE) + + return wrapper + + return decorator + + +def _save_dir(dir_size): + def decorator(func): + @wraps(func) + def wrapper(data, path, *args, **kwargs): + path = SafePath(path, PathConst.DIR, "w", dir_size).check(path_exist=False) + try: + func(data, path, *args, **kwargs) + except Exception as e: + raise MsprobeException( + MsgConst.IO_FAILURE, _SAVE_ERROR.format(data.__class__.__name__, path, func.__name__) + ) from e + change_permission(path, AUTHORITY_DIR) + + return wrapper + + return decorator + + +@_load_file("r", PathConst.SIZE_30G, ".onnx", use_safely_open=False) +def load_onnx_model(model_path): + onnx = dependent.get("onnx") + return onnx.load_model(model_path) + + +@_load_file("r", PathConst.SIZE_30G, ".onnx", use_safely_open=False) +def load_onnx_session(model_path, onnx_fusion_switch=True, provider="CPUExecutionProvider"): + ort = dependent.get("onnxruntime") + options = ort.SessionOptions() + if not onnx_fusion_switch: + options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL + return ort.InferenceSession(model_path, sess_options=options, providers=[provider]) + + +@_load_file("r", PathConst.SIZE_30G, ".om", use_safely_open=False) +def load_om_model(model_path): + cmsprobe = dependent.get("msprobe.lib.msprobe_c") + model_id, ret = cmsprobe.acl.load_from_file(model_path) + if ret != 0: + raise MsprobeException(MsgConst.IO_FAILURE, f"Load model: {model_path} failed! ErrorCode = {ret}.") + logger.info(f"Load model: {model_path} success!") + return model_id + + +@_save_file("w", None, ".onnx", use_safely_open=False) +def save_onnx_model(onnx_model, save_path): + onnx = dependent.get("onnx") + model_size = onnx_model.ByteSize() + save_external_flag = model_size > PathConst.SIZE_2G + onnx.save_model(onnx_model, save_path, save_as_external_data=save_external_flag) + + +@_load_file("r", PathConst.SIZE_30G, ".prototxt", use_safely_open=False) +def load_caffe_model(model_path, weight_path): + caffe = dependent.get("caffe") + if caffe: + caffe.set_mode_cpu() + return caffe.Net(model_path, weight_path, caffe.TEST) + return None + + +@_load_file("r", PathConst.SIZE_10G, ".npy", use_safely_open=False) +def load_npy(npy_path): + return np.load(npy_path, allow_pickle=False) + + +def load_npy_from_buffer(raw_data, dtype, shape): + try: + return np.frombuffer(raw_data, dtype=dtype).reshape(shape) + except Exception as e: + raise MsprobeException(MsgConst.IO_FAILURE, "Failed to load npy data from buffer.") from e + + +@_save_file("w", None, ".npy", use_safely_open=False) +def save_npy(npy_data, save_path): + np.save(save_path, npy_data) + + +@_save_file("wb", None, ".bin", use_safely_open=False) +def save_bin_from_ndarray(numpy_data: np.ndarray, save_path): + numpy_data.tofile(save_path) + + +@_save_file("wb", None, ".bin", use_safely_open=True) +def save_bin_from_bytes(bytes_data, f): + f.write(bytes_data) + + +@_load_file("r", PathConst.SIZE_10G, ".bin", use_safely_open=False) +def load_bin_data(bin_path, dtype=np.float16, shape=None, is_byte_data=False): + if is_byte_data: + return np.fromfile(bin_path, dtype=np.int8) + if dtype == np.float32 and get_file_size(bin_path) == np.prod(shape) * 2: + return np.fromfile(bin_path, dtype=np.float16).astype(np.float32) + else: + return np.fromfile(bin_path, dtype=dtype) + + +@_load_dir(PathConst.SIZE_30G) +def load_saved_model(model_path, tag): + pons = dependent.get_tensorflow() + if None not in pons: + tf, _, _ = pons + tf.compat.v1.reset_default_graph() + graph = tf.compat.v1.Graph() + sess = tf.compat.v1.Session(graph=graph) + saved_model = tf.compat.v1.saved_model.loader.load(sess, set(tag), model_path) + return saved_model, sess + return None, None + + +@_load_file("rb", PathConst.SIZE_30G, ".pb", use_safely_open=False) +def load_pb_frozen_graph_model(model_path): + pons = dependent.get_tensorflow() + if None not in pons: + tf, _, _ = pons + data = tf.compat.v1.gfile.GFile(model_path, "rb").read() + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(data) + tf.compat.v1.import_graph_def(graph_def, name="") + return graph_def + return None + + +@_save_file("wb", PathConst.SIZE_30G, ".pb", use_safely_open=False) +def save_pb_frozen_graph_model(frozen_graph, model_path): + pons = dependent.get_tensorflow() + if None not in pons: + tf, _, _ = pons + with tf.io.gfile.GFile(model_path, "wb") as f: + f.write(frozen_graph) + + +def savedmodel2pb(model_path, tag, serve, pb_save_dir): + """ + Converts a TensorFlow 1.x SavedModel to a frozen PB file. + + :param model_path: Path to the saved TensorFlow SavedModel directory + :param tag: Tag used for loading the model + :param serve: Signature key (e.g., "serving_default") + :param pb_save_dir: Directory to save the PB file + :return: Path to the converted PB file and net output nodes + """ + pons = dependent.get_tensorflow() + if None not in pons: + _, _, sm2pb = pons + meta_graph_def, sess = load_saved_model(model_path, tag) + signature_def = meta_graph_def.signature_def.get(serve) + if signature_def is None: + raise MsprobeException(MsgConst.VALUE_NOT_FOUND, f'Signature "{serve}" not found in the model.') + input_tensor_names = [t.name for t in signature_def.inputs.values()] + output_tensor_names = [t.name for t in signature_def.outputs.values()] + logger.info(f"Saved model input tensors: {input_tensor_names}.") + logger.info(f"Saved model output tensors: {output_tensor_names}.") + output_node_names = [t.split(":")[0] for t in output_tensor_names] + frozen_graph_def = sm2pb(sess, sess.graph.as_graph_def(), output_node_names) + pb_file_name = get_basename_from_path(model_path) + ".pb" + pb_file_path = join_path(pb_save_dir, pb_file_name) + save_pb_frozen_graph_model(frozen_graph_def.SerializeToString(), pb_file_path) + sess.close() + logger.info(f"SavedModel has been successfully converted to a frozen PB file at {pb_file_path}.") + return pb_file_path + return "" + + +@_load_file("r", PathConst.SIZE_500M, ".yaml", use_safely_open=True) +def load_yaml(f): + return yaml.safe_load(f) + + +@_save_file("w", None, ".yaml", use_safely_open=True) +def save_yaml(yaml_data, f): + yaml.dump(yaml_data, f) + + +@_load_file("r", PathConst.SIZE_2G, ".json", use_safely_open=True) +def load_json(f): + return json.load(f) + + +@_save_file("w", None, ".json", use_safely_open=True) +def save_json(json_data, f, indent: int = None): + json.dump(json_data, f, indent=indent, default=str) + + +@_load_file("r", PathConst.SIZE_500M, ".csv", use_safely_open=True, encoding="utf-8-sig") +def load_csv_by_builtin(f, sep=",", check=CsvCheckLevel.STRICT): + csv_reader = csv.reader(f, delimiter=sep) + sanitized_rows = [] + for row in csv_reader: + sanitized_row = [sanitize_csv_value(value, check) for value in row] + sanitized_rows.append(sanitized_row) + return sanitized_rows + + +@_load_file("r", PathConst.SIZE_500M, ".csv", use_safely_open=False) +def load_csv_by_pandas(csv_path, sep=",", check=CsvCheckLevel.STRICT): + df = pd.read_csv(csv_path, sep=sep, dtype=str) + df = df.applymap(lambda value: sanitize_csv_value(value, check)) + return df + + +@_save_file("w", None, ".csv", use_safely_open=False) +def save_csv_by_pandas(csv_data: pd.DataFrame, csv_path, sep=",", check=CsvCheckLevel.STRICT): + sanitized_data = csv_data.applymap(lambda value: sanitize_csv_value(value, check)) + sanitized_data.to_csv(csv_path, sep=sep, index=False) + + +@_load_file("r", PathConst.SIZE_30G, None, use_safely_open=False) +def load_torch_obj(path, **kwargs): + kwargs.setdefault("weights_only", True) + try: + torch = dependent.get("torch") + return torch.load(path, **kwargs) + except pickle.UnpicklingError: + if kwargs["weights_only"]: + prompt = """ + Weights only load failed. Re-running with `weights_only` set to `False` will likely succeed, + but it can result in arbitrary code execution. Do it only if you get the file from a trusted source. \n + Please confirm your awareness of the risks associated with this action ([y]/n): """ + if not is_input_yes(prompt): + return None + kwargs["weights_only"] = False + return torch.load(path, **kwargs) + else: + return None diff --git a/accuracy_tools/msprobe/utils/log.py b/accuracy_tools/msprobe/utils/log.py new file mode 100644 index 00000000000..965f2295636 --- /dev/null +++ b/accuracy_tools/msprobe/utils/log.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.lib.msprobe_c import log + +_STAR = "*" +_DEBUG = "DEBUG" +_INFO = "INFO" +_WARNING = "WARNING" +_ERROR = "ERROR" +_TOTAL_CHAR_LENGTH = 80 +LOG_LEVEL = [_DEBUG, _INFO, _WARNING, _ERROR] + + +class Logger: + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(Logger, cls).__new__(cls) + return cls._instance + + @staticmethod + def get_level_id(level: str): + if level.upper() in LOG_LEVEL: + return LOG_LEVEL.index(level.upper()) + else: + return LOG_LEVEL.index(LOG_LEVEL[1]) + + @staticmethod + def error(msg): + log.print_log(LOG_LEVEL.index(_ERROR), msg) + + @staticmethod + def warning(msg): + log.print_log(LOG_LEVEL.index(_WARNING), msg) + + @staticmethod + def info(msg): + log.print_log(LOG_LEVEL.index(_INFO), msg) + + @staticmethod + def debug(msg): + log.print_log(LOG_LEVEL.index(_DEBUG), msg) + + def set_level(self, level: str): + level_id = self.get_level_id(level) + log.set_log_level(level_id) + + +logger = Logger() + + +def print_log_with_star(info_message: str): + total_length = _TOTAL_CHAR_LENGTH + logger.info(_STAR * total_length) + logger.info(f"{_STAR}{info_message.center(total_length - 2)}{_STAR}") + logger.info(_STAR * total_length) diff --git a/accuracy_tools/msprobe/utils/path.py b/accuracy_tools/msprobe/utils/path.py new file mode 100644 index 00000000000..173e30a4ceb --- /dev/null +++ b/accuracy_tools/msprobe/utils/path.py @@ -0,0 +1,340 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from enum import Enum +from pathlib import Path +from shutil import disk_usage +from stat import S_IMODE, S_IRUSR, S_IWGRP, S_IWOTH, S_IWUSR, S_IXUSR + +from msprobe.utils.constants import MsgConst, PathConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger +from msprobe.utils.toolkits import check_int_border + +_MAX_PATH_LENGTH = 4096 +_MAX_LAST_NAME_LENGTH = 255 +_VALID_PATH_PATTERN = r"^(?!.*\.\.)[a-zA-Z0-9_./-]+$" + +_MODE_READ = {"r", "rb"} +_MODE_WRITE = {"w", "wb", "a", "ab", "a+"} +_MODE_EXEC = {"e"} +_MODE = _MODE_READ | _MODE_WRITE | _MODE_EXEC +_MAX_DIR_DEPTH = 32 +AUTHORITY_DIR = 0o750 +AUTHORITY_FILE = 0o640 + + +class SoftLinkLevel(Enum): + IGNORE = 0 + WARNING = 1 + STRICT = 2 + + +def is_file(path: str): + return os.path.isfile(path) + + +def is_dir(path: str): + return os.path.isdir(path) + + +def get_basename_from_path(path: str): + return os.path.basename(path.rstrip("/")) + + +def get_file_size(path: str): + return os.path.getsize(path) + + +def get_abs_path(path: str): + return os.path.abspath(path) + + +def get_name_and_ext(model_path): + basename = get_basename_from_path(model_path) + # Always returns (name, ext). + return os.path.splitext(basename) + + +def join_path(*args, max_depth=_MAX_DIR_DEPTH): + check_int_border(max_depth, tag="max value of directory depth") + + def flatten(items, depth=0): + if depth > max_depth: + raise MsprobeException(MsgConst.RISK_ALERT, f"Maximum recursion depth {max_depth} exceeded") + for item in items: + if isinstance(item, str): + yield item + elif isinstance(item, (list, tuple)): + yield from flatten(item, depth + 1) + else: + pass + + return os.path.join(*flatten(args)) + + +def is_saved_model_scene(model_path): + saved_model_pb = join_path(model_path, "saved_model.pb") + if not is_file(saved_model_pb): + return False + variables_dir = join_path(model_path, "variables") + return is_dir(variables_dir) + + +def convert_bytes(bytes_size: int) -> str: + if bytes_size < 1024: + return f"{bytes_size} Bytes" + elif bytes_size < 1_048_576: # 1024 * 1024 + return f"{bytes_size / 1024:.2f} KB" + elif bytes_size < 1_073_741_824: # 1024 * 1024 * 1024 + return f"{bytes_size / (1_048_576):.2f} MB" + else: + return f"{bytes_size / (1_073_741_824):.2f} GB" + + +class SafePath: + def __init__( + self, + path: str, + path_type: str, + mode: str, + size_limitation: int = None, + suffix: str = None, + max_dir_depth: int = _MAX_DIR_DEPTH, + ): + self.path = self._check_path(path) + self.path_type = self._check_path_type(path_type) + self.mode = self._check_mode(mode) + self.size_limitation = self._check_int(size_limitation) if size_limitation else None + self.suffix = suffix + self.max_dir_depth = self._check_int(max_dir_depth) + + @staticmethod + def _check_path(path): + if not isinstance(path, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"path" must be string.') + return path + + @staticmethod + def _check_path_type(path_type): + if path_type not in [PathConst.FILE, PathConst.DIR]: + raise MsprobeException( + MsgConst.INVALID_ARGU, + f"The path type must be one of {[PathConst.FILE, PathConst.DIR]}, " f"currently: {path_type}.", + ) + return path_type + + @staticmethod + def _check_mode(mode): + if mode not in _MODE: + raise MsprobeException(MsgConst.INVALID_ARGU, f"Mode must be one of {_MODE}, currently: {mode}.") + return mode + + @staticmethod + def _check_int(value): + if not isinstance(value, int): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, f"Value must be an integer, currently: {value}.") + return value + + @staticmethod + def _check_path_exist(path): + if not os.path.exists(path): + raise MsprobeException(MsgConst.INVALID_ARGU, f"Path not found: {path}.") + + @staticmethod + def _check_soft_link(path: str, level: SoftLinkLevel) -> str: + if not os.path.islink(path): + return path + real_path = os.path.realpath(path) + if not isinstance(level, SoftLinkLevel): + raise MsprobeException( + MsgConst.INVALID_ARGU, f"The validation level of symbolic links must be a SoftLinkLevel enum value." + ) + if level == SoftLinkLevel.STRICT: + raise MsprobeException(MsgConst.RISK_ALERT, f"Path {path} is a symlink. Usage prohibited.") + elif level == SoftLinkLevel.WARNING: + logger.warning(f"Found a symlink, path {path} -> {real_path}.") + else: + pass + return real_path + + @staticmethod + def _check_write_permission_for_group_others(path, permission): + if bool(permission & (S_IWGRP | S_IWOTH)): + raise MsprobeException( + MsgConst.RISK_ALERT, + f"The path {path} is writable by group and others. " + "Permissions for files (or directories) should not exceed 0o755 (rwxr-xr-x).", + ) + + @classmethod + def _check_permission(cls, path, mode): + path_stat = os.stat(path) + owner_id = path_stat.st_uid + current_uid = os.geteuid() + if owner_id not in {current_uid, 0}: + raise MsprobeException(MsgConst.RISK_ALERT, f"The owner of {path} must be root or the current user.") + permission = S_IMODE(path_stat.st_mode) + if current_uid == 0: + logger.warning(f"Running as root: Skipping permission checks for {path}, but this is a potential risk.") + else: + cls._check_write_permission_for_group_others(path, permission) + if mode in _MODE_READ and not (permission & S_IRUSR): + raise MsprobeException( + MsgConst.NO_PERMISSION, f"The current user is not authorized to read the path: {path}." + ) + if mode in _MODE_WRITE and not (permission & S_IWUSR): + raise MsprobeException( + MsgConst.NO_PERMISSION, f"The current user is not authorized to write the path: {path}." + ) + if mode in _MODE_EXEC and not (permission & S_IXUSR): + raise MsprobeException( + MsgConst.NO_PERMISSION, f"The current user is not authorized to execute the path: {path}." + ) + + def check(self, path_exist=True, soft_link_level=SoftLinkLevel.STRICT): + self.path = get_abs_path(os.path.normpath(self.path)) + if self.mode in _MODE_WRITE and not path_exist: + parent_dir = get_abs_path(join_path(self.path, os.pardir)) + self._check_path_exist(parent_dir) # The current path doesn't exist, but the parent directory does. + parent_dir = self._check_soft_link(parent_dir, soft_link_level) + if not is_dir(parent_dir): + raise MsprobeException(MsgConst.INVALID_ARGU, f"The parent directory {parent_dir} is not valid.") + self._check_special_chars() + self._check_path_length() + self._check_permission(parent_dir, self.mode) + else: + self._check_path_exist(self.path) + self.path = self._check_soft_link(self.path, soft_link_level) + self._check_special_chars() + self._check_path_length() + if self.path_type == PathConst.FILE: + if not is_file(self.path): + raise MsprobeException(MsgConst.INVALID_ARGU, f"The path {self.path} is not a file.") + self._check_file_suffix() + self._check_file_size() + elif self.path_type == PathConst.DIR: + if not is_dir(self.path): + raise MsprobeException(MsgConst.INVALID_ARGU, f"The path {self.path} is not a directory.") + self._check_dir_size() + self._check_permission(self.path, self.mode) + if self.path_type == PathConst.DIR and not self.path.endswith("/"): + self.path += "/" + return self.path + + def _check_special_chars(self): + if not re.match(_VALID_PATH_PATTERN, self.path): + raise MsprobeException(MsgConst.INVALID_ARGU, f"Path {self.path} contains special characters.") + + def _check_path_length(self): + if len(self.path) > _MAX_PATH_LENGTH: + raise MsprobeException( + MsgConst.RISK_ALERT, f"Current path length ({len(self.path)}) exceeds the limit ({_MAX_PATH_LENGTH})." + ) + dir_depth = 0 + for dir_name in self.path.split("/"): + dir_depth += 1 + if dir_depth > self.max_dir_depth: + raise MsprobeException(MsgConst.RISK_ALERT, f"Exceeded max directory depth ({self.max_dir_depth}).") + if len(dir_name) > _MAX_LAST_NAME_LENGTH: + raise MsprobeException( + MsgConst.RISK_ALERT, + f"Current {self.path_type} length ({len(dir_name)}) exceeds the limit ({_MAX_LAST_NAME_LENGTH}).", + ) + + def _check_file_suffix(self): + if self.suffix and not self.path.endswith(self.suffix): + raise MsprobeException(MsgConst.INVALID_ARGU, f"{self.path} is not a {self.suffix} file.") + + def _check_file_size(self): + if self.size_limitation and os.path.getsize(self.path) > self.size_limitation: + raise MsprobeException( + MsgConst.RISK_ALERT, f"File size exceeds the limit ({convert_bytes(self.size_limitation)})." + ) + + def _check_dir_size(self): + if self.size_limitation and get_dir_size(self.path, self.max_dir_depth) > self.size_limitation: + raise MsprobeException( + MsgConst.RISK_ALERT, f"Directory size exceeds the limit ({convert_bytes(self.size_limitation)})." + ) + + +def get_dir_size(dir_path, max_dir_depth=_MAX_DIR_DEPTH): + total_size = 0 + for root, _, files in os.walk(dir_path): + # fmt: off + current_depth = root[len(dir_path):].count(os.sep) + # fmt: on + if current_depth > max_dir_depth: + raise MsprobeException( + MsgConst.RISK_ALERT, + f"Calculated size of {dir_path}, but exceeded max depth ({max_dir_depth}). Current size: {total_size}.", + ) + for file_name in files: + total_size += os.path.getsize(join_path(root, file_name)) + return total_size + + +def make_dirs(dir_path: str): + normalized_path = os.path.normpath(dir_path) + depth_parts = normalized_path.strip(os.sep).split(os.sep) + depth = len([p for p in depth_parts if p]) + if depth > _MAX_DIR_DEPTH: + raise MsprobeException( + MsgConst.RISK_ALERT, f"Directory depth exceeds the limit of {_MAX_DIR_DEPTH}: {dir_path} has depth {depth}." + ) + + try: + Path(dir_path).mkdir(mode=AUTHORITY_DIR, exist_ok=True, parents=True) + except OSError as e: + raise MsprobeException( + MsgConst.IO_FAILURE, + f"Failed to create {dir_path}, please Check if the parent directory of the current " + f"path exists, and verify permissions or disk space.", + ) from e + + +def change_permission(path, permission): + if not os.path.exists(path) or os.path.islink(path): + return + try: + os.chmod(path, permission) + except PermissionError as e: + raise MsprobeException(MsgConst.NO_PERMISSION, f"Failed to set permissions ({permission}) for {path}.") from e + + +def is_enough_disk_space(path, required_space): + return disk_usage(path).free >= required_space + + +class DirSafeHandler: + @staticmethod + def ensure_dir_exists(path: str): + if not is_dir(path): + make_dirs(path) + + @staticmethod + def get_or_raise(path: str, error_msg: str): + if path: + return path + else: + raise MsprobeException(MsgConst.PATH_NOT_FOUND, error_msg) + + @staticmethod + def join_and_create(*args): + path = join_path(args) + DirSafeHandler.ensure_dir_exists(path) + return path diff --git a/accuracy_tools/msprobe/utils/toolkits.py b/accuracy_tools/msprobe/utils/toolkits.py new file mode 100644 index 00000000000..e36a3efd60c --- /dev/null +++ b/accuracy_tools/msprobe/utils/toolkits.py @@ -0,0 +1,424 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from enum import Enum +from functools import wraps +from random import seed +from subprocess import PIPE, CalledProcessError, Popen, run +from time import perf_counter, time + +import numpy as np + +from msprobe.utils.constants import MsgConst +from msprobe.utils.dependencies import dependent +from msprobe.utils.env import evars +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger + +_MALICIOUS_CSV_PATTERN = re.compile(r"^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]") + + +class CsvCheckLevel(Enum): + IGNORE = 0 + REPLACE = 1 + STRICT = 2 + + +_POSITIVE_INT_BORDER = [0, 4_294_967_295] # [0, 2**32 - 1] + + +def get_pid(): + return os.getpid() + + +def get_current_timestamp(microsecond=True): + if microsecond: + return round(perf_counter() * 1e6) % 10**10 + else: + timestamp = int(time()) + return timestamp + + +def filter_cmd(paras): + whitelist_pattern = re.compile(r"^[a-zA-Z0-9_\-./=:,\[\] ]+$") + filtered = [] + for arg in paras: + arg_str = str(arg) + if whitelist_pattern.fullmatch(arg_str): + filtered.append(arg_str) + else: + raise MsprobeException( + MsgConst.RISK_ALERT, + f'The command contains invalid characters. Only the "{whitelist_pattern}" pattern is allowed.', + ) + return filtered + + +def register(name, tmp_map): + @wraps(name) + def wrapper(comp_type): + tmp_map[name] = comp_type + return comp_type + + return wrapper + + +def safely_compute(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + logger.warning(f"Calculation failed via {func.__name__}: {e}") + return None + + return wrapper + + +def get_valid_name(name: str): + if name and name[0] == "/": + name = name.lstrip("/") + return name.replace(".", "_").replace("/", "_").replace(":", "_") + + +def run_subprocess(cmd: list, capture_output=False): + if not isinstance(cmd, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "`cmd` must be a list of strings.") + cmd = filter_cmd(cmd) + logger.warning("Please ensure the executed command is correct.") + logger.info(f'Running command: {" ".join(cmd)}.') + if capture_output: + process = Popen(cmd, stdout=PIPE, stderr=PIPE, text=True, bufsize=1, shell=False) + stdout, stderr = process.communicate() + stderr_lines = stderr.splitlines() + if process.returncode != 0: + logger.error(f"Sub-process failed with error: {stderr_lines}.") + process.terminate() + raise MsprobeException(MsgConst.CALL_FAILED, f"Failed to execute command: {' '.join(cmd)}.") + return stdout + else: + try: + run(cmd, text=True, shell=False, check=True) + return None + except CalledProcessError as e: + raise MsprobeException(MsgConst.CALL_FAILED, f"Command failed: {' '.join(cmd)}") from e + + +class DistBackend: + torch = dependent.get("torch") + dist_map = {"cuda": "nccl", "npu": "hccl", "cpu": "gloo"} + + @staticmethod + def _get_visible_device(device_type) -> int: + try: + return int(evars.get(device_type, "0").split(",")[0]) + except Exception as e: + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, + f"Please check the value of the environment variable {device_type}, " + f'currently: {evars.get(device_type, "0")}.', + ) from e + + @classmethod + def get(cls): + return cls.dist_map.get(cls._get_global_device(), "cpu") + + @classmethod + def _is_device_available(cls, device_name, device_type): + if device_name == "npu" and hasattr(cls.torch, "npu") and cls.torch.npu.is_available(): + return cls._get_visible_device(device_type) >= 0 + elif device_name == "cuda" and hasattr(cls.torch, "cuda") and cls.torch.cuda.is_available(): + return cls._get_visible_device(device_type) >= 0 + elif device_name == "cpu": + return True + return False + + @classmethod + def _get_global_device(cls): + if cls._is_device_available("npu", "ASCEND_VISIBLE_DEVICES"): + return "npu" + elif cls._is_device_available("cuda", "CUDA_VISIBLE_DEVICES"): + return "cuda" + else: + return "cpu" + + +def timestamp_sync(timestamp: int): + torch = dependent.get("torch") + world_size = evars.get("LOCAL_WORLD_SIZE", "1", int) + if world_size < 2: + return timestamp + if torch: + timestamp = torch.tensor(timestamp) + if not torch.distributed.is_initialized(): + rank = evars.get("LOCAL_RANK", "0", int) + torch.distributed.init_process_group(backend=DistBackend.get(), rank=rank, world_size=world_size) + torch.distributed.all_reduce(timestamp, op=torch.distributed.ReduceOp.MAX) + return timestamp.item() + return timestamp + + +def get_current_rank() -> str: + torch = dependent.get("torch") + if torch and torch.distributed.is_initialized(): + return str(torch.distributed.get_rank()) + return "" + + +def check_int_border(*args, border: list = None, tag: str = None): + if not border: + border = _POSITIVE_INT_BORDER + if len(border) != 2: + raise MsprobeException(MsgConst.INVALID_ARGU, "The border must be a list of two integers.") + for num in args: + if not isinstance(num, int): + msg = f"Expected int type, but got {type(num).__name__}." + if tag: + msg += f" Context: {tag}." + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, msg) + if not (border[0] <= num <= border[1]): + msg = f"The integer range is limited to {border}, currently: {num}." + if tag: + msg += f" Context: {tag}." + raise MsprobeException(MsgConst.INVALID_ARGU, msg) + + +class DropoutHandler: + @staticmethod + def remove_for_pt(): + torch = dependent.get("torch") + if not torch or torch.__version__ <= "1.8": + return + logger.info("For precision comparison, the probability p in the dropout method is set to 0.") + _f = torch.nn.functional + _vf = torch._C._VariableFunctions + has_torch_function_unary = torch.overrides.has_torch_function_unary + handle_torch_function = torch.overrides.handle_torch_function + + def function_dropout(input_tensor, p: float = 0.5, training: bool = True, inplace: bool = False): + if has_torch_function_unary(input_tensor): + return handle_torch_function( + function_dropout, (input_tensor,), input_tensor, p=0.0, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise MsprobeException( + MsgConst.INVALID_ARGU, f"dropout probability has to be between 0 and 1, but got {p}." + ) + return _vf.dropout_(input_tensor, 0.0, training) if inplace else _vf.dropout(input_tensor, 0.0, training) + + def function_dropout2d(input_tensor, p: float = 0.5, training: bool = True, inplace: bool = False): + if has_torch_function_unary(input_tensor): + return handle_torch_function( + function_dropout2d, (input_tensor,), input_tensor, p=0.0, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise MsprobeException( + MsgConst.INVALID_ARGU, f"dropout probability has to be between 0 and 1, but got {p}." + ) + return ( + _vf.feature_dropout_(input_tensor, 0.0, training) + if inplace + else _vf.feature_dropout(input_tensor, 0.0, training) + ) + + def function_dropout3d(input_tensor, p: float = 0.5, training: bool = True, inplace: bool = False): + if has_torch_function_unary(input_tensor): + return handle_torch_function( + function_dropout3d, (input_tensor,), input_tensor, p=0.0, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise MsprobeException( + MsgConst.INVALID_ARGU, f"dropout probability has to be between 0 and 1, but got {p}." + ) + return ( + _vf.feature_dropout_(input_tensor, 0.0, training) + if inplace + else _vf.feature_dropout(input_tensor, 0.0, training) + ) + + _f.dropout = function_dropout + _f.dropout2d = function_dropout2d + _f.dropout3d = function_dropout3d + + @staticmethod + def remove_for_ms(): + ms = dependent.get("mindspore") + if not ms: + return + ops = ms.ops + nn = ms.mint.nn + + class Dropout(ops.Dropout): + def __init__(self, keep_prob=0.5, seed0=0, seed1=1): + super().__init__(1.0, seed0, seed1) + + class Dropout2D(ops.Dropout2D): + def __init__(self, keep_prob=0.5): + super().__init__(1.0) + + class Dropout3D(ops.Dropout3D): + def __init__(self, keep_prob=0.5): + super().__init__(1.0) + + class DropoutExt(nn.Dropout): + def __init__(self, p=0.5): + super().__init__(0) + + def dropout_ext(input_tensor, p=0.5, training=True): + return input_tensor + + ops.Dropout = Dropout + ops.operations.Dropout = Dropout + ops.Dropout2D = Dropout2D + ops.operations.Dropout2D = Dropout2D + ops.Dropout3D = Dropout3D + ops.operations.Dropout3D = Dropout3D + nn.Dropout = DropoutExt + nn.functional.dropout = dropout_ext + + +class SetSeed: + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(SetSeed, cls).__new__(cls) + return cls._instance + + def __init__(self, seed_num: int, mode: bool, rm_dropout: bool): + self.seed_num = seed_num + self.mode = mode + self.rm_dropout = rm_dropout + self._check_param() + + @classmethod + def all(cls): + cls._focus_on_native() + cls._focus_on_torch() + cls._focus_on_torch_npu() + cls._focus_on_ascend() + cls._focus_on_mindspore() + + def _check_param(self): + check_int_border(self.seed_num) + if not isinstance(self.mode, bool): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "`mode` must be a boolean.") + if not isinstance(self.rm_dropout, bool): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "`rm_dropout` must be a boolean.") + + def _focus_on_native(self): + evars.set("PYTHONHASHSEED", str(self.seed_num)) + seed(self.seed_num) + np.random.seed(self.seed_num) + + def _focus_on_ascend(self): + evars.set("LCCL_DETERMINISTIC", "1") + evars.set("HCCL_DETERMINISTIC", "true" if self.mode else "false") + evars.set("ATB_MATMUL_SHUFFLE_K_ENABLE", "0") + evars.set("ATB_LLM_LCOC_ENABLE", "0") + + def _focus_on_torch(self): + torch = dependent.get("torch") + if not torch: + return + torch.manual_seed(self.seed_num) + torch.use_deterministic_algorithms(mode=self.mode) + if hasattr(torch, "cuda"): + torch.cuda.manual_seed(self.seed_num) + torch.cuda.manual_seed_all(self.seed_num) + if hasattr(torch, "backends"): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.enable = False + torch.backends.cudnn.benchmark = False + if hasattr(torch, "version"): + cuda_version = torch.version.cuda + if cuda_version: + major, minor = map(int, cuda_version.split(".")[:2]) + if (major, minor) >= (10, 2): + evars.set("CUBLAS_WORKSPACE_CONFIG", ":4096:8") + if self.rm_dropout: + DropoutHandler.remove_for_pt() + + def _focus_on_torch_npu(self): + torch_npu = dependent.get("torch_npu") + if not torch_npu: + return + torch_npu.npu.manual_seed(self.seed_num) + torch_npu.npu.manual_seed_all(self.seed_num) + + def _focus_on_mindspore(self): + ms = dependent.get("mindspore") + if not ms: + return + ms.set_seed(self.seed_num) + ms.set_context(deterministic="ON" if self.mode else "OFF") + if self.rm_dropout: + DropoutHandler.remove_for_ms() + + +def seed_all(seed_num=666, mode=False, rm_dropout=True): + try: + SetSeed.all(seed_num, mode, rm_dropout) + except Exception as e: + raise MsprobeException(MsgConst.CALL_FAILED, f"Failed to set seed: {e}") from e + logger.info(f"Enable deterministic computation sucess! current seed is {seed_num}.") + + +def sanitize_csv_value(value: str, errors=CsvCheckLevel.STRICT): + if errors == CsvCheckLevel.IGNORE or not isinstance(value, str): + return value + sanitized_value = value + try: + float(value) + except Exception as e: + if not _MALICIOUS_CSV_PATTERN.search(value): + pass + elif errors == CsvCheckLevel.REPLACE: + sanitized_value = "" + logger.warning(f'Malicious CSV value detected and replaced: "{value}" -> "{sanitized_value}".') + else: + msg = f"Malicious value detected: {value}, please check the value written to the csv." + raise MsprobeException(MsgConst.RISK_ALERT, msg) from e + return sanitized_value + + +def get_net_output_nodes_from_graph_def(graph_def): + all_nodes = {node.name for node in graph_def.node} + input_nodes = set() + for node in graph_def.node: + for inp in node.input: + input_nodes.add(inp) + output_nodes = all_nodes - input_nodes + return list(output_nodes) + + +def is_input_yes(prompt): + confirm_pattern = re.compile(r"^\s*y(?:es)?\s*$", re.IGNORECASE) + try: + user_action = input(prompt).strip() + except (EOFError, KeyboardInterrupt): + logger.info('Input interrupted. Defaulting to "no".') + return False + return bool(confirm_pattern.fullmatch(user_action)) + + +def set_ld_preload(so_path): + ld_preload = evars.get("LD_PRELOAD", required=False) + if ld_preload: + evars.set("LD_PRELOAD", f"{so_path}:{ld_preload}") + else: + evars.set("LD_PRELOAD", so_path) + logger.info(f"Environment updated with .so library: {so_path}.") diff --git a/accuracy_tools/pyproject.toml b/accuracy_tools/pyproject.toml new file mode 100644 index 00000000000..1688f68d012 --- /dev/null +++ b/accuracy_tools/pyproject.toml @@ -0,0 +1,11 @@ +[tool.black] +line-length = 120 # 设置最大行长 +target-version = ['py37', 'py38', 'py39', 'py310', 'py311', 'py312'] # 兼容的 Python 版本 + +[tool.isort] +profile = "black" # 使 isort 与 Black 兼容 +line_length = 120 # 统一最大行长 +multi_line_output = 3 # 按分组方式输出多行 import 语句 +force_grid_wrap = 0 # 控制换行时的显示方式 +use_parentheses = true # 使用括号包裹长 import 语句 +combine_as_imports = true # 合并多行的 as 导入 diff --git a/accuracy_tools/requirements/requirements.txt b/accuracy_tools/requirements/requirements.txt new file mode 100644 index 00000000000..1939d520010 --- /dev/null +++ b/accuracy_tools/requirements/requirements.txt @@ -0,0 +1,7 @@ +numpy < 2.0 +protobuf >= 3.18, < 5.0 +onnx >= 1.12.0, < 2.0 +onnxruntime >= 1.10, < 2.0 +pandas >= 1.3, < 3.0 +PyYAML +tqdm diff --git a/accuracy_tools/requirements/requirements_tf.txt b/accuracy_tools/requirements/requirements_tf.txt new file mode 100644 index 00000000000..745fc5dd1a4 --- /dev/null +++ b/accuracy_tools/requirements/requirements_tf.txt @@ -0,0 +1,10 @@ +numpy >= 1.19.2, <= 1.21.6 +protobuf >= 3.9.2, <= 3.20.3 +scipy >= 1.5.2, <= 1.7.3 +pandas >= 1.2.0, <= 1.3.5 +decorator +sympy +attrs +psutil +PyYAML +tqdm diff --git a/accuracy_tools/setup.py b/accuracy_tools/setup.py new file mode 100644 index 00000000000..d5f102d6c89 --- /dev/null +++ b/accuracy_tools/setup.py @@ -0,0 +1,87 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "8.1.0" + +import os +import sys +from platform import machine +from subprocess import run + +from setuptools import find_packages, setup + +_COMPAT_REQUIREMENTS_MAP = {"tf": "requirements_tf.txt", "default": "requirements.txt"} + + +def parse_args(): + compat_flag = None + if "--compat" in sys.argv: + index = sys.argv.index("--compat") + compat_flag = sys.argv[index + 1] + sys.argv.remove("--compat") + sys.argv.remove(compat_flag) + return compat_flag + + +def get_requirements(compat_name=None): + requirements_parent_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements") + requirements_file = _COMPAT_REQUIREMENTS_MAP.get(compat_name, _COMPAT_REQUIREMENTS_MAP["default"]) + with open(os.path.join(requirements_parent_path, requirements_file)) as f: + required_lines = f.read().splitlines() + return required_lines + + +compat = parse_args() +required = get_requirements(compat) + +build_cmd = f"bash ./build.sh -j16 -a {machine()} -v {sys.version_info.major}.{sys.version_info.minor}" +p = run(build_cmd.split(), shell=False) +if p.returncode != 0: + raise RuntimeError(f"Failed to build source({p.returncode})") + + +setup( + name="mindstudio-probe", + version=__version__, + description="Ascend Probe Utils", + long_description=""" + MindStudio-Probe is a set of tools for diagnosing and improving model accuracy on Ascend NPU, + including API accuracy, args checker, grad tool etc. + """, + long_description_content_type="text/markdown", + url="https://gitee.com/ascend/mstt/tree/master/accuracy_tools/msprobe", + author="Ascend Team", + author_email="pmail_mindstudio@huawei.com", + packages=find_packages(include=["msprobe", "msprobe*"]), + package_data={"": ["LICENSE", "lib/*.so"]}, + license="Apache-2.0", + keywords=["msprobe", "pytorch", "mindspore"], + python_requires=">=3.7", + install_requires=required, + zip_safe=False, + classifiers=[ + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "Programming Language :: Python :: 3", + "Programming Language :: C++", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"console_scripts": ["msprobe=msprobe.__main__:main"]}, +) diff --git a/accuracy_tools/test/ST/run_st.py b/accuracy_tools/test/ST/run_st.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/accuracy_tools/test/UT/.coveragerc b/accuracy_tools/test/UT/.coveragerc new file mode 100644 index 00000000000..99f94acd09d --- /dev/null +++ b/accuracy_tools/test/UT/.coveragerc @@ -0,0 +1,3 @@ +[run] +# 计算覆盖率时排除单元测试文件本身 +omit = */test_*.py diff --git a/accuracy_tools/test/UT/CMakeLists.txt b/accuracy_tools/test/UT/CMakeLists.txt new file mode 100644 index 00000000000..fe94d13c7b5 --- /dev/null +++ b/accuracy_tools/test/UT/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(csrc_ut) diff --git a/accuracy_tools/test/UT/base_ut/component_ut/test_manager_component.py b/accuracy_tools/test/UT/base_ut/component_ut/test_manager_component.py new file mode 100644 index 00000000000..695fd6fdc7b --- /dev/null +++ b/accuracy_tools/test/UT/base_ut/component_ut/test_manager_component.py @@ -0,0 +1,269 @@ +import unittest +from unittest.mock import MagicMock + +from msprobe.base.component.manager import BaseComponent, Component, ConsumerComp, ProducerComp, Scheduler +from msprobe.utils.constants import MsgConst +from msprobe.utils.exceptions import MsitException + + +class TestBaseComponent: + def test_initialization(self): + component = BaseComponent(priority=200) + assert component.priority == 200 + assert component.is_activated is False + + def test_do_activate(self): + component = BaseComponent() + component.activate = MagicMock() + assert component.is_activated is False + + component.do_activate() + assert component.is_activated is True + component.activate.assert_called_once() + + component.do_activate() + component.activate.assert_called_once() + + def test_do_deactivate(self): + component = BaseComponent() + component.deactivate = MagicMock() + component.activated = True + assert component.is_activated is True + + component.do_deactivate() + assert component.is_activated is False + component.deactivate.assert_called_once() + + component.do_deactivate() + component.deactivate.assert_called_once() + + def test_activate_does_not_change_state_directly(self): + component = BaseComponent() + component.activate() + assert component.is_activated is False + + def test_deactivate_does_not_change_state_directly(self): + component = BaseComponent() + component.activated = True + component.deactivate() + assert component.is_activated is True + + +class ConcreteProducerComp(ProducerComp): + def __init__(self, priority): + super(ConcreteProducerComp, self).__init__(priority) + self._data_generated = False + + def load_data(self): + if not self._data_generated: + self._data_generated = True + return "generated_data" + return None + + +class ConcreteConsumerComp(ConsumerComp): + def __init__(self, priority): + super(ConcreteConsumerComp, self).__init__(priority) + + def consume(self, packages): + print("Consuming data:", packages) + + +class HybridComp(ProducerComp, ConsumerComp): + def __init__(self, priority): + super(HybridComp, self).__init__(priority) + self._data_generated = False + + def load_data(self): + if not self._data_generated: + self._data_generated = True + return "generated_data" + return None + + def consume(self, packages): + print("Consuming:", packages) + + +class TestProducerComp(unittest.TestCase): + def setUp(self): + self.producer = ConcreteProducerComp(priority=100) + self.scheduler_mock = MagicMock() + self.producer.scheduler = self.scheduler_mock + + self.producer.activate = MagicMock() + self.producer.deactivate = MagicMock() + self.producer.publish = MagicMock() + + def test_do_activate(self): + self.assertFalse(self.producer.is_activated) + self.producer.do_activate() + self.assertTrue(self.producer.is_activated) + self.producer.activate.assert_called_once() + + def test_do_deactivate(self): + self.producer.activated = True + self.producer.do_deactivate() + self.assertFalse(self.producer.is_activated) + self.producer.deactivate.assert_called_once() + + def test_retrieve(self): + self.producer.publish("some_data", msg_id=1) + self.assertIsNone(self.producer.output_buffer) + + def test_do_load_data_when_output_buffer_is_none(self): + self.producer.load_data = MagicMock(return_value="generated_data") + self.producer.do_load_data() + self.producer.load_data.assert_called_once() + self.producer.publish.assert_called_once_with("generated_data") + + def test_do_load_data_when_output_buffer_is_not_none(self): + self.producer.output_buffer = ["some_data"] + self.producer.do_load_data() + self.producer.publish.assert_not_called() + + +class TestConsumerComp(unittest.TestCase): + def setUp(self): + self.producer = ConcreteProducerComp(priority=1) + self.consumer = ConcreteConsumerComp(priority=2) + self.comp_a = HybridComp(priority=100) + self.comp_b = HybridComp(priority=200) + self.comp_c = HybridComp(priority=300) + self.consumer.consume = MagicMock() + + def test_do_consume_with_empty_dependencies(self): + self.consumer.dependencies = {MagicMock(): None} + self.consumer.do_consume() + self.consumer.consume.assert_not_called() + + def test_do_consume_with_filled_dependencies(self): + mock_producer = MagicMock() + package_data = [mock_producer, "mock_data", 1] + + self.consumer.dependencies = {mock_producer: package_data} + self.consumer.do_consume() + + self.consumer.consume.assert_called_once_with([package_data]) + self.assertEqual(self.consumer.dependencies[mock_producer], None) + + def test_do_consume_partial_dependencies(self): + mock_producer1 = MagicMock() + mock_producer2 = MagicMock() + package_data = [mock_producer1, "mock_data", 1] + + self.consumer.dependencies = {mock_producer1: package_data, mock_producer2: None} + self.consumer.do_consume() + self.consumer.consume.assert_not_called() + + def test_subscribe_valid(self): + self.consumer.subscribe(self.producer) + self.assertIn(self.consumer, self.producer.get_subscribers()) + + def test_subscribe_invalid_type(self): + with self.assertRaises(MsitException): + self.consumer.subscribe(self.consumer) + + def test_no_cycle(self): + self.comp_a.subscribe(self.comp_b) + self.comp_b.subscribe(self.comp_c) + try: + self.comp_c.subscribe(self.comp_a) + self.assertTrue(True) + except MsitException as e: + self.fail(f"Unexpected cycle detection exception: {e}") + + def test_already_subscribed(self): + self.comp_a.subscribe(self.comp_b) + self.comp_b.subscribe(self.comp_c) + self.comp_c.subscribe(self.comp_a) + self.assertEqual(len(self.comp_c.dependencies), 1) + + def test_multiple_cycles(self): + self.comp_a.subscribe(self.comp_b) + self.comp_b.subscribe(self.comp_c) + self.comp_c.subscribe(self.comp_a) + with self.assertRaises(MsitException) as context: + self.comp_a.subscribe(self.comp_c) + self.assertIn(MsgConst.RISK_ALERT, str(context.exception)) + + def test_on_receive(self): + package = [self.producer, "test_data", 0] + self.consumer.on_receive(package) + self.assertEqual(self.consumer.dependencies[self.producer], package) + + def test_get_empty_dependencies(self): + self.consumer.subscribe(self.producer) + self.assertIn(self.producer, self.consumer.get_empty_dependencies()) + + def test_do_consume(self): + self.consumer.subscribe(self.producer) + package = [self.producer, "test_data", 0] + self.consumer.on_receive(package) + + +class TestRegisterDecorator(unittest.TestCase): + def setUp(self): + Component._component_type_map = {} + + def test_register_decorator(self): + @Component.register("ComponentB") + class ComponentB: + pass + + self.assertIn("ComponentB", Component._component_type_map) + self.assertEqual(Component._component_type_map["ComponentB"], ComponentB) + + def test_get_registered_component(self): + @Component.register("ComponentC") + class ComponentC: + pass + + component = Component.get("ComponentC") + self.assertEqual(component, ComponentC) + + +class TestScheduler(unittest.TestCase): + def setUp(self): + self.scheduler = Scheduler() + self.producer = MagicMock(ProducerComp) + self.consumer = MagicMock(ConsumerComp) + self.producer.is_ready = True + self.consumer.is_activated = False + self.consumer.get_empty_dependencies.return_value = [] + self.consumer.do_consume = MagicMock() + + def test_add_component(self): + self.scheduler.add([self.producer]) + self.assertIn(self.producer, self.scheduler.comp_ref) + self.assertEqual(self.scheduler.comp_ref[self.producer], 1) + + def test_remove_component(self): + self.scheduler.add([self.producer]) + self.scheduler.remove([self.producer]) + self.assertNotIn(self.producer, self.scheduler.comp_ref) + + def test_schedule_consumer_when_no_dependencies(self): + self.scheduler._schedule_consumer(self.consumer) + self.consumer.do_consume.assert_called_once() + self.assertIn(self.consumer, self.scheduler.comps_to_schedule) + + def test_schedule_consumer_with_unready_dependencies(self): + dependency_mock = MagicMock() + dependency_mock.is_ready = False + dependency_mock.do_load_data = MagicMock() + + self.consumer.get_empty_dependencies.return_value = [dependency_mock] + self.scheduler._schedule_consumer(self.consumer) + self.consumer.do_consume.assert_not_called() + dependency_mock.do_load_data.assert_called_once() + self.assertNotIn(dependency_mock, self.scheduler.comps_to_schedule) + + def test_schedule_consumer_with_ready_dependencies(self): + dependency_mock = MagicMock() + dependency_mock.is_ready = True + dependency_mock.do_load_data = MagicMock() + + self.consumer.get_empty_dependencies.return_value = [dependency_mock] + self.scheduler._schedule_consumer(self.consumer) + dependency_mock.do_load_data.assert_called_once() + self.assertIn(dependency_mock, self.scheduler.comps_to_schedule) diff --git a/accuracy_tools/test/UT/base_ut/service_ut/test_manager_service.py b/accuracy_tools/test/UT/base_ut/service_ut/test_manager_service.py new file mode 100644 index 00000000000..84fc7b35008 --- /dev/null +++ b/accuracy_tools/test/UT/base_ut/service_ut/test_manager_service.py @@ -0,0 +1,79 @@ +import unittest +from unittest.mock import MagicMock, call, create_autospec, patch + +from msprobe.base import BaseComponent, BaseService, Scheduler, Service +from msprobe.utils.constants import CfgConst, CmdConst + + +class TestService(unittest.TestCase): + def setUp(self): + Service._services_map.clear() + + @patch("msprobe.base.service.manager.load_json") + @patch("msprobe.base.service.manager.valid_task") + def test_service_initialization_via_config(self, mock_valid_task, mock_load_json): + mock_load_json.return_value = {CfgConst.TASK: CfgConst.TASK_STAT} + mock_valid_task.return_value = CfgConst.TASK_STAT + mock_service_cls = MagicMock() + Service._services_map[CmdConst.DUMP] = mock_service_cls + cmd_namespace = MagicMock() + cmd_namespace.config_path = "dummy_path" + service = Service(cmd_namespace=cmd_namespace, key="value") + mock_load_json.assert_called_once_with("dummy_path") + mock_service_cls.assert_called_once_with(cmd_namespace=cmd_namespace, key="value") + self.assertEqual(service.service_instance, mock_service_cls.return_value) + + def test_service_registration(self): + @Service.register("test_service") + class TestServiceImpl: + pass + + self.assertIs(Service._services_map["test_service"], TestServiceImpl) + + @patch("msprobe.base.service.manager.load_json") + @patch("msprobe.base.service.manager.valid_task") + def test_service_method_delegation(self, mock_valid_task, mock_load_json): + mock_instance = MagicMock() + mock_instance.target_method = MagicMock(return_value="result") + mock_service_cls = MagicMock(return_value=mock_instance) + Service._services_map[CmdConst.DUMP] = mock_service_cls + mock_load_json.return_value = {CfgConst.TASK: CfgConst.TASK_STAT} + mock_valid_task.return_value = CfgConst.TASK_STAT + cmd_namespace = MagicMock() + cmd_namespace.config_path = "valid_path" + service = Service(cmd_namespace=cmd_namespace) + result = service.target_method("arg", kw=456) + mock_instance.target_method.assert_called_once_with("arg", kw=456) + self.assertEqual(result, "result") + mock_service_cls.assert_called_once_with(cmd_namespace=cmd_namespace) + + +class TestBaseService(unittest.TestCase): + @patch.object(Scheduler, "add") + @patch.object(Scheduler, "remove") + def test_full_lifecycle(self, mock_remove, mock_add): + class TestService(BaseService): + def construct(self): + self.high_pri = create_autospec(BaseComponent, name="high_pri") + self.high_pri.priority = 1 + self.low_pri = create_autospec(BaseComponent, name="low_pri") + self.low_pri.priority = 2 + self.non_comp = "non comp" + + service = TestService() + service.start() + mock_add.assert_called_once() + + @patch.object(BaseService, "init_start") + @patch.object(BaseService, "finalize_start") + def test_hook_execution_order(self, mock_final, mock_init): + class HookTestService(BaseService): + def construct(self): + pass + + service = HookTestService() + service.start() + mock_init.assert_called_once() + mock_final.assert_called_once() + self.assertEqual(mock_init.call_args_list[0], call()) + self.assertEqual(mock_final.call_args_list[-1], call()) diff --git a/accuracy_tools/test/UT/base_ut/test_cmd.py b/accuracy_tools/test/UT/base_ut/test_cmd.py new file mode 100644 index 00000000000..a9e257d0eff --- /dev/null +++ b/accuracy_tools/test/UT/base_ut/test_cmd.py @@ -0,0 +1,86 @@ +import unittest +from argparse import RawTextHelpFormatter +from unittest.mock import MagicMock, patch + +from msprobe.base import Command, msprobeCommand +from msprobe.utils.constants import CmdConst, MsgConst +from msprobe.utils.exceptions import msprobeException + + +class TestCommandRegistration(unittest.TestCase): + def setUp(self): + Command._cmd_map.clear() + + def test_register_command(self): + parent_cmd = None + cmd_name = "test" + + @Command.register(parent_cmd, cmd_name) + class TestCommand(msprobeCommand): + pass + + self.assertIn(parent_cmd, Command._cmd_map) + self.assertIn(cmd_name, Command._cmd_map[parent_cmd]) + self.assertIs(Command._cmd_map[parent_cmd][cmd_name], TestCommand) + + def test_get_command(self): + parent1, parent2 = "parent1", "parent2" + cmd1, cmd2 = "cmd1", "cmd2" + + @Command.register(parent1, cmd1) + class Cmd1(msprobeCommand): + pass + + @Command.register(parent2, cmd2) + class Cmd2(msprobeCommand): + pass + + self.assertEqual(Command.get(parent1), {cmd1: Cmd1}) + self.assertEqual(Command.get(parent2), {cmd2: Cmd2}) + self.assertEqual(Command.get("invalid_parent"), {}) + + +class TestmsprobeCommand(unittest.TestCase): + class ConcreteCommand(msprobeCommand): + def add_arguments(self, parse): + pass + + def setUp(self): + self.cmd = self.ConcreteCommand() + self.cmd.subcommand_level = 0 + + @patch("msprobe.base.cmd.argv", ["script", "arg1", "arg2"]) + def test_input_module_valid(self): + self.cmd.subcommand_level = 1 + self.assertEqual(self.cmd.input_module, "arg1") + + @patch("msprobe.base.cmd.argv", ["script"]) + def test_input_module_insufficient_args(self): + self.cmd.subcommand_level = 1 + self.assertIsNone(self.cmd.input_module) + + def test_input_module_invalid_level(self): + self.cmd.subcommand_level = "invalid" + with self.assertRaises(msprobeException) as cm: + _ = self.cmd.input_module + self.assertEqual(str(cm.exception), f"{MsgConst.INVALID_ARGU} Subcommand level must be a positive integer.") + + @patch("msprobe.base.Command.get") + def test_build_parser_with_subcommands(self, mock_get): + class MockSubCommand: + @classmethod + def add_arguments(cls, parser): + pass + + mock_get.side_effect = [{"subcmd": MockSubCommand}, {}] + parent_parser = MagicMock() + fake_subparser = MagicMock() + subparsers = MagicMock() + parent_parser.add_subparsers.return_value = subparsers + subparsers.add_parser.return_value = fake_subparser + self.cmd.subcommand_level = 0 + self.cmd.build_parser(parent_parser, MagicMock()) + parent_parser.add_subparsers.assert_called_once_with(dest="L1command") + subparsers.add_parser.assert_called_once_with( + name="subcmd", help=CmdConst.HELP_TOOL_MAP.get("subcmd"), formatter_class=RawTextHelpFormatter + ) diff --git a/accuracy_tools/test/UT/base_ut/test_config.py b/accuracy_tools/test/UT/base_ut/test_config.py new file mode 100644 index 00000000000..b539ca8f695 --- /dev/null +++ b/accuracy_tools/test/UT/base_ut/test_config.py @@ -0,0 +1,135 @@ +import unittest +from unittest.mock import MagicMock, patch + +from msprobe.base import BaseConfig, Dict2Class +from msprobe.utils.constants import CfgConst, MsgConst +from msprobe.utils.exceptions import MsprobeException + + +class ConcreteConfig(BaseConfig): + def check_config(self): + pass + + +class TestBaseConfig(unittest.TestCase): + def setUp(self): + self.mock_config = { + CfgConst.TASK: "test_task", + "test_task": {"key": "value"}, + CfgConst.FRAMEWORK: "test_framework", + CfgConst.STEP: [], + CfgConst.RANK: [], + CfgConst.LEVEL: [CfgConst.LEVEL_API], + CfgConst.LOG_LEVEL: "info", + CfgConst.SEED: None, + } + self.config_path = "dummy_path.json" + + @patch("msprobe.base.config.load_json") + def test_initialization(self, mock_load_json): + mock_load_json.return_value = self.mock_config + config = ConcreteConfig(self.config_path, task="test_task", step=[], level=[]) + self.assertEqual(config.config_path, self.config_path) + self.assertEqual(config.config, self.mock_config) + self.assertEqual(config.task, "test_task") + self.assertEqual(config.step, []) + self.assertEqual(config.level, []) + mock_load_json.assert_called_once_with(self.config_path) + + @patch("msprobe.base.config.load_json") + def test_common_check_calls(self, mock_load_json): + mock_load_json.return_value = self.mock_config + config = ConcreteConfig(self.config_path) + + with patch.multiple( + "msprobe.base.config", + valid_task=MagicMock(return_value="test_task"), + valid_framework=MagicMock(return_value="valid_framework"), + valid_step_or_rank=MagicMock(side_effect=lambda x: x), + valid_level=MagicMock(return_value=["valid_level"]), + valid_log_level=MagicMock(return_value="valid_log_level"), + valid_seed=MagicMock(return_value=42), + ) as mocks: + config._common_check() + + self.assertEqual(config.config[CfgConst.TASK], "test_task") + self.assertEqual(config.config[CfgConst.FRAMEWORK], "valid_framework") + self.assertEqual(config.config[CfgConst.STEP], []) + self.assertEqual(config.config[CfgConst.RANK], []) + self.assertEqual(config.config[CfgConst.LEVEL], ["valid_level"]) + self.assertEqual(config.config[CfgConst.LOG_LEVEL], "valid_log_level") + self.assertEqual(config.config[CfgConst.SEED], 42) + + @patch("msprobe.base.config.load_json") + def test_get_task_dict_success(self, mock_load_json): + mock_load_json.return_value = {CfgConst.TASK: "existing_task", "existing_task": {"key": "value"}} + config = ConcreteConfig(self.config_path) + task_dict = config._get_task_dict() + self.assertEqual(task_dict, {"key": "value"}) + + @patch("msprobe.base.config.load_json") + def test_get_task_dict_raises_exception(self, mock_load_json): + mock_load_json.return_value = {CfgConst.TASK: "non_existing_task"} + config = ConcreteConfig(self.config_path) + with self.assertRaises(MsprobeException) as context: + config._get_task_dict() + self.assertIn(f'Missing dictionary for key "non_existing_task".', context.exception.error_msg) + + @patch("msprobe.base.config.load_json") + def test_update_config(self, mock_load_json): + mock_load_json.return_value = self.mock_config + config = ConcreteConfig(self.config_path) + test_dict = {} + mock_check = MagicMock(return_value="checked_value") + config._update_config(test_dict, "test_key", mock_check, "test_value") + mock_check.assert_called_once_with("test_value") + self.assertEqual(test_dict["test_key"], "checked_value") + + @patch("msprobe.base.config.load_json") + def test_check_config_wrapper(self, mock_load_json): + mock_load_json.return_value = self.mock_config + config = ConcreteConfig(self.config_path) + with patch.object(config, "_common_check") as mock_common_check, patch.object( + config, "check_config" + ) as mock_check_config: + config.check_config() + mock_common_check.assert_called_once() + mock_check_config.assert_called_once() + self.assertEqual(config.task_config, {"key": "value"}) + + +class TestDict2Class(unittest.TestCase): + def test_basic_conversion(self): + data = {"name": "test", "value": 10} + obj = Dict2Class(data) + self.assertEqual(obj.name, "test") + self.assertEqual(obj.value, 10) + + def test_nested_dict_conversion(self): + data = {"nested": {"key": "value"}} + obj = Dict2Class(data) + self.assertIsInstance(obj.nested, Dict2Class) + self.assertEqual(obj.nested.key, "value") + + def test_service_key_processing(self): + data = {CfgConst.TASK: "special", "special": {"input": [[224, 224], "path/to/input"], "param": 5}} + obj = Dict2Class(data) + self.assertEqual(obj.input_shape, [224, 224]) + self.assertEqual(obj.input_path, "path/to/input") + self.assertEqual(obj.param, 5) + + def test_max_recursion_depth(self): + data = {} + current = data + for _ in range(MsgConst.MAX_RECURSION_DEPTH + 1): + current["nested"] = {} + current = current["nested"] + with self.assertRaises(MsprobeException) as context: + Dict2Class(data) + self.assertIn(f"Maximum recursion depth of {MsgConst.MAX_RECURSION_DEPTH}", str(context.exception)) + + def test_missing_attribute(self): + obj = Dict2Class({"existing": 1}) + with self.assertRaises(MsprobeException) as context: + _ = obj.non_existing + self.assertIn("has no attribute non_existing", str(context.exception)) diff --git a/accuracy_tools/test/UT/common_ut/test_ascend.py b/accuracy_tools/test/UT/common_ut/test_ascend.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/accuracy_tools/test/UT/common_ut/test_cli.py b/accuracy_tools/test/UT/common_ut/test_cli.py new file mode 100644 index 00000000000..5e71221e58f --- /dev/null +++ b/accuracy_tools/test/UT/common_ut/test_cli.py @@ -0,0 +1,52 @@ +import unittest +from unittest.mock import ANY, MagicMock, PropertyMock, call, patch + +from msprobe.common.cli import MainCommand, MsprobeException + + +class TestMainCommand(unittest.TestCase): + def setUp(self): + self.main_cmd = MainCommand() + self.mock_second_commands = {"cmd1": MagicMock(), "cmd2": MagicMock()} + self.main_cmd.second_commands = self.mock_second_commands + + @patch("msprobe.common.cli.ArgumentParser") + def test_init(self, mock_argparse): + main_cmd = MainCommand() + mock_argparse.assert_called_once_with(prog="msprobe", description=ANY, formatter_class=ANY) + self.assertEqual(main_cmd.subcommand_level, 1) + self.assertIsNotNone(main_cmd.parser) + self.assertIsNotNone(main_cmd.subparser) + + def test_register(self): + with patch.object(MainCommand, "input_module", new_callable=PropertyMock) as mock_input_module: + mock_input_module.return_value = "cmd1" + mock_subparser = MagicMock() + self.main_cmd.subparser = mock_subparser + self.main_cmd.register() + expected_calls = [call(name="cmd1", help=None, formatter_class=self.main_cmd.formatter_class)] + mock_subparser.add_parser.assert_has_calls(expected_calls, any_order=True) + self.mock_second_commands["cmd1"].add_arguments.assert_called_once() + self.assertEqual(self.main_cmd.subcommand_level, 2) + + def test_parse(self): + mock_args = MagicMock() + self.main_cmd.parser.parse_args = MagicMock(return_value=mock_args) + result = self.main_cmd.parse() + self.assertEqual(result, mock_args) + self.main_cmd.parser.parse_args.assert_called_once() + + @patch("msprobe.common.cli.cann.get_atb_probe_so_path") + def test_set_env_failure(self, mock_get_so): + mock_get_so.return_value = None + with self.assertRaises(MsprobeException) as context: + self.main_cmd.set_env("invalid_framework") + + @patch("msprobe.common.cli.Service") + @patch("msprobe.common.cli.argv", ["msprobe", "invalid_service"]) + def test_execute_invalid_service(self, mock_service): + mock_service.get.return_value = False + args = MagicMock() + with self.assertRaises(MsprobeException) as context: + self.main_cmd.execute(args) + self.assertIn(" service is not registered", str(context.exception)) diff --git a/accuracy_tools/test/UT/common_ut/test_validation.py b/accuracy_tools/test/UT/common_ut/test_validation.py new file mode 100644 index 00000000000..94a9f229910 --- /dev/null +++ b/accuracy_tools/test/UT/common_ut/test_validation.py @@ -0,0 +1,254 @@ +import unittest +from argparse import Namespace +from unittest.mock import MagicMock, patch + +from msprobe.common.validation import ( + CheckConfigPath, + CheckExec, + CheckFramework, + SafePath, + check_int_border, + parse_hyphen, + valid_config_path, + valid_exec, + valid_framework, + valid_level, + valid_log_level, + valid_seed, + valid_step_or_rank, + valid_task, +) +from msprobe.utils.exceptions import MsprobeException + + +class TestValidationFunctions(unittest.TestCase): + def setUp(self): + self.mock_cfgconst = MagicMock() + self.mock_cfgconst.ALL_TASK = ["train", "eval", "predict"] + self.mock_cfgconst.ALL_FRAMEWORK = ["tf", "pytorch"] + self.mock_cfgconst.ALL_LEVEL = ["info", "debug", "warning"] + self.patcher = patch.dict( + "sys.modules", + { + "msprobe.utils.constants.CfgConst": self.mock_cfgconst, + "msprobe.utils.constants.PathConst": MagicMock( + SUFFIX_SH=".sh", + SUFFIX_PY=".py", + SUFFIX_OFFLINE_MODEL=(".onnx", ".pb"), + SUFFIX_ONLINE_SCRIPT=(".sh", ".py"), + SUFFIX_JSON=".json", + DIR="dir", + FILE="file", + ), + }, + ) + self.patcher.start() + + def tearDown(self): + self.patcher.stop() + + def test_valid_task_valid(self): + self.assertEqual(valid_task("tensor"), "tensor") + + def test_valid_task_invalid(self): + with self.assertRaises(MsprobeException) as cm: + valid_task("invalid_task") + self.assertIn("must be one of ", str(cm.exception)) + + def test_valid_task_type_invalid(self): + with self.assertRaises(MsprobeException) as cm: + valid_task(123) + self.assertIn("[ERROR] invalid data type.", str(cm.exception)) + + def test_valid_exec_none(self): + self.assertEqual(valid_exec([]), []) + + def test_valid_exec_type_invalid(self): + with self.assertRaises(MsprobeException) as cm: + valid_exec(123) + self.assertIn("[ERROR] invalid data type.", str(cm.exception)) + + @patch("msprobe.common.validation.is_dir") + @patch("msprobe.common.validation.msprobePath") + def test_valid_exec_directory(self, mock_msprobepath, mock_is_dir): + mock_is_dir.return_value = True + values = "/valid/directory" + result = valid_exec(values) + self.assertEqual(result, [values]) + mock_msprobepath.assert_called_once() + + def test_valid_exec_bash_valid(self): + values = "bash script.sh" + self.assertEqual(valid_exec(values), ["bash", "script.sh"]) + + def test_valid_exec_bash_invalid(self): + with self.assertRaises(MsprobeException) as cm: + valid_exec("bash invalid_script.py") + self.assertIn("[ERROR] Parsing failed.", str(cm.exception)) + + def test_valid_exec_python_invalid(self): + with self.assertRaises(MsprobeException) as cm: + valid_exec("python invalid_script.sh") + self.assertIn("[ERROR] Parsing failed.", str(cm.exception)) + + @patch("msprobe.common.validation.is_file") + @patch("msprobe.common.validation.msprobePath") + def test_valid_exec_model_file(self, mock_msprobepath, mock_is_file): + mock_is_file.return_value = True + values = "model.onnx" + self.assertEqual(valid_exec(values), [values]) + mock_msprobepath.assert_called_once() + + @patch("msprobe.common.validation.is_file") + @patch("msprobe.common.validation.msprobePath") + def test_invalid_exec_model_file(self, mock_msprobepath, mock_is_file): + mock_is_file.return_value = True + values = "model.tfv" + with self.assertRaises(MsprobeException) as cm: + valid_exec(values) + self.assertIn("('.pb', '.onnx', '.om', '.prototxt', '.py', '.sh').", str(cm.exception)) + + @patch("msprobe.common.validation.is_dir") + @patch("msprobe.common.validation.is_file") + @patch("msprobe.common.validation.SafePath") + def test_check_exec_action(self, mock_msprobepath, mock_is_file, mock_is_dir): + mock_is_dir.return_value = False + mock_is_file.return_value = True + action = CheckExec(option_strings=["-e", "--exec"], dest="exec") + mock_namespace = Namespace() + test_values = "valid_script.sh" + with patch.object(SafePath, "check") as mock_check: + mock_check.return_value = test_values[0] + action(None, mock_namespace, test_values) + self.assertEqual(mock_namespace.exec, [test_values]) + mock_is_dir.assert_called_once_with(test_values) + mock_is_file.assert_called_once_with(test_values) + + @patch("msprobe.common.validation.SafePath") + def test_valid_config_path_valid(self, mock_msprobepath): + mock_msprobepath.return_value.check.return_value = "valid.json" + result = valid_config_path("config.json") + self.assertEqual(result, "valid.json") + + def test_valid_config_path(self): + self.option_strings = ["-c", "--config"] + self.dest = "config_path" + self.action = CheckConfigPath(option_strings=self.option_strings, dest=self.dest) + test_value = "/valid/path/config.json" + expected_result = "/verified/path/config.json" + mock_namespace = Namespace() + + with patch("msprobe.common.validation.valid_config_path") as mock_validator: + mock_validator.return_value = expected_result + self.action(parser=MagicMock(), namespace=mock_namespace, values=test_value) + mock_validator.assert_called_once_with(test_value) + self.assertEqual(getattr(mock_namespace, self.dest), expected_result) + + def test_valid_framework_valid(self): + self.assertEqual(valid_framework("mindie_llm"), "mindie_llm") + + def test_valid_framework_invalid(self): + self.assertEqual(valid_framework(""), "") + + def test_valid_framework_type_invalid(self): + with self.assertRaises(MsprobeException) as cm: + valid_framework(123) + self.assertIn("[ERROR] invalid data type.", str(cm.exception)) + + def test_valid_framework_more_element_invalid(self): + with self.assertRaises(MsprobeException) as cm: + valid_framework("invalid_fw") + self.assertIn('[ERROR] invalid argument. "framework" must be one of', str(cm.exception)) + + def test_check_framework(self): + self.option_strings = ["-f", "--framework"] + self.dest = "framework" + self.action = CheckFramework(option_strings=self.option_strings, dest=self.dest) + test_value = "mindie_llm" + mock_namespace = Namespace() + with patch("msprobe.common.validation.valid_framework") as mock_validator: + mock_validator.return_value = test_value + self.action(parser=MagicMock(), namespace=mock_namespace, values=test_value) + mock_validator.assert_called_once_with(test_value) + self.assertEqual(getattr(mock_namespace, self.dest), test_value) + + def test_check_int_border_valid(self): + check_int_border(0, 500000, 1000000) + + def test_valid_check_int_type_invalid(self): + with self.assertRaises(MsprobeException) as cm: + check_int_border([0.35]) + self.assertIn("[ERROR] invalid data type.", str(cm.exception)) + + def test_check_int_border_invalid(self): + with self.assertRaises(MsprobeException) as cm: + check_int_border(-1) + self.assertIn("The integer range is limited to [0, 1000000.0], currently: -1.", str(cm.exception)) + with self.assertRaises(MsprobeException): + check_int_border(1000001) + + def test_parse_hyphen_valid(self): + self.assertEqual(parse_hyphen("100-200"), list(range(100, 201))) + self.assertEqual(parse_hyphen("100-200-2"), list(range(100, 201, 2))) + + def test_parse_hyphen_invalid(self): + with self.assertRaises(MsprobeException): + parse_hyphen("100-200-300-400") + with self.assertRaises(MsprobeException): + parse_hyphen("200-100") + + def test_valid_step_or_rank(self): + self.assertEqual(valid_step_or_rank([10, "20-22", "30-35-2"]), [10, 20, 21, 22, 30, 32, 34]) + + def test_valid_step_or_rank_none(self): + self.assertEqual(valid_step_or_rank([]), []) + + def test_valid_step_or_rank_type_invalid(self): + with self.assertRaises(MsprobeException) as cm: + valid_step_or_rank(123) + + def test_valid_step_or_rank_invalid(self): + with self.assertRaises(MsprobeException) as cm: + valid_step_or_rank([0.35]) + + def test_valid_level_valid_none(self): + self.assertEqual(valid_level(""), "") + + def test_valid_level_invalid_type(self): + with self.assertRaises(MsprobeException) as cm: + valid_level(123) + self.assertIn("[ERROR] invalid data type.", str(cm.exception)) + + def test_valid_level_valid(self): + self.assertEqual(valid_level(["kernel", "layer"]), ["kernel", "layer"]) + + def test_valid_log_level_valid_none(self): + self.assertEqual(valid_log_level(""), "") + + def test_valid_log_level_invalid_type(self): + with self.assertRaises(MsprobeException) as cm: + valid_log_level(123) + self.assertIn("[ERROR] invalid data type.", str(cm.exception)) + + def test_valid_level_invalid(self): + with self.assertRaises(MsprobeException): + valid_level(["invalid_level"]) + + def test_valid_log_level_valid(self): + self.assertEqual(valid_log_level("info"), "info") + + def test_valid_log_level_invalid(self): + with self.assertRaises(MsprobeException): + valid_log_level("invalid") + + def test_valid_seed_valid_none(self): + self.assertEqual(valid_seed(""), "") + + def test_valid_seed_valid(self): + self.assertEqual(valid_seed(42), 42) + + def test_valid_seed_invalid(self): + with self.assertRaises(MsprobeException): + valid_seed("not_an_int") + with self.assertRaises(MsprobeException): + valid_seed(-1) diff --git a/accuracy_tools/test/UT/core_ut/probe_ut/base_ut/test_dump_actuator.py b/accuracy_tools/test/UT/core_ut/probe_ut/base_ut/test_dump_actuator.py new file mode 100644 index 00000000000..c2a1c21f4b3 --- /dev/null +++ b/accuracy_tools/test/UT/core_ut/probe_ut/base_ut/test_dump_actuator.py @@ -0,0 +1,133 @@ +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np + +from msprobe.core.probe.base import OfflineModelActuator +from msprobe.utils.exceptions import MsitException + + +class TestOfflineModelActuator(unittest.TestCase): + def setUp(self): + self.mock_logger = MagicMock() + self.mock_dependent = MagicMock() + self.mock_DirPool = MagicMock() + self.mock_save_npy = MagicMock() + self.mock_load_npy = MagicMock() + self.mock_load_bin = MagicMock() + + self.patcher1 = patch("msit.module.probe.base.dump_actuator.logger", self.mock_logger) + self.patcher2 = patch("msit.module.probe.base.dump_actuator.dependent", self.mock_dependent) + self.patcher3 = patch("msit.module.probe.base.dump_actuator.DirPool", self.mock_DirPool) + self.patcher4 = patch("msit.module.probe.base.dump_actuator.save_npy", self.mock_save_npy) + self.patcher5 = patch("msit.module.probe.base.dump_actuator.load_npy", self.mock_load_npy) + self.patcher6 = patch("msit.module.probe.base.dump_actuator.load_bin_data", self.mock_load_bin) + + self.patcher1.start() + self.patcher2.start() + self.patcher3.start() + self.patcher4.start() + self.patcher5.start() + self.patcher6.start() + + def tearDown(self): + self.patcher1.stop() + self.patcher2.stop() + self.patcher3.stop() + self.patcher4.stop() + self.patcher5.stop() + self.patcher6.stop() + + def test_is_dynamic_shape(self): + self.assertFalse(OfflineModelActuator._is_dynamic_shape([1, 3, 224, 224])) + self.assertTrue(OfflineModelActuator._is_dynamic_shape([None, 3, 224, 224])) + self.assertTrue(OfflineModelActuator._is_dynamic_shape(["batch", 3, 224, 224])) + + def test_process_tensor_shape_dynamic_valid(self): + actuator = OfflineModelActuator( + model_path="model.onnx", input_shape={"input1": [1, 224, 224, 3]}, input_path="" + ) + result = actuator.process_tensor_shape("input1", "tensor(float32)", [None, 224, 224, 3]) + expected = [{"name": "input1", "shape": [1, 224, 224, 3], "type": "tensor(float32)"}] + self.assertEqual(result, expected) + self.mock_logger.info.assert_called_with("The dynamic shape of input1 has been fixed to [1, 224, 224, 3].") + + def test_process_tensor_shape_dynamic_missing_input(self): + actuator = OfflineModelActuator(model_path="model.onnx", input_shape={}, input_path="") + with self.assertRaises(MsitException) as context: + actuator.process_tensor_shape("input1", "tensor(float32)", [None, 224, 224, 3]) + self.assertIn("dynamic shape", str(context.exception)) + + def test_check_input_shape_mismatch(self): + with self.assertRaises(MsitException) as context: + OfflineModelActuator._check_input_shape("input1", [1, 3, 224, 224], [1, 4, 224, 224]) + self.assertIn("does not match", str(context.exception)) + + @patch("os.path.exists") + def test_get_inputs_data_generate_random(self, mock_exists): + mock_exists.return_value = True + self.mock_DirPool.get_input_dir.return_value = "/mock/input" + self.mock_dependent.get_tensorflow.return_value = (None, None, None) + actuator = OfflineModelActuator(model_path="model.onnx", input_shape={}, input_path="") + inputs_info = [{"name": "input1", "shape": [1, 3, 224, 224], "type": "tensor(float16)"}] + with patch("numpy.random.random") as mock_random: + mock_random.return_value = np.zeros((1, 3, 224, 224), dtype=np.float32) + result = actuator.get_inputs_data(inputs_info) + self.mock_save_npy.assert_called_once() + self.assertIn("input1", result) + self.assertEqual(result["input1"].shape, (1, 3, 224, 224)) + + def test_get_inputs_data_load_existing(self): + self.mock_dependent.get_tensorflow.return_value = (None, None, None) + actuator = OfflineModelActuator(model_path="model.onnx", input_shape={}, input_path=["input1.npy"]) + inputs_info = [{"name": "input1", "shape": [1, 3, 224, 224], "type": "tensor(float16)"}] + self.mock_load_npy.return_value = np.zeros((1, 3, 224, 224), dtype=np.float32) + result = actuator.get_inputs_data(inputs_info) + self.mock_load_npy.assert_called_with("input1.npy") + self.assertEqual(result["input1"].shape, (1, 3, 224, 224)) + + def test_read_input_shape_mismatch(self): + self.mock_dependent.get_tensorflow.return_value = (None, None, None) + actuator = OfflineModelActuator(model_path="model.onnx", input_shape={}, input_path=["input1.bin"]) + inputs_info = [{"name": "input1", "shape": [1, 3, 224, 224], "type": "tensor(float16)"}] + self.mock_load_bin.return_value = np.zeros((2, 3, 224, 224)) + with self.assertRaises(MsitException) as context: + actuator.get_inputs_data(inputs_info) + self.assertIn("does not match", str(context.exception)) + + def test_type_conversion(self): + self.mock_dependent.get_tensorflow.return_value = (None, None, None) + dtype = OfflineModelActuator._tensor2numpy_for_type("tensor(int32)") + self.assertEqual(dtype, np.int32) + + def test_invalid_type_conversion(self): + self.mock_dependent.get_tensorflow.return_value = (None, None, None) + with self.assertRaises(MsitException) as context: + OfflineModelActuator._tensor2numpy_for_type("tensor(unknown)") + self.assertIn("invalid data type", str(context.exception)) + + def test_valid_static_shape(self): + OfflineModelActuator._check_input_shape("input1", model_shape=[1, 3, 224, 224], input_shape=[1, 3, 224, 224]) + + def test_missing_input_shape(self): + with self.assertRaises(MsitException) as ctx: + OfflineModelActuator._check_input_shape("input1", model_shape=[1, 3, 224, 224], input_shape=[]) + self.assertIn("Required argument missing", str(ctx.exception)) + + def test_dimension_mismatch(self): + with self.assertRaises(MsitException) as ctx: + OfflineModelActuator._check_input_shape("input1", model_shape=[1, 3, 224, 224], input_shape=[1, 3, 224]) + self.assertIn("Unequal lengths", str(ctx.exception)) + + def test_dynamic_dimension_skip(self): + OfflineModelActuator._check_input_shape("input1", model_shape=[None, 3, 224, 224], input_shape=[2, 3, 224, 224]) + OfflineModelActuator._check_input_shape( + "input1", model_shape=["batch", 3, 224, 224], input_shape=[4, 3, 224, 224] + ) + + def test_static_shape_processing(self): + actuator = OfflineModelActuator(model_path="model.onnx", input_shape={}, input_path="") + result = actuator.process_tensor_shape( + tensor_name="input1", tensor_type="tensor(float16)", tensor_shape=[1, 3, 224, 224] + ) + self.assertEqual(result, [{"name": "input1", "shape": [1, 3, 224, 224], "type": "tensor(float16)"}]) diff --git a/accuracy_tools/test/UT/core_ut/probe_ut/base_ut/test_dump_dumper.py b/accuracy_tools/test/UT/core_ut/probe_ut/base_ut/test_dump_dumper.py new file mode 100644 index 00000000000..02d5af7d225 --- /dev/null +++ b/accuracy_tools/test/UT/core_ut/probe_ut/base_ut/test_dump_dumper.py @@ -0,0 +1,39 @@ +import unittest +from unittest.mock import patch + +from msprobe.core.probe.base import BaseDumper + + +class TestBaseDumper(unittest.TestCase): + def setUp(self): + class ConcreteDumper(BaseDumper): + def register_hook(self): + pass + + self.dumper = ConcreteDumper() + + def test_init_handler_is_empty_list(self): + self.assertEqual(self.dumper.handler, []) + + @patch("msit.module.probe.base.dump_dumper.release") + def test_release_hook_with_multiple_handlers(self, mock_release): + test_handlers = [0xDEADBEEF, 0xCAFEBABE] + self.dumper.handler = test_handlers + self.dumper.release_hook() + self.assertEqual(mock_release.call_count, len(test_handlers)) + mock_release.assert_any_call(test_handlers[0]) + mock_release.assert_any_call(test_handlers[1]) + + @patch("msit.module.probe.base.dump_dumper.release") + def test_release_hook_with_empty_handler(self, mock_release): + self.dumper.handler = [] + self.dumper.release_hook() + mock_release.assert_not_called() + + def test_abstract_method_enforcement(self): + with self.assertRaises(TypeError): + + class InvalidDumper(BaseDumper): + pass + + InvalidDumper() diff --git a/accuracy_tools/test/UT/core_ut/probe_ut/base_ut/test_dump_writer.py b/accuracy_tools/test/UT/core_ut/probe_ut/base_ut/test_dump_writer.py new file mode 100644 index 00000000000..8d9f67db2ca --- /dev/null +++ b/accuracy_tools/test/UT/core_ut/probe_ut/base_ut/test_dump_writer.py @@ -0,0 +1,166 @@ +import sys +import unittest +from unittest.mock import MagicMock, call, patch + +from refactor.msit.inference_tools.msit.module.probe.components.dumper_writer import _WITHOUT_CALL_STACK, WriterDump + +from msprobe.common.dirs import DirPool +from msprobe.utils.constants import CfgConst, DumpConst +from msprobe.utils.log import logger +from msprobe.utils.toolkits import get_valid_name + + +class TestWriterDump(unittest.TestCase): + def setUp(self): + class TestWriterDumpConcrete(WriterDump): + def summ_dump_data(self): + return "test_result" + + self.writer_cls = TestWriterDumpConcrete + self.task = CfgConst.TASK_TENSOR + self.mock_get_rank_dir = patch.object(DirPool, "get_rank_dir", return_value="/fake/rank_dir").start() + self.mock_get_tensor_dir = patch.object(DirPool, "get_tensor_dir", return_value="/fake/tensor_dir").start() + self.mock_get_model_dir = patch.object(DirPool, "get_model_dir", return_value="/fake/model_dir").start() + self.mock_make_tensor_dir = patch.object(DirPool, "make_tensor_dir").start() + self.mock_dirpool = patch("msit.common.dirs.DirPool").start() + self.mock_datastat = patch("msit.common.stat.DataStat").start() + self.addCleanup(patch.stopall) + + def test_init(self): + writer = self.writer_cls(self.task) + self.assertEqual(writer.task, self.task) + self.assertEqual(writer.max_cache_size, 1_048_576) + self.assertEqual(writer.cache_dump_json_size, 0) + self.assertIn(CfgConst.TASK, writer.cache_dump_json) + + @patch("msit.module.probe.base.dump_writer.stack") + def test_call_stack(self, mock_stack): + mock_stack.return_value = [ + (None, "msit/core/module.py", 10, "func1", ["code1"], None), + (None, "user_script.py", 20, "func2", ["code2"], None), + ] + writer = self.writer_cls(self.task) + stack_info = writer._call_stack("test_node") + self.assertIn("test_node", stack_info) + self.assertNotIn("msit/core", stack_info["test_node"][0]) + + @patch("msit.module.probe.base.dump_writer.stack", side_effect=Exception("mock error")) + @patch.object(logger, "warning") + def test_call_stack_exception_handling(self, mock_warn, mock_stack): + writer = self.writer_cls(self.task) + result = writer._call_stack("test_node") + mock_warn.assert_called_once_with("The call stack of test_node failed to retrieve, mock error.") + self.assertEqual(result, {"test_node": [_WITHOUT_CALL_STACK]}) + + def test_remove_colon_with_colon(self): + writer = self.writer_cls(self.task) + self.assertEqual(writer._remove_colon("node:output"), "node") + + def test_remove_colon_without_colon(self): + writer = self.writer_cls(self.task) + self.assertEqual(writer._remove_colon("node_output"), "node_output") + + @patch.object(WriterDump, "_save_stack_json") + def test_update_stack_behavior(self, mock_save): + writer = self.writer_cls(self.task) + mock_stack_data = {"node1": ["stack_line1"]} + with patch.object(writer, "_call_stack", return_value=mock_stack_data): + writer.update_stack("node1") + self.assertEqual(writer.cache_stack_json, mock_stack_data) + self.assertEqual(writer.cache_stack_json_size, sys.getsizeof(mock_stack_data)) + mock_save.assert_not_called() + writer.cache_stack_json_size = writer.max_cache_size - 1 + writer.update_stack("node2") + mock_save.assert_called_once() + self.assertEqual(writer.cache_stack_json_size, 0) + + @patch.object(WriterDump, "_save_dump_json") + def test_update_stat_flushes_cache(self, mock_save): + writer = self.writer_cls(self.task) + writer.max_cache_size = 100 + self.mock_datastat.collect_stats_for_numpy.return_value = {"mean": 0.5} + + with patch("sys.getsizeof", return_value=150): + writer.cache_dump_json[DumpConst.DATA].setdefault(get_valid_name("node1"), {}) + writer.update_stat("node1", "input", "arg1", MagicMock()) + mock_save.assert_called_once() + + @patch.object(WriterDump, "update_stat") + @patch.object(WriterDump, "_save_tensor_data") + def test_through_inputs(self, mock_save_tensor, mock_update_stat): + writer = self.writer_cls(self.task) + inputs = [MagicMock(name="input1"), "input2"] + input_map = {"input1": "data1", "input2": "data2"} + writer.through_inputs(inputs, "node1", input_map) + self.assertEqual(mock_update_stat.call_count, 2) + self.assertEqual(mock_save_tensor.call_count, 2) + + @patch.object(WriterDump, "update_stat") + @patch.object(WriterDump, "_save_tensor_data") + @patch.object(logger, "info") + def test_through_outputs_behavior(self, mock_info, mock_save_tensor, mock_update_stat): + writer = self.writer_cls(self.task) + mock_output_obj = MagicMock(name="output_obj") + mock_output_obj.name = "output2" + outputs = ["output1", mock_output_obj] + output_map = {"output1": "data1", "output2": "data2"} + writer.net_output_nodes = ["output2"] + writer.through_outputs(outputs, "test_node", output_map) + calls = [ + call("test_node", DumpConst.OUTPUT_ARGS, "output1", "data1"), + call("test_node", DumpConst.OUTPUT_ARGS, "output2", "data2"), + ] + mock_update_stat.assert_has_calls(calls) + mock_info.assert_called_once_with("net_output node index is: 0, node name: output2.") + writer.task = CfgConst.TASK_TENSOR + writer.through_outputs(outputs, "test_node", output_map) + mock_save_tensor.assert_has_calls( + [call("test_node", DumpConst.OUTPUT, 0, "data1"), call("test_node", DumpConst.OUTPUT, 1, "data2")] + ) + + @patch.object(WriterDump, "_save_dump_json") + @patch.object(WriterDump, "_save_stack_json") + def test_flush_remaining_cache(self, mock_save_stack, mock_save_dump): + writer = self.writer_cls(self.task) + writer.cache_dump_json_size = 500 + writer.cache_stack_json_size = 500 + writer._flush_remaining_cache() + mock_save_dump.assert_called_once() + mock_save_stack.assert_called_once() + + @patch.object(DirPool, "get_rank_dir", return_value="/mock/rank") + @patch("msit.module.probe.base.dump_writer.save_json") + def test_save_stack_json(self, mock_save, mock_dir): + test_data = {"node1": ["stack_info"]} + writer = self.writer_cls(self.task) + writer.cache_stack_json = test_data + writer._save_stack_json() + mock_save.assert_called_once_with(test_data, "/mock/rank/stack.json", indent=4) + + @patch("msit.module.probe.base.dump_writer.save_json") + def test_save_dump_json(self, mock_save): + writer = self.writer_cls(self.task) + writer.cache_dump_json = {"data": "test"} + writer._save_dump_json() + mock_save.assert_called_once_with({"data": "test"}, "/fake/rank_dir/dump.json", indent=4) + self.mock_get_rank_dir.assert_called_once() + + @patch("msit.module.probe.base.dump_writer.save_npy") + @patch("msit.module.probe.base.dump_writer.MsitPath.check") + def test_save_tensor_data(self, mock_msitpath, mock_save_npy): + writer = self.writer_cls(self.task) + mock_msitpath.return_value = "/fake/tensor/path.npy" + self.mock_dirpool.get_tensor_dir.return_value = "/fake/tensor_dir" + writer._save_tensor_data("node1", "input", 0, "tensor_data") + mock_save_npy.assert_called_once_with("tensor_data", "/fake/tensor/path.npy") + + @patch.object(WriterDump, "_flush_remaining_cache") + @patch("msit.module.probe.base.dump_writer.save_json") + def test_summ_dump_data_decorator(self, mock_save, mock_flush): + writer = self.writer_cls(self.task) + writer.net_output_nodes = ["output1"] + self.mock_dirpool.get_model_dir.return_value = "/fake/model_dir" + result = writer.summ_dump_data() + self.assertEqual(result, "test_result") + mock_flush.assert_called_once() + mock_save.assert_called_once_with(["output1"], "/fake/model_dir/net_output_nodes.json") diff --git a/accuracy_tools/test/UT/core_ut/probe_ut/components_ut/test_dumper_offline_model.py b/accuracy_tools/test/UT/core_ut/probe_ut/components_ut/test_dumper_offline_model.py new file mode 100644 index 00000000000..45b4da840af --- /dev/null +++ b/accuracy_tools/test/UT/core_ut/probe_ut/components_ut/test_dumper_offline_model.py @@ -0,0 +1,26 @@ +import unittest +from unittest.mock import patch + +from msprobe.base import BaseComponent +from msprobe.core.probe.components.dumper_offline_model import OfflineModelActuatorComp + + +class TestOfflineModelActuatorComp(unittest.TestCase): + def test_inheritance(self): + self.assertTrue(issubclass(OfflineModelActuatorComp, BaseComponent)) + + @patch.object(BaseComponent, "__init__", return_value=None) + def test_init_default_priority(self, mock_base_init): + OfflineModelActuatorComp() + mock_base_init.assert_called_once_with(100) + + @patch.object(BaseComponent, "__init__", return_value=None) + def test_init_custom_priority(self, mock_base_init): + custom_priority = 200 + OfflineModelActuatorComp(priority=custom_priority) + mock_base_init.assert_called_once_with(custom_priority) + + @patch.object(BaseComponent, "__init__", return_value=None) + def test_instance_type(self, mock_base_init): + instance = OfflineModelActuatorComp() + self.assertIsInstance(instance, BaseComponent) diff --git a/accuracy_tools/test/UT/core_ut/probe_ut/config_initiator_ut/test_validate_params.py b/accuracy_tools/test/UT/core_ut/probe_ut/config_initiator_ut/test_validate_params.py new file mode 100644 index 00000000000..55800a17ddc --- /dev/null +++ b/accuracy_tools/test/UT/core_ut/probe_ut/config_initiator_ut/test_validate_params.py @@ -0,0 +1,325 @@ +import unittest +from unittest.mock import MagicMock, patch + +from msprobe.core.probe.config_initiator.validate_params import ( + OfflineModelInput, + valid_data_mode, + valid_device, + valid_dump_extra, + valid_dump_ge_graph, + valid_dump_graph_level, + valid_dump_last_logits, + valid_dump_path, + valid_dump_time, + valid_dump_weight, + valid_fusion_switch_file, + valid_input, + valid_list, + valid_onnx_fusion_switch, + valid_op_id, + valid_saved_model_signature, + valid_saved_model_tag, + valid_weight_path, +) +from msprobe.utils.exceptions import MsitException + + +class TestValidators(unittest.TestCase): + + @patch("msit.module.probe.config_initiator.validate_params.MsitPath") + def test_valid_dump_path(self, mock_msit_path): + mock_instance = MagicMock() + mock_instance.check.return_value = "valid" + mock_msit_path.return_value = mock_instance + result = valid_dump_path("some/path") + self.assertEqual(result, "valid") + mock_msit_path.assert_called_once() + + def test_valid_list_with_none(self): + value = ("", ["level1", "level2"]) + result = valid_list(value) + self.assertEqual(result, {"level1": ""}) + + def test_valid_list_with_dict(self): + value = ({"level1": ["sdgf"]}, ["level1", "level2"]) + result = valid_list(value) + self.assertEqual(result, {"level1": ["sdgf"]}) + + def test_valid_list_invalid_dict_key(self): + value = ({"bad": [1]}, ["allowed"]) + with self.assertRaises(MsitException): + valid_list(value) + + def test_valid_list_invalid_list(self): + value = ({"level1": 12}, ["level1", "level2"]) + with self.assertRaises(MsitException): + valid_list(value) + + def test_invalid_list(self): + value = (12, ["level1", "level2"]) + with self.assertRaises(MsitException): + valid_list(value) + + def test_valid_data_mode_none(self): + result = valid_data_mode([]) + self.assertEqual(result, []) + + @patch("msit.module.probe.config_initiator.validate_params.DumpConst.ALL_DATA_MODE", ["mode1"]) + def test_valid_data_mode_valid(self): + result = valid_data_mode(["mode1"]) + self.assertEqual(result, ["mode1"]) + + @patch("msit.module.probe.config_initiator.validate_params.DumpConst.ALL_DATA_MODE", ["mode1"]) + def test_valid_data_mode_invalid(self): + with self.assertRaises(MsitException): + valid_data_mode(["invalid"]) + + @patch("msit.module.probe.config_initiator.validate_params.DumpConst.ALL_DATA_MODE", ["mode1"]) + def test_invalid_data_mode(self): + with self.assertRaises(MsitException): + valid_data_mode(12) + + @patch("msit.module.probe.config_initiator.validate_params.DumpConst.ALL_DATA_MODE", ["mode1", "mode2"]) + def test_invalid_data_mode_more_element(self): + with self.assertRaises(MsitException): + valid_data_mode(["mode1", "mode2"]) + + @patch("msit.module.probe.config_initiator.validate_params.DumpConst.ALL_DUMP_EXTRA", ["extra1"]) + def valid_dump_extra_none(self): + self.assertIsNone(valid_dump_extra(None)) + result = valid_dump_extra([]) + self.assertEqual(result, []) + + @patch("msit.module.probe.config_initiator.validate_params.DumpConst.ALL_DUMP_EXTRA", ["extra1"]) + def test_valid_dump_extra_valid(self): + result = valid_dump_extra(["extra1"]) + self.assertEqual(result, ["extra1"]) + + @patch("msit.module.probe.config_initiator.validate_params.DumpConst.ALL_DUMP_EXTRA", ["extra1"]) + def test_valid_dump_extra_invalid(self): + with self.assertRaises(MsitException): + valid_dump_extra(123) + with self.assertRaises(MsitException): + valid_dump_extra(["bad"]) + + @patch("msit.module.probe.config_initiator.validate_params.DumpConst.ALL_DUMP_TIME", ["before", "after"]) + def test_valid_dump_time(self): + result = valid_dump_time("") + self.assertEqual(result, "") + result = valid_dump_time("before") + self.assertEqual(result, "before") + + @patch("msit.module.probe.config_initiator.validate_params.DumpConst.ALL_DUMP_TIME", ["before", "after"]) + def test_valid_dump_time_invalid_element(self): + with self.assertRaises(MsitException): + valid_dump_time(["bad"]) + with self.assertRaises(MsitException): + valid_dump_time("invalid") + + def test_valid_op_id_none(self): + result = valid_op_id("") + self.assertEqual(result, "") + + def test_valid_op_id(self): + valid_list = [1, "3_1", "4_2_3"] + result = valid_op_id(valid_list) + self.assertEqual(result, valid_list) + + def test_valid_op_id_invalid_element_format(self): + with self.assertRaises(MsitException): + valid_op_id(12) + with self.assertRaises(MsitException): + valid_op_id([["invalid"]]) + + def test_valid_dump_last_logits(self): + self.assertIsNone(valid_dump_last_logits(None)) + self.assertTrue(valid_dump_last_logits(True)) + with self.assertRaises(MsitException): + valid_dump_last_logits("true") + + def test_valid_dump_weight(self): + self.assertIsNone(valid_dump_weight(None)) + self.assertTrue(valid_dump_weight(True)) + with self.assertRaises(MsitException): + valid_dump_weight("true") + + def test_valid_dump_ge_graph(self): + self.assertIsNone(valid_dump_ge_graph(None)) + with self.assertRaises(MsitException): + valid_dump_ge_graph(123) + with self.assertRaises(MsitException): + valid_dump_ge_graph("8") + self.assertEqual(valid_dump_ge_graph("2"), "2") + + def test_valid_dump_graph_level(self): + self.assertIsNone(valid_dump_graph_level(None)) + with self.assertRaises(MsitException): + valid_dump_graph_level(123) + with self.assertRaises(MsitException): + valid_dump_graph_level("8") + self.assertEqual(valid_dump_graph_level("2"), "2") + + @patch("msit.module.probe.config_initiator.validate_params.MsitPath") + def test_valid_fusion_switch_file(self, mock_msit_path): + self.assertIsNone(valid_fusion_switch_file(None)) + mock_instance = MagicMock() + mock_instance.check.return_value = "valid" + mock_msit_path.return_value = mock_instance + result = valid_fusion_switch_file("some/path") + self.assertEqual(result, "valid") + mock_msit_path.assert_called_once() + + def test_valid_device(self): + self.assertIsNone(valid_device(None)) + self.assertEqual(valid_device("cpu"), "cpu") + with self.assertRaises(MsitException): + valid_device("gpu") + with self.assertRaises(MsitException): + valid_device(123) + + @patch("msit.module.probe.config_initiator.validate_params.MsitPath") + def test_valid_weight_path(self, mock_path): + self.assertIsNone(valid_weight_path(None)) + mock_path().check.return_value = "checked" + result = valid_weight_path("file.caffemodel") + self.assertEqual(result, "checked") + + def test_valid_onnx_fusion_switch(self): + self.assertIsNone(valid_onnx_fusion_switch(None)) + self.assertTrue(valid_onnx_fusion_switch(True)) + with self.assertRaises(MsitException): + valid_onnx_fusion_switch(123) + + def test_valid_saved_model_tag(self): + self.assertIsNone(valid_saved_model_tag(None)) + with self.assertRaises(MsitException): + valid_saved_model_tag(123) + with self.assertRaises(MsitException): + valid_saved_model_tag(["%qsc/"]) + self.assertEqual(valid_saved_model_tag(["qazx"]), ["qazx"]) + + def test_valid_saved_model_signature(self): + self.assertIsNone(valid_saved_model_signature(None)) + with self.assertRaises(MsitException): + valid_saved_model_signature(["%qsc/"]) + valid_saved_model_signature(123) + self.assertEqual(valid_saved_model_signature("wsx"), "wsx") + + +class TestValidInputAndOfflineModelInput(unittest.TestCase): + def test_valid_input_none(self): + self.assertIsNone(valid_input(None)) + + @patch("msit.module.probe.config_initiator.validate_params.OfflineModelInput") + def test_valid_input_calls_parse(self, mock_input_cls): + mock_parser = MagicMock() + mock_input_cls.return_value = mock_parser + valid_input([{"name": "x"}]) + mock_parser.parse.assert_called_once() + + def test_check_form_not_list(self): + with self.assertRaisesRegex(MsitException, "The input must be a list."): + OfflineModelInput("invalid") + + def test_check_form_element_not_dict(self): + with self.assertRaisesRegex(MsitException, "Each element in the input must be a dictionary."): + OfflineModelInput([1, 2]) + + def test_check_name_missing(self): + with self.assertRaisesRegex(MsitException, "Each input must have a name."): + OfflineModelInput([{}])._check_name({}) + + def test_check_input_shape_invalid_type(self): + with self.assertRaisesRegex(MsitException, "must be a list"): + OfflineModelInput([{}])._check_input_shape({"shape": "not_list"}, "input1") + + def test_check_input_shape_element_not_int(self): + with self.assertRaisesRegex(MsitException, "support only integers"): + OfflineModelInput([{}])._check_input_shape({"shape": [1, "a"]}, "input1") + + @patch("msit.module.probe.config_initiator.validate_params.MsitPath") + def test_check_input_path_invalid_type(self, mock_path): + with self.assertRaisesRegex(MsitException, "must be a string"): + OfflineModelInput([{}])._check_input_path({"path": 123}, "input1") + + @patch("msit.module.probe.config_initiator.validate_params.MsitPath") + def test_check_input_path_invalid_suffix(self, mock_path): + with self.assertRaisesRegex(MsitException, "can only accept .npy or .bin"): + OfflineModelInput([{}])._check_input_path({"path": "file.txt"}, "input1") + + @patch("msit.module.probe.config_initiator.validate_params.MsitPath") + def test_check_input_path_valid(self, mock_path): + mock_check = mock_path.return_value.check + mock_check.return_value = True + self.assertIsNone(OfflineModelInput([{}])._check_input_path({"path": "input.npy"}, "input1")) + + @patch("msit.module.probe.config_initiator.validate_params.parse_hyphen") + def test_parse_shape_range_for_str_hyphen(self, mock_parse): + mock_parse.return_value = [1, 2] + result = OfflineModelInput([{}])._parse_shape_range_for_str("1-2") + self.assertEqual(result, [1, 2]) + + def test_parse_shape_range_for_str_comma_valid(self): + result = OfflineModelInput([{}])._parse_shape_range_for_str("2,3") + self.assertEqual(result, [2, 3]) + + def test_parse_shape_range_for_str_invalid_format(self): + with self.assertRaisesRegex(MsitException, "can only contain hyphen"): + OfflineModelInput([{}])._parse_shape_range_for_str("wrong") + + @patch("msit.module.probe.config_initiator.validate_params.check_int_border") + @patch("msit.module.probe.config_initiator.validate_params.OfflineModelInput._parse_shape_range_for_str") + def test_parse_dym_shape_range_mixed(self, mock_parse, mock_check): + mock_parse.return_value = [1, 2] + input_obj = OfflineModelInput([{}]) + result = input_obj._parse_dym_shape_range(["1-2", 3], "input1") + self.assertIsInstance(result, list) + + def test_parse_dym_shape_range_invalid_type(self): + with self.assertRaisesRegex(MsitException, "must be a list"): + OfflineModelInput([{}])._parse_dym_shape_range("not_list", "input1") + + def test_parse_dym_shape_range_element_type_error(self): + with self.assertRaisesRegex(MsitException, "support only string and integers"): + OfflineModelInput([{}])._parse_dym_shape_range([1.5], "input1") + + @patch("msit.module.probe.config_initiator.validate_params.logger") + @patch.object(OfflineModelInput, "_parse_dym_shape_range") + def test_check_dym_shape_with_path(self, mock_parse_dym, mock_logger): + mock_parse_dym.return_value = [[1, 2], [3, 4]] + input_obj = OfflineModelInput([{}]) + result = input_obj._check_dym_shape({"name": "x", "dym_shape": ["1-2"], "path": "input.npy"}, "x") + self.assertEqual(result["path"], "") + self.assertEqual(result["shape"], []) + + def test_draw_shape_and_path_static(self): + input_obj = OfflineModelInput([{}]) + input_obj.is_need_expand_shape = False + shapes, paths = input_obj._draw_shape_and_path([{"name": "x", "shape": [1, 2], "path": "x.npy"}]) + self.assertEqual(shapes["x"], [1, 2]) + self.assertEqual(paths, ["x.npy"]) + + def test_draw_shape_and_path_dynamic_valid(self): + input_obj = OfflineModelInput([{}]) + input_obj.is_need_expand_shape = True + input_data = [{"name": "x", "dym_shape": [[1], [2]]}, {"name": "y", "dym_shape": [[3], [4]]}] + shapes, paths = input_obj._draw_shape_and_path(input_data) + self.assertEqual(len(shapes), 2) + self.assertIsNone(paths) + + def test_draw_shape_and_path_dynamic_invalid(self): + input_obj = OfflineModelInput([{}]) + input_obj.is_need_expand_shape = True + input_data = [{"name": "x", "dym_shape": [[1], [2], [3]]}, {"name": "y", "dym_shape": [[4], [5]]}] + with self.assertRaisesRegex(MsitException, "same expanded dynamic shape length"): + input_obj._draw_shape_and_path(input_data) + + @patch.object(OfflineModelInput, "_check_name", return_value="x") + @patch.object(OfflineModelInput, "_check_input_shape") + @patch.object(OfflineModelInput, "_check_input_path") + @patch.object(OfflineModelInput, "_check_dym_shape", side_effect=lambda x, y: x) + @patch.object(OfflineModelInput, "_draw_shape_and_path", return_value=({}, [])) + def test_parse_calls_all(self, mock_draw, mock_dym, mock_path, mock_shape, mock_name): + input_obj = OfflineModelInput([{"name": "x"}]) + result = input_obj.parse() + self.assertEqual(result, ({}, [])) diff --git a/accuracy_tools/test/UT/core_ut/probe_ut/dump_ut/test_caffe_model.py b/accuracy_tools/test/UT/core_ut/probe_ut/dump_ut/test_caffe_model.py new file mode 100644 index 00000000000..5ff7ac2c1e8 --- /dev/null +++ b/accuracy_tools/test/UT/core_ut/probe_ut/dump_ut/test_caffe_model.py @@ -0,0 +1,118 @@ +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np + +from msprobe.core.probe.dump.caffe_model import CaffeModelActuator, CaffeModelDataWriter +from msprobe.utils.exceptions import MsitException + + +class TestCaffeModelActuator(unittest.TestCase): + @patch("msit.module.probe.dump.caffe_model.load_caffe_model") + def test_load_model(self, mock_load_model): + mock_model = MagicMock() + mock_load_model.return_value = mock_model + actuator = CaffeModelActuator( + model_path="model.prototxt", + input_shape=(1, 3, 224, 224), + input_path="input.npy", + weight_path="model.caffemodel", + ) + actuator.load_model() + mock_load_model.assert_called_once_with("model.prototxt", "model.caffemodel") + self.assertEqual(actuator.model, mock_model) + + def test_missing_weight_raises_exception(self): + with self.assertRaises(MsitException) as context: + CaffeModelActuator(model_path="model.prototxt", input_shape=(1, 3, 224, 224), input_path="input.npy") + self.assertIn("a weight file", str(context.exception)) + + @patch("msit.module.probe.dump.caffe_model.logger") + def test_get_input_tensor_info(self, mock_logger): + mock_model = MagicMock() + mock_blob = MagicMock() + mock_blob.data = np.zeros((1, 3, 224, 224), dtype=np.float32) + mock_model.blobs = {"data": mock_blob} + mock_model.inputs = ["data"] + + actuator = CaffeModelActuator( + model_path="model.prototxt", + input_shape=(1, 3, 224, 224), + input_path="input.npy", + weight_path="model.caffemodel", + ) + actuator.model = mock_model + result = actuator.get_input_tensor_info() + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["name"], "data") + self.assertEqual(result[0]["shape"], (1, 3, 224, 224)) + self.assertEqual(result[0]["type"], "float32") + + def test_infer_success(self): + actuator = CaffeModelActuator( + model_path="model.prototxt", + input_shape=(1, 3, 224, 224), + input_path="input.npy", + weight_path="model.caffemodel", + ) + + mock_model = MagicMock() + mock_blob = MagicMock() + mock_blob.data = np.zeros((1, 3, 224, 224)) + mock_model.blobs = {"data": mock_blob} + mock_model.forward.return_value = {"prob": np.array([1.0])} + actuator.model = mock_model + + input_data = {"data": np.zeros((1, 3, 224, 224))} + result = actuator.infer(input_data) + self.assertIn("prob", result) + + def test_infer_failure(self): + actuator = CaffeModelActuator( + model_path="model.prototxt", + input_shape=(1, 3, 224, 224), + input_path="input.npy", + weight_path="model.caffemodel", + ) + mock_model = MagicMock() + mock_blob = MagicMock() + mock_blob.data = np.zeros((1, 3, 224, 224)) + mock_model.blobs = {"data": mock_blob} + mock_model.forward.side_effect = RuntimeError("Failure") + actuator.model = mock_model + with self.assertRaises(MsitException): + actuator.infer({"data": np.zeros((1, 3, 224, 224))}) + + +class TestCaffeModelDataWriter(unittest.TestCase): + + def test_get_input_output_map(self): + writer = CaffeModelDataWriter(task="mock_task", data_mode=["input", "output"]) + mock_net = MagicMock() + mock_net.blobs = {"conv1": MagicMock(data=np.ones((1, 3, 224, 224)))} + mock_net.params = {"conv1": [MagicMock(data=np.ones((64, 3, 3, 3))), MagicMock(data=np.ones((64,)))]} + mock_net.bottom_names = {"conv1": []} + input_map, output_map = writer.get_input_output_map(mock_net) + self.assertIn("conv1_weight", input_map) + self.assertIn("conv1_bias", input_map) + self.assertIn("conv1", output_map) + + @patch("msit.module.probe.base.dump_writer.save_json") + @patch("msit.module.probe.dump.onnx_model.DirPool.get_model_dir", return_value="/mock/model/dir") + @patch("msit.module.probe.dump.onnx_model.get_valid_name") + def test_summ_dump_data(self, mock_get_valid_name, mock_model_dir, mock_save_json): + mock_get_valid_name.side_effect = lambda name: f"valid_{name}" + writer = CaffeModelDataWriter(task="mock_task", data_mode=["input", "output"]) + mock_net = MagicMock() + mock_net.outputs = ["fc"] + mock_net.blobs = {"fc": MagicMock(data=np.ones((1, 1000)))} + mock_net.top_names = {"fc": ["fc_out"]} + mock_net.bottom_names = {"fc": ["fc_in"]} + writer.caffe_net = mock_net + writer.through_inputs = MagicMock() + writer.through_outputs = MagicMock() + input_map = {"fc_weight": np.ones((1000, 512)), "fc_bias": np.ones((1000,))} + output_map = {"fc": np.ones((1, 1000))} + writer.summ_dump_data(input_map, output_map) + writer.through_inputs.assert_called() + writer.through_outputs.assert_called() diff --git a/accuracy_tools/test/UT/core_ut/probe_ut/dump_ut/test_onnx_model.py b/accuracy_tools/test/UT/core_ut/probe_ut/dump_ut/test_onnx_model.py new file mode 100644 index 00000000000..5b6c08735ac --- /dev/null +++ b/accuracy_tools/test/UT/core_ut/probe_ut/dump_ut/test_onnx_model.py @@ -0,0 +1,160 @@ +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np + +from msprobe.core.probe.dump import OnnxModelActuator, OnnxModelDataWriter +from msprobe.utils.exceptions import MsitException + + +class TestOnnxModelActuator(unittest.TestCase): + @patch("msit.module.probe.dump.onnx_model.load_onnx_session") + def test_infer_success(self, mock_load_session): + mock_output_node = MagicMock() + mock_output_node.name = "output1" + mock_session = MagicMock() + mock_session.get_outputs.return_value = [mock_output_node] + mock_session.run.return_value = ["dummy_output"] + mock_load_session.return_value = mock_session + result = OnnxModelActuator.infer("dummy_path", {"input1": np.array([1])}) + mock_session.run.assert_called_once_with(["output1"], {"input1": np.array([1])}) + self.assertEqual(result, ["dummy_output"]) + + @patch("msit.module.probe.dump.onnx_model.load_onnx_session") + def test_infer_failure(self, mock_load_session): + mock_session = MagicMock() + mock_session.get_outputs.return_value = [MagicMock(name="output1")] + mock_session.run.side_effect = Exception("Runtime error") + mock_load_session.return_value = mock_session + with self.assertRaises(MsitException) as context: + OnnxModelActuator.infer("dummy_path", {}) + self.assertIn("Please check if the input shape", str(context.exception)) + + @patch("msit.module.probe.dump.onnx_model.load_onnx_model") + @patch("msit.module.probe.dump.onnx_model.load_onnx_session") + def test_load_model(self, mock_load_session, mock_load_model): + actuator = OnnxModelActuator("model_path", None, None) + actuator.load_model() + + mock_load_model.assert_called_once_with("model_path") + mock_load_session.assert_called_once_with("model_path", True) + self.assertEqual(actuator.model_session, mock_load_session.return_value) + + def test_get_input_tensor_info(self): + actuator = OnnxModelActuator("model_path", None, None) + mock_input = MagicMock() + mock_input.name = "input1" + mock_input.type = "tensor(float32)" + mock_input.shape = [1, 3, 224, 224] + actuator.model_session = MagicMock() + actuator.model_session.get_inputs.return_value = [mock_input] + + result = actuator.get_input_tensor_info() + self.assertIn("input1", str(result)) + + @patch("msit.module.probe.dump.onnx_model.DirPool.get_uninfer_model_path") + @patch("msit.module.probe.dump.onnx_model.is_file", return_value=False) + @patch("msit.module.probe.dump.onnx_model.dependent.get") + @patch("msit.module.probe.dump.onnx_model.convert_bytes", return_value="10MB") + @patch("msit.module.probe.dump.onnx_model.save_onnx_model") + @patch("msit.module.probe.dump.onnx_model.logger") + def test_export_uninfer_model( + self, + mock_logger, + mock_save_model, + mock_convert_bytes, + mock_dependent_get, + mock_is_file, + mock_get_uninfer_model_path, + ): + mock_get_uninfer_model_path.return_value = "/fake/path/model_uninfer.onnx" + mock_onnx = MagicMock() + mock_value_info_proto = MagicMock() + mock_onnx.ValueInfoProto.return_value = mock_value_info_proto + mock_dependent_get.return_value = mock_onnx + fake_graph = MagicMock() + fake_graph.node = [MagicMock(output=["out1", "out2"])] + fake_graph.output = [] + actuator = OnnxModelActuator("model.onnx", (1, 3, 224, 224), "input.npy") + actuator.origin_model = MagicMock() + actuator.origin_model.graph = fake_graph + actuator.origin_model.ByteSize.return_value = 1024 * 1024 * 10 + result_path = actuator.export_uninfer_model() + self.assertEqual(result_path, "/fake/path/model_uninfer.onnx") + self.assertEqual(len(fake_graph.output), 2) + mock_save_model.assert_called_once_with(actuator.origin_model, "/fake/path/model_uninfer.onnx") + mock_logger.info.assert_any_call("The size of the modified ONNX model to be saved is 10MB.") + mock_logger.info.assert_any_call( + "The modified ONNX model has been successfully saved to /fake/path/model_uninfer.onnx." + ) + + +class TestOnnxModelDataWriter(unittest.TestCase): + def setUp(self): + self.task_mock = MagicMock() + self.dump_mode = ["all"] + self.writer = OnnxModelDataWriter(self.task_mock, self.dump_mode) + + @patch("msit.module.probe.dump.onnx_model.get_valid_name") + def test_get_output_map(self, mock_get_valid_name): + mock_model = MagicMock() + mock_model.graph.node = [MagicMock(output=["out1"]), MagicMock(output=["out2"])] + mock_get_valid_name.side_effect = lambda name: f"valid_{name}" + output_list = ["data1", "data2"] + output_map = self.writer._get_output_map(output_list, mock_model) + expected = {"valid_out1": "data1", "valid_out2": "data2"} + self.assertEqual(output_map, expected) + + @patch("msit.module.probe.dump.onnx_model.load_npy_from_buffer") + @patch("msit.module.probe.dump.onnx_model.get_valid_name") + def test_augment_input_map(self, mock_get_valid_name, mock_load_npy): + mock_model = MagicMock() + initializer = MagicMock() + initializer.name = "init1" + initializer.raw_data = b"data" + initializer.data_type = 1 + initializer.dims = [1] + mock_model.graph.initializer = [initializer] + mock_get_valid_name.return_value = "valid_init1" + mock_load_npy.return_value = np.array([123]) + input_map = {"input1": np.array([1])} + output_map = {"output1": np.array([2])} + result = self.writer._augment_input_map(input_map, output_map, mock_model) + expected_keys = {"input1", "valid_init1", "output1"} + self.assertTrue(expected_keys.issubset(result.keys())) + + @patch.object(OnnxModelDataWriter, "_augment_input_map") + @patch.object(OnnxModelDataWriter, "_get_output_map") + def test_get_input_output_map(self, mock_get_output_map, mock_augment_input_map): + input_map = {"a": 1} + output_list = ["out"] + origin_model = MagicMock() + mock_get_output_map.return_value = {"out1": 100} + mock_augment_input_map.return_value = {"a": 1, "out1": 100} + result_input, result_output = self.writer.get_input_output_map(input_map, output_list, origin_model) + self.assertEqual(result_output, {"out1": 100}) + self.assertEqual(result_input, {"a": 1, "out1": 100}) + + @patch("msit.module.probe.base.dump_writer.save_json") + @patch("msit.module.probe.dump.onnx_model.DirPool.get_model_dir", return_value="/mock/model/dir") + @patch("msit.module.probe.dump.onnx_model.get_valid_name") + def test_summ_dump_data(self, mock_get_valid_name, mock_model_dir, mock_save_json): + mock_model_session = MagicMock() + mock_output_info = MagicMock() + mock_output_info.name = "output_node_1" + mock_model_session.get_outputs.return_value = [mock_output_info] + mock_origin_model = MagicMock() + node1 = MagicMock() + node1.name = "node1" + node1.input = ["input1"] + node1.output = ["output1"] + mock_origin_model.graph.node = [node1] + mock_get_valid_name.side_effect = lambda name: f"valid_{name}" + input_map = {"valid_input1": "input_data"} + output_map = {"valid_output1": "output_data"} + self.writer.through_inputs = MagicMock() + self.writer.through_outputs = MagicMock() + self.writer.summ_dump_data(input_map, output_map, mock_origin_model, mock_model_session) + self.writer.through_inputs.assert_called_once_with(["input1"], "node1", input_map) + self.writer.through_outputs.assert_called_once_with(["output1"], "node1", output_map) + self.assertIn("valid_node1", self.writer.cache_dump_json["data"]) diff --git a/accuracy_tools/test/UT/core_ut/probe_ut/dump_ut/test_tf_model.py b/accuracy_tools/test/UT/core_ut/probe_ut/dump_ut/test_tf_model.py new file mode 100644 index 00000000000..085a32d27e3 --- /dev/null +++ b/accuracy_tools/test/UT/core_ut/probe_ut/dump_ut/test_tf_model.py @@ -0,0 +1,269 @@ +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np + +from msprobe.core.probe.dump.tf_model import ( + FrozenGraphActuator, + FrozenGraphActuatorCPU, + FrozenGraphActuatorNPU, + FrozenGraphDataWriter, +) +from msprobe.utils.exceptions import MsitException + + +class TestFrozenGraphActuator(unittest.TestCase): + + @patch("msit.module.probe.dump.tf_model.dependent.get_tensorflow") + def setUp(self, mock_get_tf): + mock_tf = MagicMock() + mock_rewriter_config = MagicMock() + mock_get_tf.return_value = (mock_tf, mock_rewriter_config, None) + self.actuator = FrozenGraphActuator( + model_path="fake_model.pb", input_shape=(1, 224, 224, 3), input_path="input.npy" + ) + self.tf = mock_tf + self.actuator.tf = mock_tf + + @patch("msit.module.probe.dump.tf_model.dependent.get_tensorflow") + def test_import_tf_success(self, mock_get_tf): + mock_tf = MagicMock() + mock_rewriter = MagicMock() + mock_get_tf.return_value = (mock_tf, mock_rewriter, "extra") + tf, rewriter = FrozenGraphActuator._import_tf() + self.assertEqual(tf, mock_tf) + self.assertEqual(rewriter, mock_rewriter) + mock_tf.compat.v1.disable_eager_execution.assert_called_once() + + @patch("msit.module.probe.dump.tf_model.dependent.get_tensorflow") + def test_import_tf_none(self, mock_get_tf): + mock_get_tf.return_value = (None, None, None) + tf, rewriter = FrozenGraphActuator._import_tf() + self.assertIsNone(tf) + self.assertIsNone(rewriter) + + @patch("msit.module.probe.dump.tf_model.load_pb_frozen_graph_model") + def test_load_model(self, mock_load_pb): + mock_graph_def = MagicMock() + mock_load_pb.return_value = mock_graph_def + self.actuator.load_model() + mock_load_pb.assert_called_once_with("fake_model.pb") + self.assertEqual(self.actuator.graph_def, mock_graph_def) + + def test_get_tensor_name(self): + name = FrozenGraphActuator._get_tensor_name("input:0") + self.assertEqual(name, "input") + name = FrozenGraphActuator._get_tensor_name("no_colon") + self.assertEqual(name, "no_colon") + + def test_tf_shape_to_list(self): + mock_shape = MagicMock() + dim1 = MagicMock(size=1) + dim2 = MagicMock(size=-1) + dim3 = MagicMock(size=3) + mock_shape.dim = [dim1, dim2, dim3] + result = FrozenGraphActuator._tf_shape_to_list(mock_shape) + self.assertEqual(result, [1, None, 3]) + + def test_get_input_tensor_info(self): + mock_dtype = MagicMock() + mock_dtype.type = 1 + mock_tensor_shape = MagicMock() + mock_tensor_shape.dim = [MagicMock(size=1), MagicMock(size=224)] + node = MagicMock() + node.name = "input_node" + node.op = "Placeholder" + node.attr = {"dtype": mock_dtype, "shape": MagicMock(shape=mock_tensor_shape)} + self.actuator.graph_def = MagicMock() + self.actuator.graph_def.node = [node] + self.actuator.tf.dtypes.as_dtype.return_value = "float32" + self.actuator.process_tensor_shape = MagicMock( + return_value=[{"name": "input_node", "shape": [1, 224], "type": "float32"}] + ) + result = self.actuator.get_input_tensor_info() + self.assertEqual(len(result), 1) + self.assertIn("input_node", self.actuator.all_node_names) + + def test_close_session(self): + mock_sess = MagicMock() + self.actuator.sess = mock_sess + self.actuator.close() + mock_sess.close.assert_called_once() + self.assertIsNone(self.actuator.sess) + + def test_close_session_no_attr(self): + self.actuator.sess = None + try: + self.actuator.close() + except Exception as e: + self.fail(f"close() raised an exception unexpectedly: {e}") + + def test_get_tf_ops_success(self): + self.actuator.all_node_names = ["input"] + mock_graph = MagicMock() + tensor = MagicMock() + mock_graph.get_tensor_by_name.return_value = tensor + + self.actuator.sess = MagicMock() + self.actuator.sess.graph = mock_graph + + ops = self.actuator._get_tf_ops() + self.assertEqual(len(ops), 1) + self.assertEqual(ops[0], tensor) + + def test_get_tf_ops_failure(self): + self.actuator.all_node_names = ["bad_node"] + self.actuator.sess = MagicMock() + self.actuator.sess.graph.get_tensor_by_name.side_effect = Exception("fail") + + with self.assertRaises(MsitException): + self.actuator._get_tf_ops() + + def test_build_feed_success(self): + tensor = MagicMock() + input_map = {"input": np.ones((1, 224, 224, 3))} + + self.actuator.sess = MagicMock() + self.actuator.sess.graph.get_tensor_by_name.return_value = tensor + + feed_dict = self.actuator._build_feed(input_map) + self.assertEqual(feed_dict[tensor].shape, (1, 224, 224, 3)) + + def test_build_feed_failure(self): + input_map = {"bad_input": np.zeros((1,))} + + self.actuator.sess = MagicMock() + self.actuator.sess.graph.get_tensor_by_name.side_effect = Exception("fail") + + with self.assertRaises(MsitException): + self.actuator._build_feed(input_map) + + def test_infer_success(self): + mock_sess = MagicMock() + mock_sess.run.return_value = ["result"] + self.actuator._open_session = MagicMock(return_value=mock_sess) + self.actuator._renew_all_node_names = MagicMock() + self.actuator._get_tf_ops = MagicMock(return_value=["fake_op"]) + self.actuator._build_feed = MagicMock(return_value={"input": "fake_data"}) + self.actuator.close = MagicMock() + result = self.actuator.infer({"input": "data"}) + self.assertEqual(result, ["result"]) + mock_sess.run.assert_called_once() + + def test_infer_failure(self): + mock_sess = MagicMock() + mock_sess.run.side_effect = RuntimeError("bad inference") + self.actuator._open_session = MagicMock(return_value=mock_sess) + self.actuator._renew_all_node_names = MagicMock() + self.actuator._get_tf_ops = MagicMock(return_value=["fake_op"]) + self.actuator._build_feed = MagicMock(return_value={"input": "fake_data"}) + self.actuator.close = MagicMock() + with self.assertRaises(MsitException) as context: + self.actuator.infer({"input": "data"}) + self.assertIn("input shape or data", str(context.exception)) + + +class TestFrozenGraphActuatorCPU(unittest.TestCase): + @patch("msit.module.probe.dump.tf_model.FrozenGraphActuator._import_tf") + def test_open_session(self, mock_import_tf): + mock_tf = MagicMock() + mock_tf.compat.v1.Session.return_value = "mock_session" + mock_import_tf.return_value = (mock_tf, MagicMock()) + actuator = FrozenGraphActuatorCPU("model", {}, "input") + session = actuator._open_session() + self.assertEqual(session, "mock_session") + + +class TestFrozenGraphActuatorNPU(unittest.TestCase): + @patch("msit.module.probe.dump.tf_model.DirPool.get_rank_dir", return_value="/mock/rank_dir") + @patch("msit.module.probe.dump.tf_model.FrozenGraphActuator._import_tf") + @patch("msit.module.probe.dump.tf_model.dependent.get") + def test_open_session_npu(self, mock_dependent_get, mock_import_tf, mock_rank_dir): + mock_tf = MagicMock() + mock_import_tf.return_value = (mock_tf, MagicMock()) + mock_device = MagicMock() + mock_device.compat.enable_v1 = MagicMock() + mock_dependent_get.return_value = mock_device + mock_tf.compat.v1.ConfigProto.return_value = MagicMock() + mock_tf.compat.v1.Session.return_value = "npu_session" + + actuator = FrozenGraphActuatorNPU("model", {}, "input") + session = actuator._open_session() + self.assertEqual(session, "npu_session") + + @patch("msit.module.probe.dump.tf_model.cann.model2json") + @patch("msit.module.probe.dump.tf_model.get_name_and_ext") + @patch("msit.module.probe.dump.tf_model.glob") + @patch("msit.module.probe.dump.tf_model.DirPool.get_model_dir") + def test_convert_txt2json(self, mock_get_model_dir, mock_glob, mock_get_name_ext, mock_model2json): + mock_get_model_dir.return_value = "/mock/dir" + mock_glob.return_value = ["/mock/dir/mock_Build.txt"] + mock_get_name_ext.return_value = ("mock", ".txt") + actuator = FrozenGraphActuatorNPU("model", {}, "input") + actuator.convert_txt2json() + mock_model2json.assert_called_once() + + +class TestFrozenGraphDataWriter(unittest.TestCase): + def setUp(self): + self.task = MagicMock() + self.dump_mode = ["all"] + self.writer = FrozenGraphDataWriter(self.task, self.dump_mode) + self.writer.cache_dump_json = {"data": {}} + self.writer.through_inputs = MagicMock() + self.writer.through_outputs = MagicMock() + + @patch("msit.module.probe.dump.tf_model.get_valid_name") + def test_get_output_map(self, mock_get_valid_name): + tf_ops = [MagicMock(name="Tensor1"), MagicMock(name="Tensor2")] + infer_output = ["out1", "out2"] + mock_get_valid_name.side_effect = lambda x: f"valid_{x}" + + output_map = self.writer._get_output_map(tf_ops, infer_output) + self.assertEqual(output_map, {f"valid_{tf_ops[0].name}": "out1", f"valid_{tf_ops[1].name}": "out2"}) + + @patch("msit.module.probe.dump.tf_model.get_valid_name") + @patch("msit.module.probe.dump.tf_model.logger") + def test_get_input_map(self, mock_logger, mock_get_valid_name): + # Mock input tensors + tensor_input = MagicMock() + tensor_input.name = "input_tensor:0" + + tf_op = MagicMock() + tf_op.op.name = "nodeA" + tf_op.op.inputs = [tensor_input] + + output_map = {"input_tensor": "data123"} + mock_get_valid_name.side_effect = lambda x: x.split(":")[0] + + input_map = self.writer._get_input_map([tf_op], output_map) + self.assertEqual(input_map, {"input_tensor": "data123"}) + + @patch("msit.module.probe.dump.tf_model.FrozenGraphDataWriter._get_output_map") + @patch("msit.module.probe.dump.tf_model.FrozenGraphDataWriter._get_input_map") + def test_get_input_output_map(self, mock_get_input_map, mock_get_output_map): + tf_ops = ["op1"] + infer_output = ["output1"] + mock_get_output_map.return_value = {"x": "y"} + mock_get_input_map.return_value = {"a": "b"} + input_map, output_map = self.writer.get_input_output_map(tf_ops, infer_output) + self.assertEqual(input_map, {"a": "b"}) + self.assertEqual(output_map, {"x": "y"}) + + @patch("msit.module.probe.base.dump_writer.save_json") + @patch("msit.module.probe.dump.onnx_model.DirPool.get_model_dir", return_value="/mock/model/dir") + @patch("msit.module.probe.dump.tf_model.get_net_output_nodes_from_graph_def") + @patch("msit.module.probe.dump.tf_model.get_valid_name") + def test_summ_dump_data(self, mock_get_valid_name, mock_get_net_output_nodes, mock_model_dir, mock_save_json): + node_mock = MagicMock() + node_mock.name = "nodeA" + node_mock.op.inputs = ["input1"] + node_mock.op.outputs = ["output1"] + mock_get_valid_name.side_effect = lambda x: f"valid_{x}" + mock_get_net_output_nodes.return_value = ["output_node"] + self.writer.dump_mode = ["input", "output"] + self.writer.cache_dump_json = {"data": {}} + self.writer.summ_dump_data([node_mock], {"input1": "data"}, {"output1": "out"}, MagicMock()) + self.writer.through_inputs.assert_called_once() + self.writer.through_outputs.assert_called_once() + self.assertIn("valid_nodeA", self.writer.cache_dump_json["data"]) diff --git a/accuracy_tools/test/UT/csrc_ut/CMakeLists.txt b/accuracy_tools/test/UT/csrc_ut/CMakeLists.txt new file mode 100644 index 00000000000..c395e02fa60 --- /dev/null +++ b/accuracy_tools/test/UT/csrc_ut/CMakeLists.txt @@ -0,0 +1,24 @@ +project(msit VERSION 1.0.0 LANGUAGES CXX C) +cmake_minimum_required(VERSION 3.14) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +find_package(cpython MODULE REQUIRED) +find_package(gtest MODULE REQUIRED) +find_package(mockcpp MODULE REQUIRED) +find_package(nlohmannjson MODULE REQUIRED) + +add_executable(msit_test) +target_link_libraries(msit_test PRIVATE ${gtest_LIBRARIES}) +target_link_libraries(msit_test PRIVATE ${mockcpp_LIBRARIES}) +target_link_libraries(msit_test PRIVATE msprobe_c) + +target_include_directories(msit_test PRIVATE $ENV{PROJECT_ROOT_PATH}/msit/csrc) +target_include_directories(msit_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include) +target_include_directories(msit_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mock) + +target_compile_definitions(msit_test PRIVATE __RESOURCES_PATH__="${CMAKE_CURRENT_SOURCE_DIR}/../resources") + +file(GLOB_RECURSE SOURCES "*.cpp") +target_sources(msit_test PUBLIC ${SOURCES}) diff --git a/accuracy_tools/test/UT/csrc_ut/utils_ut/test_log.cpp b/accuracy_tools/test/UT/csrc_ut/utils_ut/test_log.cpp new file mode 100644 index 00000000000..e69de29bb2d diff --git a/accuracy_tools/test/UT/pytest.ini b/accuracy_tools/test/UT/pytest.ini new file mode 100644 index 00000000000..c24fe5bb9e6 --- /dev/null +++ b/accuracy_tools/test/UT/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +filterwarnings = + ignore::DeprecationWarning diff --git a/accuracy_tools/test/UT/run_ut.py b/accuracy_tools/test/UT/run_ut.py new file mode 100644 index 00000000000..0a60699b6a3 --- /dev/null +++ b/accuracy_tools/test/UT/run_ut.py @@ -0,0 +1,35 @@ +import os + +from msprobe.utils.log import logger +from msprobe.utils.toolkits import run_subprocess + + +class RunUT: + def __init__(self): + self.cur_dir = os.path.realpath(os.path.dirname(__file__)) + self.cov_dir = os.path.join(os.path.dirname(self.cur_dir), "../msprobe") + self.report_dir = os.path.join(self.cur_dir, "report") + self.cov_config_path = os.path.join(self.cur_dir, ".coveragerc") + self.final_xml_path = os.path.join(self.report_dir, "final.xml") + self.html_cov_report = os.path.join(self.report_dir, "htmlcov") + self.xml_cov_report = os.path.join(self.report_dir, "coverage.xml") + self.cmd = [ + "python3", + "-m", + "pytest", + self.cur_dir, + f"--junitxml={self.final_xml_path}", + f"--cov-config={self.cov_config_path}", + f"--cov={self.cov_dir}", + "--cov-branch", + f"--cov-report=html:{self.html_cov_report}", + f"--cov-report=xml:{self.xml_cov_report}", + ] + + def execute(self): + run_subprocess(self.cmd) + logger.info("Unit tests executed successfully.") + + +if __name__ == "__main__": + RunUT().execute() diff --git a/accuracy_tools/test/UT/run_ut.sh b/accuracy_tools/test/UT/run_ut.sh new file mode 100644 index 00000000000..0fb1bb8576d --- /dev/null +++ b/accuracy_tools/test/UT/run_ut.sh @@ -0,0 +1,56 @@ +#!/bin/bash +CUR_DIR=$(dirname $(readlink -f $0)) +TOP_DIR=${CUR_DIR}/../.. +MSIT_TEST=${TOP_DIR}/output/release/test/UT/csrc_ut/msprobe_test +TEST_DIR=${TOP_DIR}/test/UT +SRC_DIR=${TOP_DIR}/ + +run_ut_cpp() { + echo "[INFO] Start compiling msprobe_test..." + ARCH_TYPE=$(uname -m) + PYTHON_VERSION=$(python3 -c 'import platform; print(".".join(platform.python_version_tuple()[:2]))') + BUILD_SCRIPT=${TOP_DIR}/build.sh + if [[ ! -f "${BUILD_SCRIPT}" ]]; then + echo "[ERROR] build.sh not found at ${BUILD_SCRIPT}" + exit 1 + fi + bash "${BUILD_SCRIPT}" \ + --release \ + -t \ + -a "${ARCH_TYPE}" \ + -v "${PYTHON_VERSION}" \ + -j 16 \ + + if [[ -x "${MSIT_TEST}" ]]; then + echo "[INFO] Running C++ unit test binary: ${MSIT_TEST}" + ${MSIT_TEST} + else + echo "[ERROR] msprobe_test binary not found or not executable at ${MSIT_TEST}" + exit 1 + fi +} + +install_pytest() { + if ! pip show pytest &> /dev/null; then + echo "pytest not found, trying to install..." + pip install pytest + fi + + if ! pip show pytest-cov &> /dev/null; then + echo "pytest-cov not found, trying to install..." + pip install pytest-cov + fi +} + +run_ut_py() { + install_pytest + export PYTHONPATH=${SRC_DIR}:${PYTHONPATH} + python3 run_ut.py +} + +main() { + run_ut_cpp + cd ${TEST_DIR} && run_ut_py +} + +main $@ diff --git a/accuracy_tools/test/UT/test___main__.py b/accuracy_tools/test/UT/test___main__.py new file mode 100644 index 00000000000..00d4b1b278a --- /dev/null +++ b/accuracy_tools/test/UT/test___main__.py @@ -0,0 +1,38 @@ +from unittest import TestCase +from unittest.mock import MagicMock, patch + + +class TestMainFunction(TestCase): + @patch("msprobe.__main__.MainCommand") + def test_main_execution_flow(self, mock_main_command): + mock_instance = MagicMock() + mock_main_command.return_value = mock_instance + mock_args = MagicMock() + mock_instance.parse.return_value = mock_args + from msprobe.__main__ import main + + main() + mock_main_command.assert_called_once() + mock_instance.parse.assert_called_once() + mock_instance.execute.assert_called_once_with(mock_args) + + @patch("msprobe.__main__.MainCommand") + def test_direct_execution(self, mock_main_command): + with patch("sys.argv", ["script_name"]): + from msprobe.__main__ import main + + main() + mock_main_command.return_value.execute.assert_called_once() + + def test_main_called_in_if_main(self): + mock_instance = MagicMock() + mock_instance.parse.return_value = "mock_args" + MockMainCommand = MagicMock(return_value=mock_instance) + with patch("msprobe.__main__.MainCommand", MockMainCommand): + from msprobe.__main__ import main + + main() + MockMainCommand.assert_called_once() + mock_instance.register.assert_called_once() + mock_instance.parse.assert_called_once() + mock_instance.execute.assert_called_once_with("mock_args") diff --git a/accuracy_tools/test/UT/utils_ut/test_dependencies.py b/accuracy_tools/test/UT/utils_ut/test_dependencies.py new file mode 100644 index 00000000000..fbbf5d61626 --- /dev/null +++ b/accuracy_tools/test/UT/utils_ut/test_dependencies.py @@ -0,0 +1,113 @@ +import os +import sys +import unittest +from unittest.mock import MagicMock, patch + +from msprobe.utils.dependencies import DependencyManager, temporary_tf_log_level +from msprobe.utils.exceptions import MsitException + + +class TestDependencyManager(unittest.TestCase): + def setUp(self): + DependencyManager._instance = None + self.manager = DependencyManager() + + def tearDown(self): + if "TF_CPP_MIN_LOG_LEVEL" in os.environ: + del os.environ["TF_CPP_MIN_LOG_LEVEL"] + + @patch.dict(os.environ, {"TF_CPP_MIN_LOG_LEVEL": "0"}) + def test_temporary_tf_log_level(self): + @temporary_tf_log_level + def mock_function(): + return os.environ["TF_CPP_MIN_LOG_LEVEL"] + + self.assertEqual(mock_function(), "2") + self.assertEqual(os.environ["TF_CPP_MIN_LOG_LEVEL"], "0") + + @patch("msit.utils.dependencies.import_module") + def test_get_tensorflow(self, mock_import_module): + mock_tf = MagicMock() + mock_tf.__version__ = "2.6.5" + mock_rewriter_config = MagicMock() + mock_convert_variables = MagicMock() + + def side_effect(name): + if name == "tensorflow": + return mock_tf + return MagicMock() + + mock_import_module.side_effect = side_effect + sys.modules["tensorflow"] = mock_tf + sys.modules["tensorflow.core.protobuf.rewriter_config_pb2"] = MagicMock(RewriterConfig=mock_rewriter_config) + sys.modules["tensorflow.python.framework.graph_util"] = MagicMock( + convert_variables_to_constants=mock_convert_variables + ) + dm = DependencyManager() + tf, re_writer_config, sm2pb = dm.get_tensorflow() + + self.assertIsNotNone(tf, "TensorFlow is not None") + self.assertEqual(tf, mock_tf) + self.assertEqual(re_writer_config, mock_rewriter_config) + self.assertEqual(sm2pb, mock_convert_variables) + + @patch("msit.utils.dependencies.import_module") + def test_import_package_non_tensorflow(self, mock_import): + mock_module = MagicMock() + mock_import.return_value = mock_module + result = self.manager._import_package("abc") + mock_import.assert_called_once_with("abc") + self.assertEqual(result, mock_module) + self.assertIn("abc", self.manager._dependencies) + + @patch.object(DependencyManager, "_import_tensorflow") + def test_import_package_tensorflow(self, mock_import_tf): + mock_tf = MagicMock() + + def simulate_import(): + self.manager._dependencies["tensorflow"] = mock_tf + return mock_tf + + mock_import_tf.side_effect = simulate_import + result = self.manager._import_package("tensorflow") + mock_import_tf.assert_called_once() + self.assertEqual(result, mock_tf) + self.assertIn("tensorflow", self.manager._dependencies) + + @patch("msit.utils.dependencies.import_module") + def test_import_tensorflow_wrong_version(self, mock_import): + mock_tf = MagicMock() + mock_tf.__version__ = "2.7.0" + mock_import.return_value = mock_tf + with self.assertRaises(MsitException) as context: + self.manager._import_tensorflow() + self.assertIn("Incompatible versions", str(context.exception)) + + @patch("msit.utils.dependencies.import_module") + def test_import_tensorflow_environment_reset(self, mock_import): + original_level = "0" + os.environ["TF_CPP_MIN_LOG_LEVEL"] = original_level + mock_tf = MagicMock() + mock_tf.__version__ = "2.6.5" + mock_import.return_value = mock_tf + self.manager._import_tensorflow() + self.assertEqual(os.environ["TF_CPP_MIN_LOG_LEVEL"], original_level) + + @patch("msit.utils.dependencies.import_module") + def test_import_package_missing_dependency(self, mock_import): + mock_import.side_effect = ImportError("No module named 'missing_package'") + result = self.manager._import_package("missing_package") + self.assertIsNone(result) + self.assertNotIn("missing_package", self.manager._dependencies) + + @patch("msit.utils.dependencies.logger.warning") + @patch("msit.utils.dependencies.import_module") + def test_safely_import_decorator(self, mock_import, mock_warning): + mock_import.side_effect = ImportError("Test error") + result = self.manager._import_package("test_package") + self.assertIsNone(result) + mock_warning.assert_called_once_with("test_package is not installed. Please install it if needed.") + mock_warning.reset_mock() + result = self.manager._import_package("test_package") + self.assertIsNone(result) + mock_warning.assert_not_called() diff --git a/accuracy_tools/test/UT/utils_ut/test_env.py b/accuracy_tools/test/UT/utils_ut/test_env.py new file mode 100644 index 00000000000..08a790c55b0 --- /dev/null +++ b/accuracy_tools/test/UT/utils_ut/test_env.py @@ -0,0 +1,98 @@ +import os +import unittest +from unittest import mock + +from msprobe.utils.env import EnvVarManager +from msprobe.utils.exceptions import MsitException + + +class TestEnvVarManager(unittest.TestCase): + def setUp(self): + self.manager = EnvVarManager() + self.manager.set_prefix("") + self.env_patcher = mock.patch.dict(os.environ, clear=True) + self.env_patcher.start() + + def tearDown(self): + self.env_patcher.stop() + + def test_singleton_instance(self): + manager1 = EnvVarManager() + manager2 = EnvVarManager() + self.assertIs(manager1, manager2) + + def test_set_prefix(self): + self.manager.set_prefix("TEST_") + self.assertEqual(self.manager.prefix, "TEST_") + + def test_get_existing_var_no_prefix(self): + os.environ["KEY"] = "value" + result = self.manager.get("KEY") + self.assertEqual(result, "value") + + def test_get_existing_var_with_prefix(self): + self.manager.set_prefix("TEST_") + os.environ["TEST_KEY"] = "value" + result = self.manager.get("TEST_KEY") + self.assertEqual(result, "value") + + def test_get_missing_var_required(self): + with self.assertRaises(MsitException) as cm: + self.manager.get("MISSING_KEY", required=True) + self.assertIn("MISSING_KEY", str(cm.exception)) + + def test_get_missing_var_optional_with_default(self): + result = self.manager.get("MISSING_KEY", default="default_val", required=False) + self.assertEqual(result, "default_val") + + def test_get_cast_type_success(self): + os.environ["INT_VAL"] = "123" + result = self.manager.get("INT_VAL", cast_type=int) + self.assertEqual(result, 123) + + def test_get_cast_type_failure(self): + os.environ["INVALID_INT"] = "abc" + with self.assertRaises(MsitException) as cm: + self.manager.get("INVALID_INT", cast_type=int) + self.assertIn("Failed to cast", str(cm.exception)) + + def test_set_var_with_prefix(self): + self.manager.set_prefix("TEST_") + self.manager.set("NEW_KEY", "value") + self.assertEqual(os.environ["NEW_KEY"], "value") + + def test_delete_existing_var(self): + os.environ["TEST_KEY"] = "value" + self.manager.set_prefix("TEST_") + self.manager.delete("KEY") + self.assertIn("TEST_KEY", os.environ) + + def test_delete_non_existing_var(self): + self.manager.set_prefix("TEST_") + try: + self.manager.delete("NON_EXISTENT") + except Exception: + self.fail("Deleting non-existent variable raised unexpected exception") + + def test_list_all_with_prefix(self): + os.environ.update({"TEST_A": "1", "TEST_B": "2", "OTHER": "3"}) + self.manager.set_prefix("TEST_") + result = self.manager.list_all() + expected = {"TEST_A": "1", "TEST_B": "2"} + self.assertDictEqual(result, expected) + + def test_list_all_without_prefix(self): + os.environ["KEY"] = "value" + result = self.manager.list_all() + self.assertIn("KEY", result) + + @mock.patch("msit.utils.log.logger.debug") + def test_logging_on_get(self, mock_debug): + os.environ["LOGGED_KEY"] = "log_value" + self.manager.get("LOGGED_KEY") + mock_debug.assert_called_with("Accessed environment variable LOGGED_KEY, Value: log_value.") + + @mock.patch("msit.utils.log.logger.debug") + def test_logging_on_set(self, mock_debug): + self.manager.set("LOGGED_SET", "value") + mock_debug.assert_called_with("Set environment variable LOGGED_SET to value.") diff --git a/accuracy_tools/test/UT/utils_ut/test_hijack.py b/accuracy_tools/test/UT/utils_ut/test_hijack.py new file mode 100644 index 00000000000..7cafcee694a --- /dev/null +++ b/accuracy_tools/test/UT/utils_ut/test_hijack.py @@ -0,0 +1,294 @@ +import sys +import unittest +from unittest.mock import ANY, MagicMock, Mock, call, patch + +from msprobe.utils.exceptions import MsitException +from msprobe.utils.hijack import ( + POST_HOOK, + PRE_HOOK, + REPLACE, + HiJackerManager, + HiJackerPathFinder, + HijackerUnit, + HiJackerWrapperFunction, + HiJackerWrapperModule, + HiJackerWrapperObj, + HijackHandler, + hijacker, + release, +) + + +class TestHijackerUnit(unittest.TestCase): + def test_valid_parameters(self): + stub = MagicMock() + unit = HijackerUnit(stub, "module", "cls", "func", REPLACE, 100) + self.assertEqual(unit.stub, stub) + self.assertEqual(unit.module, "module") + self.assertEqual(unit.cls, "cls") + self.assertEqual(unit.function, "func") + self.assertEqual(unit.action, REPLACE) + self.assertEqual(unit.priority, 100) + + def test_invalid_stub(self): + with self.assertRaises(MsitException): + HijackerUnit("not_callable", "module", "", "", REPLACE, 100) + + def test_missing_module(self): + with self.assertRaises(MsitException): + HijackerUnit(MagicMock(), "", "", "", REPLACE, 100) + + def test_invalid_action(self): + with self.assertRaises(MsitException): + HijackerUnit(MagicMock(), "module", "", "", 999, 100) + + def test_replace_module_error(self): + with self.assertRaises(MsitException): + HijackerUnit(MagicMock(), "module", "", "", REPLACE, 100) + + +class TestHijackerUnit(unittest.TestCase): + + def test_valid_parameters(self): + mock_stub = MagicMock() + unit = HijackerUnit(mock_stub, "module_name", "ClassName", "function_name", REPLACE, 1) + self.assertEqual(unit.module, "module_name") + + def test_invalid_stub(self): + with self.assertRaises(MsitException) as context: + HijackerUnit("not_callable", "module_name", "ClassName", "function_name", REPLACE, 1) + self.assertIn('"stub" should be callable.', str(context.exception)) + + def test_missing_module(self): + mock_stub = MagicMock() + with self.assertRaises(MsitException) as context: + HijackerUnit(mock_stub, None, "ClassName", "function_name", REPLACE, 1) + self.assertIn('"module" is required.', str(context.exception)) + + def test_invalid_module_type(self): + mock_stub = MagicMock() + with self.assertRaises(MsitException) as context: + HijackerUnit(mock_stub, 123, "ClassName", "function_name", REPLACE, 1) + self.assertIn('"module" should be a str.', str(context.exception)) + + def test_invalid_cls_type(self): + mock_stub = MagicMock() + with self.assertRaises(MsitException) as context: + HijackerUnit(mock_stub, "module_name", 123, "function_name", REPLACE, 1) + self.assertIn('"cls" should be a str.', str(context.exception)) + + def test_missing_function_when_cls_present(self): + mock_stub = MagicMock() + with self.assertRaises(MsitException) as context: + HijackerUnit(mock_stub, "module_name", "ClassName", None, REPLACE, 1) + self.assertIn('"function" should be used when "cls" used.', str(context.exception)) + + def test_invalid_function_type(self): + mock_stub = MagicMock() + with self.assertRaises(MsitException) as context: + HijackerUnit(mock_stub, "module_name", "ClassName", 123, REPLACE, 1) + self.assertIn('"function" should be a str.', str(context.exception)) + + def test_invalid_action(self): + mock_stub = MagicMock() + with self.assertRaises(MsitException) as context: + HijackerUnit(mock_stub, "module_name", "ClassName", "function_name", "INVALID_ACTION", 1) + self.assertIn('"action" should be REPLACE, PRE_HOOK, or POST_HOOK.', str(context.exception)) + + def test_module_replacement_not_supported(self): + mock_stub = MagicMock() + with self.assertRaises(MsitException) as context: + HijackerUnit(mock_stub, "module_name", None, None, REPLACE, 1) + self.assertIn("replacement of a module is not supported", str(context.exception)) + + def test_invalid_priority_type(self): + mock_stub = MagicMock() + with self.assertRaises(MsitException) as context: + HijackerUnit(mock_stub, "module_name", "ClassName", "function_name", REPLACE, "high") + self.assertIn('"priority" should be an int.', str(context.exception)) + + +class TestRelease(unittest.TestCase): + @patch("msit.utils.hijack.HiJackerManager") + def test_release_valid_handler(self, mock_manager): + handler = MagicMock(spec=HijackHandler) + handler.released = False + handler.unit = "test_unit" + release(handler) + self.assertTrue(handler.released) + mock_manager.remove_unit.assert_called_once_with("test_unit") + + def test_release_with_invalid_handler_type(self): + invalid_handler = "not_a_handler" + with self.assertRaises(MsitException) as context: + release(invalid_handler) + self.assertIn("Handler must be an instance of HijackHandler.", str(context.exception)) + + +class TestHijackerManager(unittest.TestCase): + def setUp(self): + HiJackerManager._initialized = False + HiJackerManager._hijacker_units = {} + HiJackerManager._hijacker_wrappers = {} + + @patch("sys.meta_path", []) + def test_initialize(self): + HiJackerManager.initialize() + self.assertTrue(HiJackerManager._initialized) + self.assertIsInstance(sys.meta_path[0], HiJackerPathFinder) + + def test_add_and_remove_unit(self): + stub = MagicMock() + unit = HijackerUnit(stub, "test_module", "", "test_func", REPLACE, 100) + handler = HiJackerManager.add_unit(unit) + self.assertIn(handler, HiJackerManager._hijacker_units) + wrapper = HiJackerManager._hijacker_wrappers.get("test_module--test_func") + self.assertIsInstance(wrapper, HiJackerWrapperFunction) + self.assertEqual(len(wrapper.replacement), 1) + HiJackerManager.remove_unit(handler) + self.assertNotIn(handler, HiJackerManager._hijacker_units) + self.assertNotIn("test_module--test_func", HiJackerManager._hijacker_wrappers) + + +class ConcreteHiJackerWrapper(HiJackerWrapperObj): + def activate(self): + pass + + def deactivate(self): + pass + + +class TestRemoveUnit(unittest.TestCase): + def setUp(self): + self.hijacker = ConcreteHiJackerWrapper("mod-class-func") + self.unit_replace = Mock(action=REPLACE, priority=1) + self.unit_pre = Mock(action=PRE_HOOK, priority=2) + self.unit_post = Mock(action=POST_HOOK, priority=3) + + def test_remove_replace_unit(self): + self.hijacker.replacement.append(self.unit_replace) + self.hijacker.remove_unit(self.unit_replace) + self.assertNotIn(self.unit_replace, self.hijacker.replacement) + self.assertEqual(len(self.hijacker.replacement), 0) + + def test_remove_pre_hook_unit(self): + self.hijacker.pre_hooks.append(self.unit_pre) + self.hijacker.remove_unit(self.unit_pre) + self.assertNotIn(self.unit_pre, self.hijacker.pre_hooks) + self.assertEqual(len(self.hijacker.pre_hooks), 0) + + def test_remove_post_hook_unit(self): + self.hijacker.post_hooks.append(self.unit_post) + self.hijacker.remove_unit(self.unit_post) + self.assertNotIn(self.unit_post, self.hijacker.post_hooks) + self.assertEqual(len(self.hijacker.post_hooks), 0) + + def test_remove_non_existent_unit_raises_error(self): + with self.assertRaises(ValueError): + self.hijacker.remove_unit(self.unit_replace) + self.hijacker.pre_hooks.append(Mock(action=PRE_HOOK)) + with self.assertRaises(ValueError): + self.hijacker.remove_unit(self.unit_pre) + self.hijacker.post_hooks.append(Mock(action=POST_HOOK)) + with self.assertRaises(ValueError): + self.hijacker.remove_unit(self.unit_post) + + def test_remove_from_multiple_units(self): + unit1 = Mock(action=REPLACE, priority=1) + unit2 = Mock(action=REPLACE, priority=2) + self.hijacker.replacement = [unit1, unit2] + self.hijacker.remove_unit(unit1) + self.assertEqual(self.hijacker.replacement, [unit2]) + + def test_remove_does_not_affect_other_lists(self): + self.hijacker.replacement.append(self.unit_replace) + self.hijacker.pre_hooks.append(self.unit_pre) + self.hijacker.post_hooks.append(self.unit_post) + + self.hijacker.remove_unit(self.unit_replace) + self.assertIn(self.unit_pre, self.hijacker.pre_hooks) + self.assertIn(self.unit_post, self.hijacker.post_hooks) + + +class TestHijackerWrapperModule(unittest.TestCase): + def setUp(self): + self.wrapper = HiJackerWrapperModule("test_module--") + + @patch("msit.utils.hijack.HiJackerPathFinder.add_mod") + def test_activate(self, mock_add_mod): + self.wrapper.activate() + mock_add_mod.assert_called_once_with("test_module") + + def test_exec_pre_post_hooks(self): + pre_unit = MagicMock(action=PRE_HOOK, stub=MagicMock()) + post_unit = MagicMock(action=POST_HOOK, stub=MagicMock()) + self.wrapper.add_unit(pre_unit) + self.wrapper.add_unit(post_unit) + mock_module = MagicMock() + self.wrapper.exec_pre_hook() + pre_unit.stub.assert_called_once() + self.wrapper.exec_post_hook(mock_module) + post_unit.stub.assert_called_once_with(mock_module) + + +class TestHiJackerWrapperFunction(unittest.TestCase): + def setUp(self): + self.target_name = "test_mod-TestClass-test_method" + self.wrapper = HiJackerWrapperFunction(self.target_name) + + self.mock_module = MagicMock() + self.mock_class = MagicMock() + self.original_method = MagicMock() + self.mock_module.TestClass = self.mock_class + self.mock_class.test_method = self.original_method + + def test_initialization(self): + self.assertEqual(self.wrapper.mod_name, "test_mod") + self.assertEqual(self.wrapper.class_name, "TestClass") + self.assertEqual(self.wrapper.func_name, "test_method") + + @patch("msit.utils.hijack.hijacker") + @patch.dict("msit.utils.hijack.sys.modules", {"test_mod": None}) + def test_activate_module_not_loaded(self, mock_hijacker): + self.wrapper.activate() + mock_hijacker.assert_called_once_with(stub=ANY, module="test_mod", action=POST_HOOK, priority=0) + + def test_wrapper_execution_flow(self): + pre_hook = MagicMock() + pre_hook.stub = MagicMock(return_value=(("modified_args",), {"new_kw": 1})) + replacement = MagicMock() + replacement.stub = MagicMock(return_value="replaced_result") + post_hook = MagicMock() + post_hook.stub = MagicMock(return_value="final_result") + + self.wrapper.pre_hooks = [pre_hook] + self.wrapper.replacement = [replacement] + self.wrapper.post_hooks = [post_hook] + self.wrapper.ori_obj = MagicMock() + + result = self.wrapper._get_wrapper()("arg1", kw1=2) + + pre_hook.stub.assert_called_once_with("arg1", kw1=2) + replacement.stub.assert_called_once_with("modified_args", new_kw=1) + post_hook.stub.assert_called_once_with("replaced_result", "modified_args", new_kw=1) + self.assertEqual(result, "final_result") + + def test_pre_hook_type_check(self): + invalid_hook = MagicMock() + invalid_hook.stub = MagicMock(return_value="invalid_type") + self.wrapper.pre_hooks = [invalid_hook] + self.wrapper.ori_obj = MagicMock() + + with self.assertRaises(MsitException) as cm: + self.wrapper._get_wrapper()("arg1") + self.assertIn("Pre-hook must return a tuple", str(cm.exception)) + + @patch("msit.utils.hijack.release") + @patch.dict("msit.utils.hijack.sys.modules", {"test_mod": sys}) + def test_deactivate_with_missing_class(self, mock_release): + self.wrapper.class_name = "NonExistentClass" + self.wrapper.ori_obj = self.original_method + self.wrapper.mod_hijacker = MagicMock() + self.wrapper.deactivate() + self.assertIsNone(self.wrapper.ori_obj) + mock_release.assert_called_once() diff --git a/accuracy_tools/test/UT/utils_ut/test_io.py b/accuracy_tools/test/UT/utils_ut/test_io.py new file mode 100644 index 00000000000..11d67c718ca --- /dev/null +++ b/accuracy_tools/test/UT/utils_ut/test_io.py @@ -0,0 +1,569 @@ +import os +import pickle +import tempfile +import unittest +from unittest.mock import MagicMock, mock_open, patch + +import numpy as np +import pandas as pd + +from msprobe.utils.constants import MsgConst +from msprobe.utils.dependencies import dependent +from msprobe.utils.exceptions import MsitException +from msprobe.utils.io import ( + SafelyOpen, + _load_dir, + _load_file, + _save_dir, + _save_file, + load_bin_data, + load_caffe_model, + load_csv_by_builtin, + load_csv_by_pandas, + load_json, + load_npy, + load_npy_from_buffer, + load_om_model, + load_onnx_model, + load_onnx_session, + load_pb_frozen_graph_model, + load_saved_model, + load_torch_obj, + load_yaml, + save_bin_from_ndarray, + save_csv_by_pandas, + save_json, + save_npy, + save_onnx_model, + save_pb_frozen_graph_model, + save_yaml, + savedmodel2pb, +) +from msprobe.utils.path import AUTHORITY_DIR, AUTHORITY_FILE, MsitPath, PathConst, change_permission + + +class TestSafelyOpen(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.temp_path = self.temp_dir.name + + def tearDown(self): + self.temp_dir.cleanup() + + def test_read_existing_file(self): + file_path = os.path.join(self.temp_path, "test.txt") + with open(file_path, "w") as f: + f.write("content") + with SafelyOpen(file_path, "r", path_exist=True) as f: + self.assertEqual(f.read(), "content") + + def test_write_new_file(self): + file_path = os.path.join(self.temp_path, "new.txt") + with SafelyOpen(file_path, "w", path_exist=False) as f: + f.write("content") + with open(file_path, "r") as f: + self.assertEqual(f.read(), "content") + + def test_suffix_mismatch(self): + file_path = os.path.join(self.temp_path, "file.csv") + with open(file_path, "w") as f: + f.write("data") + with self.assertRaises(MsitException): + SafelyOpen(file_path, "r", suffix=".txt", path_exist=True) + + def test_file_size_exceeded(self): + file_path = os.path.join(self.temp_path, "large.txt") + with open(file_path, "w") as f: + f.write("a" * 1024) + with self.assertRaises(MsitException): + SafelyOpen(file_path, "r", file_size_limitation=512, path_exist=True) + + +class TestMsitPath(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.temp_path = self.temp_dir.name + + def tearDown(self): + self.temp_dir.cleanup() + + def test_check_file_is_dir(self): + dir_path = os.path.join(self.temp_path, "dir") + os.mkdir(dir_path) + with self.assertRaises(MsitException): + MsitPath(dir_path, PathConst.FILE, "r").check(path_exist=True) + + def test_suffix_check(self): + file_path = os.path.join(self.temp_path, "file.txt") + with open(file_path, "w") as f: + f.write("data") + MsitPath(file_path, PathConst.FILE, "r", suffix=".txt").check() + with self.assertRaises(MsitException): + MsitPath(file_path, PathConst.FILE, "r", suffix=".csv").check() + + def test_file_size_limitation(self): + file_path = os.path.join(self.temp_path, "file.txt") + with open(file_path, "w") as f: + f.write("a" * 1024) + with self.assertRaises(MsitException): + MsitPath(file_path, PathConst.FILE, "r", size_limitation=512).check() + + def test_path_existence(self): + non_existent = os.path.join(self.temp_path, "nonexistent.txt") + with self.assertRaises(MsitException): + MsitPath(non_existent, PathConst.FILE, "r").check(path_exist=True) + MsitPath(non_existent, PathConst.FILE, "w").check(path_exist=False) + + +class TestDecorators(unittest.TestCase): + @staticmethod + @_load_file("r", None, ".txt", True) + def dummy_load(f): + return f.read() + + def test_load_file_decorator_file_not_found(self): + with self.assertRaises(MsitException): + self.dummy_load("nonexistent.txt") + + @staticmethod + @_save_file("w", None, ".txt", True) + def dummy_save(data, f): + f.write(data) + + @patch("msit.utils.path.MsitPath._check_write_permission_for_group_others") + def test_save_file_decorator_success(self, mock_check): + mock_check.return_value = None + temp_file = tempfile.NamedTemporaryFile(delete=False) + temp_path = temp_file.name + temp_file.close() + self.dummy_save("data", temp_path) + with open(temp_path, "r") as f: + self.assertEqual(f.read(), "data") + os.unlink(temp_path) + + @patch("builtins.open") + def test_save_file_decorator_permission_error(self, mock_open): + mock_open.side_effect = PermissionError("Permission denied") + with self.assertRaises(MsitException): + self.dummy_save("data", "/unauthorized.txt") + + +class TestSaveDirDecorator(unittest.TestCase): + @patch("msit.utils.io.MsitPath.check") + @patch("msit.utils.io.change_permission") + def test_save_dir_success(self, mock_change_perm, mock_msit_path): + mock_path_instance = MagicMock() + mock_msit_path.return_value = mock_path_instance + + @_save_dir(dir_size=1024) + def test_func(data, path, *args, **kwargs): + return + + result = test_func("test_data", "/test/path") + mock_msit_path.assert_called_once_with(path_exist=False) + mock_change_perm.assert_called_once_with(mock_path_instance, AUTHORITY_DIR) + self.assertEqual(result, None) + + @patch("msit.utils.io.MsitPath.check") + @patch("msit.utils.io.change_permission") + def test_save_dir_exception_handling(self, mock_change_perm, mock_msit_path): + mock_path_instance = MagicMock() + mock_msit_path.return_value = mock_path_instance + + @_save_dir(dir_size=2048) + def failing_func(data, path, *args, **kwargs): + raise MsitException("Test error") + + with self.assertRaises(MsitException) as cm: + failing_func("test_data", "/failing/path") + self.assertIn(MsgConst.IO_FAILURE, cm.exception.error_msg) + mock_change_perm.assert_not_called() + + +class TestLoadNpyFromBuffer(unittest.TestCase): + def test_load_valid_buffer(self): + expected_array = np.array([1, 2, 3, 4], dtype=np.int32) + raw_data = expected_array.tobytes() + result = load_npy_from_buffer(raw_data, dtype=np.int32, shape=(4,)) + np.testing.assert_array_equal(result, expected_array) + + def test_invalid_dtype(self): + test_data = np.array([1, 2, 3, 4], dtype=np.int32) + raw_data = test_data.tobytes() + with self.assertRaises(MsitException) as cm: + load_npy_from_buffer(raw_data, dtype=np.float64, shape=(4,)) + self.assertIn(MsgConst.IO_FAILURE, cm.exception.error_msg) + self.assertIsInstance(cm.exception.__cause__, ValueError) + + def test_mismatched_shape(self): + test_data = np.array([1, 2, 3, 4], dtype=np.int32) + raw_data = test_data.tobytes() + with self.assertRaises(MsitException) as cm: + load_npy_from_buffer(raw_data, dtype=np.int32, shape=(2, 3)) + self.assertIn(MsgConst.IO_FAILURE, cm.exception.error_msg) + self.assertIn("reshape", str(cm.exception.__cause__).lower()) + + def test_invalid_raw_data_type(self): + with self.assertRaises(MsitException) as cm: + load_npy_from_buffer("invalid_data", dtype=np.int32, shape=(1,)) + + self.assertIn(MsgConst.IO_FAILURE, cm.exception.error_msg) + self.assertIsInstance(cm.exception.__cause__, TypeError) + + +class TestPermissionManagement(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.temp_file = tempfile.NamedTemporaryFile(delete=False) + self.temp_file.close() + + def tearDown(self): + os.unlink(self.temp_file.name) + self.temp_dir.cleanup() + + @patch("os.chmod") + def test_change_permission_file(self, mock_chmod): + change_permission(self.temp_file.name, 0o644) + mock_chmod.assert_called_once_with(self.temp_file.name, 0o644) + + @patch("os.chmod") + def test_change_permission_dir(self, mock_chmod): + change_permission(self.temp_dir.name, 0o755) + mock_chmod.assert_called_once_with(self.temp_dir.name, 0o755) + + +class TestModelLoading(unittest.TestCase): + def setUp(self): + self.mock_onnx = MagicMock() + self.mock_ort = MagicMock() + self.mock_caffe = MagicMock() + self.mock_tf = MagicMock() + self.mock_rewriter_config = MagicMock() + self.mock_convert_vars = MagicMock() + self.mock_graph = MagicMock() + self.mock_session = MagicMock() + self.mock_saved_model = MagicMock() + self.mock_graph_def = MagicMock() + + dependent._dependencies["onnx"] = self.mock_onnx + dependent._dependencies["onnxruntime"] = self.mock_ort + dependent._dependencies["caffe"] = self.mock_caffe + dependent._dependencies["tensorflow"] = self.mock_tf + dependent._dependencies["tensorflow/RewriterConfig"] = self.mock_rewriter_config + dependent._dependencies["tensorflow/convert_variables_to_constants"] = self.mock_convert_vars + + self.mock_tf.compat.v1.Graph.return_value = self.mock_graph + self.mock_tf.compat.v1.Session.return_value = self.mock_session + self.mock_tf.compat.v1.saved_model.loader.load.return_value = self.mock_saved_model + self.mock_tf.compat.v1.gfile.GFile.return_value.read.return_value = b"proto_data" + self.mock_tf.compat.v1.GraphDef.return_value = self.mock_graph_def + + def tearDown(self): + dependent._dependencies.clear() + + @patch("msit.utils.path.MsitPath.check") + def test_load_onnx_model(self, mock_check): + mock_check.return_value = "dummy.onnx" + mock_model = MagicMock() + self.mock_onnx.load_model.return_value = mock_model + result = load_onnx_model("dummy.onnx") + mock_check.assert_called_once() + self.mock_onnx.load_model.assert_called_once_with("dummy.onnx") + self.assertEqual(result, mock_model) + + @patch("msit.utils.path.MsitPath.check") + def test_load_onnx_session(self, mock_check): + mock_check.return_value = "dummy.onnx" + mock_session = MagicMock() + self.mock_ort.InferenceSession.return_value = mock_session + result = load_onnx_session("dummy.onnx", provider="CPUExecutionProvider") + mock_check.assert_called_once() + self.mock_ort.InferenceSession.assert_called_once_with( + "dummy.onnx", sess_options=self.mock_ort.SessionOptions(), providers=["CPUExecutionProvider"] + ) + self.assertEqual(result, mock_session) + + @patch("msit.utils.path.MsitPath.check") + def test_load_caffe_model(self, mock_check): + mock_check.return_value = "model.prototxt" + mock_net = MagicMock() + self.mock_caffe.Net.return_value = mock_net + result = load_caffe_model("model.prototxt", "weights.caffemodel") + mock_check.assert_called_once() + self.mock_caffe.Net.assert_called_once_with("model.prototxt", "weights.caffemodel", self.mock_caffe.TEST) + self.assertEqual(result, mock_net) + + @patch("msit.utils.path.MsitPath.check") + @patch("msit.utils.dependencies.dependent.get") + def test_save_small_model(self, mock_dependent_get, mock_check): + mock_check.return_value = "model.onnx" + mock_onnx = MagicMock() + mock_onnx.save_model = MagicMock() + mock_dependent_get.return_value = mock_onnx + + mock_model = MagicMock() + mock_model.ByteSize.return_value = PathConst.SIZE_2G - 1 + save_onnx_model(mock_model, "model.onnx") + mock_check.assert_called_once() + mock_dependent_get.assert_called_once_with("onnx") + mock_onnx.save_model.assert_called_once_with(mock_model, "model.onnx", save_as_external_data=False) + + @patch("msit.utils.path.MsitPath.check") + @patch("msit.utils.dependencies.dependent.get") + def test_save_large_model(self, mock_dependent_get, mock_check): + mock_check.return_value = "large_model.onnx" + mock_onnx = MagicMock() + mock_dependent_get.return_value = mock_onnx + mock_model = MagicMock() + mock_model.ByteSize.return_value = PathConst.SIZE_2G + 1 + save_onnx_model(mock_model, "large_model.onnx") + mock_check.assert_called_once() + mock_onnx.save_model.assert_called_once_with(mock_model, "large_model.onnx", save_as_external_data=True) + + @patch("msit.utils.path.MsitPath.check") + @patch("msit.utils.dependencies.dependent.get") + def test_onnx_dependency_missing(self, mock_dependent_get, mock_check): + mock_check.return_value = "model.onnx" + mock_dependent_get.return_value = None + mock_model = MagicMock() + with self.assertRaises(MsitException) as ctx: + save_onnx_model(mock_model, "model.onnx") + self.assertIn("using . Please check permissions or disk space.", str(ctx.exception)) + + @patch("numpy.load") + @patch("msit.utils.path.MsitPath.check") + def test_load_npy(self, mock_check, mock_np_load): + mock_check.return_value = "data.npy" + mock_data = MagicMock() + mock_np_load.return_value = mock_data + result = load_npy("data.npy") + mock_check.assert_called_once() + mock_np_load.assert_called_once_with("data.npy", allow_pickle=False) + np.testing.assert_array_equal(result, mock_data) + + @patch("numpy.save") + @patch("msit.utils.path.MsitPath.check") + def test_save_npy(self, mock_check, mock_np_save): + mock_check.return_value = "save.npy" + data = np.array([1, 2, 3]) + save_npy(data, "save.npy") + mock_check.assert_called_once() + mock_np_save.assert_called_once_with("save.npy", data) + + @patch("msit.utils.path.MsitPath.check") + def test_load_saved_model(self, mock_check): + mock_check.return_value = "saved_model" + result_model, result_sess = load_saved_model("saved_model", ["serve"]) + tf_module, rewriter_config, convert_vars = dependent.get_tensorflow() + self.assertIsNotNone(tf_module) + self.assertIsNotNone(rewriter_config) + self.assertIsNotNone(convert_vars) + self.mock_tf.compat.v1.reset_default_graph.assert_called_once() + self.mock_tf.compat.v1.Graph.assert_called_once() + self.mock_tf.compat.v1.Session.assert_called_once_with(graph=self.mock_graph) + self.mock_tf.compat.v1.saved_model.loader.load.assert_called_once_with( + self.mock_session, set(["serve"]), "saved_model" + ) + self.assertEqual(result_model, self.mock_saved_model) + self.assertEqual(result_sess, self.mock_session) + + @patch("msit.utils.path.MsitPath.check", side_effect=MsitException("File not found")) + def test_load_onnx_model_failure(self, mock_check): + with self.assertRaises(MsitException): + load_onnx_model("invalid.onnx") + + def test_load_caffe_model_no_dependency(self): + dependent._dependencies["caffe"] = None + with self.assertRaises(MsitException): + result = load_caffe_model("model.prototxt", "weights.caffemodel") + self.assertIsNone(result) + + @patch("msit.utils.path.MsitPath.check") + def test_load_pb_frozen_graph_model_success(self, mock_check): + mock_check.return_value = "model.pb" + result = load_pb_frozen_graph_model("model.pb") + self.mock_tf.compat.v1.gfile.GFile.assert_called_once_with("model.pb", "rb") + self.mock_graph_def.ParseFromString.assert_called_once_with(b"proto_data") + self.assertEqual(result, self.mock_graph_def) + + def test_load_pb_frozen_graph_model_no_tf(self): + dependent._dependencies["tensorflow"] = None + with self.assertRaises(MsitException): + result = load_pb_frozen_graph_model("model.pb") + self.assertIsNone(result) + + @patch("msit.utils.path.MsitPath.check") + def test_save_pb_frozen_graph_model(self, mock_check): + mock_gfile = MagicMock() + mock_gfile_instance = MagicMock() + mock_gfile.__enter__.return_value = mock_gfile_instance + self.mock_tf.io.gfile.GFile.return_value = mock_gfile + mock_check.return_value = "save.pb" + mock_frozen_graph = b"dummy_frozen_graph_data" + save_pb_frozen_graph_model(mock_frozen_graph, "save.pb") + self.mock_tf.io.gfile.GFile.assert_called_once_with("save.pb", "wb") + mock_gfile_instance.write.assert_called_once_with(mock_frozen_graph) + + +class TestBinFileOperations(unittest.TestCase): + @patch("msit.utils.io.np.fromfile") + @patch("msit.utils.io.get_file_size") + @patch("msit.utils.path.MsitPath.check") + @patch("builtins.open", new_callable=mock_open) + def test_load_bin_float32_with_valid_size(self, mock_open_file, mock_check, mock_get_size, mock_fromfile): + + mock_check.return_value = "data.bin" + mock_get_size.return_value = 8 + + mock_fromfile.side_effect = [ + np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16), + np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), + ] + + result = load_bin_data("data.bin", dtype=np.float32, shape=(2, 2)) + + mock_get_size.assert_called_once_with("data.bin") + mock_fromfile.assert_any_call("data.bin", dtype=np.float16) + self.assertEqual(result.dtype, np.float32) + np.testing.assert_array_equal(result, np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)) + + @patch("msit.utils.io.np.fromfile") + @patch("msit.utils.io.get_file_size") + @patch("msit.utils.io.MsitPath.check") + @patch("builtins.open", new_callable=mock_open) + def test_load_bin_float32_with_invalid_size(self, mock_open_file, mock_check, mock_get_size, mock_fromfile): + mock_check.return_value = "data.bin" + mock_get_size.return_value = 10 + mock_fromfile.return_value = np.array([1.0, 2.0], dtype=np.float32) + result = load_bin_data("data.bin", dtype=np.float32, shape=(2, 2)) + mock_fromfile.assert_called_once_with("data.bin", dtype=np.float32) + self.assertEqual(result.dtype, np.float32) + + +class TestSavedModelToPb(unittest.TestCase): + def setUp(self): + self.mock_tf = MagicMock() + self.mock_rewriter_config = MagicMock() + self.mock_sm2pb = MagicMock() + self.mock_sess = MagicMock() + self.mock_meta_graph = MagicMock() + + self.mock_tf.compat.v1.saved_model.loader.load.return_value = self.mock_meta_graph + dependent._dependencies["tensorflow"] = self.mock_tf + dependent._dependencies["tensorflow/RewriterConfig"] = self.mock_rewriter_config + dependent._dependencies["tensorflow/convert_variables_to_constants"] = self.mock_sm2pb + + @patch("msit.utils.io.load_saved_model") + @patch("msit.utils.io.save_pb_frozen_graph_model") + def test_savedmodel2pb_success(self, mock_save_pb, mock_load_model): + mock_load_model.return_value = (self.mock_meta_graph, self.mock_sess) + mock_signature = MagicMock() + self.mock_meta_graph.signature_def.get.return_value = mock_signature + mock_signature.inputs = {"input": MagicMock(name="input:0")} + mock_signature.outputs = {"output": MagicMock(name="output:0")} + + result = savedmodel2pb("model_dir", ["serve"], "serving_default", "output_dir") + self.mock_sm2pb.assert_called_once() + mock_save_pb.assert_called_once() + self.assertIn("model_dir.pb", result) + + def test_savedmodel2pb_signature_not_found(self): + with self.assertRaises(MsitException): + savedmodel2pb("model_dir", ["serve"], "invalid_signature", "output_dir") + + +class TestYamlJsonOperations(unittest.TestCase): + @patch("yaml.safe_load") + @patch("msit.utils.io.SafelyOpen") + @patch("msit.utils.path.MsitPath.check") + def test_load_yaml(self, mock_check, mock_safely_open, mock_yaml_load): + mock_check.return_value = "dummy.yaml" + mock_file = MagicMock() + mock_safely_open.return_value.__enter__.return_value = mock_file + mock_yaml_load.return_value = {"key": "value"} + result = load_yaml("dummy.yaml") + mock_safely_open.assert_called_once_with("dummy.yaml", "r", PathConst.SIZE_500M, PathConst.SUFFIX_YAML, "utf-8") + mock_yaml_load.assert_called_once_with(mock_file) + self.assertEqual(result, {"key": "value"}) + + @patch("yaml.dump") + @patch("msit.utils.io.SafelyOpen") + @patch("msit.utils.path.MsitPath.check") + def test_save_yaml(self, mock_check, mock_safely_open, mock_yaml_dump): + mock_check.return_value = "save.yaml" + mock_file = MagicMock() + mock_safely_open.return_value.__enter__.return_value = mock_file + save_yaml({"key": "value"}, "save.yaml") + mock_safely_open.assert_called_once_with("save.yaml", "w", None, PathConst.SUFFIX_YAML, path_exist=False) + mock_yaml_dump.assert_called_once_with({"key": "value"}, mock_file) + + @patch("json.load") + @patch("msit.utils.io.SafelyOpen") + @patch("msit.utils.path.MsitPath.check") + def test_load_json(self, mock_check, mock_safely_open, mock_json_load): + mock_check.return_value = "data.json" + mock_file = MagicMock() + mock_safely_open.return_value.__enter__.return_value = mock_file + mock_json_load.return_value = {"name": "test"} + result = load_json("data.json") + mock_safely_open.assert_called_once_with("data.json", "r", PathConst.SIZE_2G, PathConst.SUFFIX_JSON, "utf-8") + self.assertEqual(result, {"name": "test"}) + + @patch("json.dump") + @patch("msit.utils.io.SafelyOpen") + @patch("msit.utils.path.MsitPath.check") + def test_save_json(self, mock_check, mock_safely_open, mock_json_dump): + mock_check.return_value = "save.json" + mock_file = MagicMock() + mock_safely_open.return_value.__enter__.return_value = mock_file + save_json({"id": 1}, "save.json") + mock_safely_open.assert_called_once_with("save.json", "w", None, PathConst.SUFFIX_JSON, path_exist=False) + mock_json_dump.assert_called_once_with({"id": 1}, mock_file, indent=None, default=str) + + +class TestCsvOperations(unittest.TestCase): + @patch("csv.reader") + @patch("msit.utils.io.SafelyOpen") + @patch("msit.utils.path.MsitPath.check") + def test_load_csv_by_builtin(self, mock_check, mock_safely_open, mock_csv_reader): + mock_check.return_value = "data.csv" + mock_file = MagicMock() + mock_safely_open.return_value.__enter__.return_value = mock_file + mock_csv_reader.return_value = [["a", "1"], ["b", "2"]] + result = load_csv_by_builtin("data.csv") + mock_safely_open.assert_called_once_with( + "data.csv", "r", PathConst.SIZE_500M, PathConst.SUFFIX_CSV, "utf-8-sig" + ) + self.assertEqual(result, [["a", "1"], ["b", "2"]]) + + @patch("pandas.read_csv") + @patch("msit.utils.path.MsitPath.check") + def test_load_csv_by_pandas(self, mock_check, mock_pd_read): + mock_check.return_value = "data.csv" + mock_df = pd.DataFrame({"col1": ["a", "b"], "col2": [1, 2]}) + mock_pd_read.return_value = mock_df + result = load_csv_by_pandas("data.csv") + pd.testing.assert_frame_equal(result, mock_df) + + @patch("pandas.DataFrame.to_csv") + @patch("msit.utils.path.MsitPath.check") + def test_save_csv_by_pandas(self, mock_check, mock_to_csv): + mock_check.return_value = "save.csv" + df = pd.DataFrame({"A": [1, 2]}) + save_csv_by_pandas(df, "save.csv") + mock_to_csv.assert_called_once_with("save.csv", sep=",", index=False) + + +class TestTorchOperations(unittest.TestCase): + def setUp(self): + self.mock_torch = MagicMock() + dependent._dependencies["torch"] = self.mock_torch + + @patch("msit.utils.io.is_input_yes") + @patch("msit.utils.io.MsitPath.check") + def test_load_torch_obj_safe(self, mock_check, mock_input): + mock_check.return_value = "model.pt" + self.mock_torch.load.side_effect = [pickle.UnpicklingError(), MagicMock()] + mock_input.return_value = True + result = load_torch_obj("model.pt") + self.mock_torch.load.assert_called_with("model.pt", weights_only=False) diff --git a/accuracy_tools/test/UT/utils_ut/test_log.py b/accuracy_tools/test/UT/utils_ut/test_log.py new file mode 100644 index 00000000000..d649be5d677 --- /dev/null +++ b/accuracy_tools/test/UT/utils_ut/test_log.py @@ -0,0 +1,116 @@ +import time +import unittest +from unittest.mock import patch + +from msit.lib.msprobe_c import log + +from msprobe.utils.log import LOG_LEVEL, MsitLogger, get_current_timestamp, logger, print_log_with_star + + +class TestGetCurrentTimestamp(unittest.TestCase): + def test_used_for_log_false_no_microsecond(self): + result = get_current_timestamp(microsecond=False) + self.assertIsInstance(result, int) + self.assertAlmostEqual(result, int(time.time()), delta=1) + + @patch("msit.utils.log.perf_counter") + def test_used_for_log_false_with_microsecond(self, mock_time): + mock_time.return_value = 1620000000.123456 + expected = round(1620000000.123456 * 1e6) % 10**10 + result = get_current_timestamp(microsecond=True) + self.assertEqual(result, expected) + + +class TestPrintLogWithStar(unittest.TestCase): + @patch.object(logger, "info") + def test_print_log_with_star_normal(self, mock_info): + test_message = "Test Message" + print_log_with_star(test_message) + self.assertEqual(mock_info.call_count, 3) + args_list = [call.args[0] for call in mock_info.call_args_list] + self.assertEqual(args_list[0], "*" * 80) + self.assertEqual(args_list[2], "*" * 80) + middle_line = args_list[1] + self.assertEqual(len(middle_line), 80) + self.assertTrue(middle_line.startswith("*")) + self.assertTrue(middle_line.endswith("*")) + expected_content = f"*{test_message.center(78)}*" + self.assertEqual(middle_line, expected_content) + + @patch.object(logger, "info") + def test_print_log_with_star_long_message(self, mock_info): + test_message = "A" * 79 + print_log_with_star(test_message) + middle_line = mock_info.call_args_list[1].args[0] + self.assertEqual(len(middle_line), 81) + + +class TestMsitLogger(unittest.TestCase): + def setUp(self): + MsitLogger._instance = None + self.logger = MsitLogger() + + def tearDown(self): + MsitLogger._instance = None + + def test_get_level_id_valid(self): + for idx, level in enumerate(LOG_LEVEL): + self.assertEqual(MsitLogger.get_level_id(level), idx, f"Failed for level: {level}") + + def test_get_level_id_case_insensitive(self): + self.assertEqual(MsitLogger.get_level_id("debug"), LOG_LEVEL.index("DEBUG")) + + def test_get_level_id_invalid(self): + self.assertEqual(MsitLogger.get_level_id("INVALID_LEVEL"), LOG_LEVEL.index("INFO")) + + def test_set_level_valid(self): + test_levels = ["ERROR", "WARNING", "DEBUG", "INFO"] + for level in test_levels: + with self.subTest(level=level): + self.logger.set_level(level) + self.assertEqual(log.get_log_level(), LOG_LEVEL.index(level)) + + def test_set_level_invalid(self): + self.logger.set_level("INVALID_LEVEL") + self.assertEqual(log.get_log_level(), LOG_LEVEL.index("INFO")) + + @patch.object(log, "print_log") + def test_error_log_when_level_allows(self, mock_print): + self.logger.set_level("ERROR") + test_msg = "Test error message" + self.logger.error(test_msg) + mock_print.assert_called_once_with(LOG_LEVEL.index("ERROR"), test_msg) + + @patch.object(log, "print_log") + def test_error_log_when_level_denies(self, mock_print): + self.logger.set_level("WARNING") + self.logger.error("Should print") + mock_print.assert_called() + mock_print.reset_mock() + self.logger.set_level("INVALID_LEVEL") + self.logger.error("Should ALSO print") + mock_print.assert_called() + + @patch.object(log, "print_log") + def test_error_special_char_filter(self, mock_print): + test_msg = "Bad\nmessage\twith\rspecial" + expected_msg = "Bad_message_with_special" + + self.logger.error(test_msg) + mock_print.assert_called_once_with(LOG_LEVEL.index("ERROR"), test_msg) + + @patch.object(log, "print_log") + def test_debug_log_when_level_allows(self, mock_print): + self.logger.set_level("DEBUG") + test_msg = "Debug message" + self.logger.debug(test_msg) + mock_print.assert_called_once_with(LOG_LEVEL.index("DEBUG"), test_msg) + + @patch.object(log, "print_log") + def test_debug_special_char_filter(self, mock_print): + test_msg = f"Special\tchars" + expected_msg = "Special_chars" + + self.logger.set_level("DEBUG") + self.logger.debug(test_msg) + mock_print.assert_called_once_with(LOG_LEVEL.index("DEBUG"), test_msg) diff --git a/accuracy_tools/test/UT/utils_ut/test_path.py b/accuracy_tools/test/UT/utils_ut/test_path.py new file mode 100644 index 00000000000..54d238dbe01 --- /dev/null +++ b/accuracy_tools/test/UT/utils_ut/test_path.py @@ -0,0 +1,660 @@ +import os +import tempfile +import unittest +from collections import namedtuple +from unittest.mock import MagicMock, patch + +from msprobe.utils.constants import PathConst +from msprobe.utils.exceptions import MsitException +from msprobe.utils.path import ( + _MAX_DIR_DEPTH, + _MAX_LAST_NAME_LENGTH, + _MAX_PATH_LENGTH, + _MODE, + AUTHORITY_DIR, + SOFT_LINK_LEVEL_IGNORE, + SOFT_LINK_LEVEL_STRICT, + SOFT_LINK_LEVEL_WARNING, + MsitPath, + change_permission, + convert_bytes, + get_basename_from_path, + get_dir_size, + get_file_size, + get_name_and_ext, + is_dir, + is_enough_disk_space, + is_file, + is_saved_model_scene, + join_path, + make_dir, +) + + +class TestIsFile(unittest.TestCase): + def test_existing_file(self): + with tempfile.NamedTemporaryFile() as tmp: + self.assertTrue(is_file(tmp.name)) + + def test_non_existing_path(self): + self.assertFalse(is_file("/non/existent/path")) + + def test_directory(self): + with tempfile.TemporaryDirectory() as tmpdir: + self.assertFalse(is_file(tmpdir)) + + +class TestIsDir(unittest.TestCase): + def test_existing_dir(self): + with tempfile.TemporaryDirectory() as tmpdir: + self.assertTrue(is_dir(tmpdir)) + + def test_non_existing_path(self): + self.assertFalse(is_dir("/non/existent/path")) + + def test_file(self): + with tempfile.NamedTemporaryFile() as tmp: + self.assertFalse(is_dir(tmp.name)) + + +class TestGetBasenameFromPath(unittest.TestCase): + def test_normal_path(self): + self.assertEqual(get_basename_from_path("/path/to/file.txt"), "file.txt") + + def test_trailing_slash(self): + self.assertEqual(get_basename_from_path("/path/to/dir/"), "dir") + + def test_root_path(self): + self.assertEqual(get_basename_from_path("/"), "") + + +class TestGetFileSize(unittest.TestCase): + def test_file_size(self): + with tempfile.NamedTemporaryFile() as tmp: + content = b"12345" + tmp.write(content) + tmp.flush() + self.assertEqual(get_file_size(tmp.name), len(content)) + + def test_non_existing_file(self): + with self.assertRaises(FileNotFoundError): + get_file_size("/invalid/path") + + +class TestGetNameAndExt(unittest.TestCase): + def test_with_extension(self): + self.assertEqual(get_name_and_ext("/path/to/file.txt"), ("file", ".txt")) + + def test_multiple_dots(self): + self.assertEqual(get_name_and_ext("file.tar.gz"), ("file.tar", ".gz")) + + def test_no_extension(self): + self.assertEqual(get_name_and_ext("/path/to/file"), ("file", "")) + + +class TestJoinPath(unittest.TestCase): + def test_basic_join(self): + self.assertEqual(join_path("a", "b", "c"), os.path.join("a", "b", "c")) + + def test_nested_iterables(self): + self.assertEqual(join_path(["a", ["b", "c"]], "d"), os.path.join("a", "b", "c", "d")) + + def test_max_depth_exceeded(self): + deep_nested = ["a", ["b", ["c"]]] + with self.assertRaises(MsitException) as e: + join_path(deep_nested, max_depth=2) + self.assertIn("Maximum recursion depth 2 exceeded", str(e.exception)) + + def test_invalid_max_depth_type(self): + with self.assertRaises(MsitException) as e: + join_path("a", max_depth="invalid") + self.assertIn("max_depth must be a positive integer.", str(e.exception)) + + +class TestIsSavedModelScene(unittest.TestCase): + def create_valid_structure(self, path): + os.makedirs(os.path.join(path, "variables")) + with open(os.path.join(path, "saved_model.pb"), "w") as f: + f.write("") + + def test_valid_model(self): + with tempfile.TemporaryDirectory() as tmpdir: + self.create_valid_structure(tmpdir) + self.assertTrue(is_saved_model_scene(tmpdir)) + + def test_missing_pb(self): + with tempfile.TemporaryDirectory() as tmpdir: + os.makedirs(os.path.join(tmpdir, "variables")) + self.assertFalse(is_saved_model_scene(tmpdir)) + + def test_missing_variables(self): + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, "saved_model.pb"), "w") as f: + f.write("") + self.assertFalse(is_saved_model_scene(tmpdir)) + + +class TestConvertBytes(unittest.TestCase): + def test_bytes(self): + self.assertEqual(convert_bytes(500), "500 Bytes") + + def test_kb(self): + self.assertEqual(convert_bytes(2048), "2.00 KB") + + def test_mb(self): + self.assertEqual(convert_bytes(2 * 1024 * 1024), "2.00 MB") + + def test_gb(self): + self.assertEqual(convert_bytes(3 * 1024 * 1024 * 1024), "3.00 GB") + + def test_zero(self): + self.assertEqual(convert_bytes(0), "0 Bytes") + + +class TestMsitPathInitialization(unittest.TestCase): + def test_valid_initialization(self): + msit_path = MsitPath( + path="/tmp/test", path_type=PathConst.FILE, mode="r", size_limitation=1024, suffix=".txt", max_dir_depth=10 + ) + self.assertEqual(msit_path.path, "/tmp/test") + + def test_invalid_path_type(self): + with self.assertRaisesRegex(MsitException, "path type must be one of") as e: + MsitPath("/tmp", "invalid_type", "r") + self.assertIn("The path type must be one of", str(e.exception)) + + def test_invalid_mode(self): + with self.assertRaisesRegex(MsitException, "Mode must be one of") as e: + MsitPath("/tmp", PathConst.DIR, "invalid_mode") + self.assertIn("Mode must be one of ['r', 'rb', 'w', 'wb', 'a', 'ab', 'a+', 'e']", str(e.exception)) + + def test_negative_size_limitation(self): + with self.assertRaisesRegex(MsitException, "greater than 0") as e: + MsitPath("/tmp", PathConst.FILE, "r", size_limitation=-1) + self.assertIn("The value must be an integer greater than 0, currently: -1.", str(e.exception)) + + +class TestMsitPath(unittest.TestCase): + def test_check_path_type_valid(self): + self.assertEqual(MsitPath._check_path_type(PathConst.FILE), PathConst.FILE) + self.assertEqual(MsitPath._check_path_type(PathConst.DIR), PathConst.DIR) + + def test_check_path_type_invalid(self): + with self.assertRaises(MsitException) as e: + MsitPath._check_path_type("invalid_type") + self.assertIn("The path type must be one of ", str(e.exception)) + + def test_check_mode_valid(self): + for mode in _MODE: + self.assertEqual(MsitPath._check_mode(mode), mode) + + def test_check_mode_invalid(self): + with self.assertRaises(MsitException): + MsitPath._check_mode("invalid_mode") + + def test_check_positive_int_valid(self): + self.assertEqual(MsitPath._check_positive_int(5), 5) + + def test_check_positive_int_invalid(self): + with self.assertRaises(MsitException): + MsitPath._check_positive_int(0) + with self.assertRaises(MsitException): + MsitPath._check_positive_int(-1) + with self.assertRaises(MsitException): + MsitPath._check_positive_int("not_int") + + @patch("msit.utils.path.os.path.getsize") + @patch("msit.utils.path.os.path.exists") + @patch("msit.utils.path.os.path.islink") + @patch("msit.utils.path.os.path.realpath") + @patch("msit.utils.path.os.stat") + @patch("msit.utils.path.is_file") + @patch("msit.utils.path.is_dir") + def test_check_existing_file( + self, mock_is_dir, mock_is_file, mock_stat, mock_realpath, mock_islink, mock_exists, mock_getsize + ): + mock_exists.return_value = True + mock_is_file.return_value = True + mock_is_dir.return_value = False + mock_islink.return_value = False + mock_realpath.return_value = "/valid/path/file.txt" + mock_getsize.return_value = 512 + stat_mock = MagicMock() + stat_mock.st_uid = 0 + stat_mock.st_mode = 0o755 + mock_stat.return_value = stat_mock + msit_path = MsitPath("/valid/path/file.txt", PathConst.FILE, "r", size_limitation=1024, suffix=".txt") + result = msit_path.check() + self.assertEqual(result, "/valid/path/file.txt") + mock_getsize.assert_called_once_with("/valid/path/file.txt") + + @patch("os.path.exists") + @patch("os.path.islink") + @patch("os.path.realpath") + @patch("msit.utils.path.is_dir") + @patch("os.stat") + def test_check_write_mode_new_file(self, mock_stat, mock_is_dir, mock_realpath, mock_islink, mock_exists): + mock_exists.side_effect = lambda x: False if x == "/new/file.txt" else True + mock_is_dir.return_value = True + mock_islink.return_value = False + mock_realpath.return_value = "/valid/parent" + + mock_stat_result = MagicMock() + mock_stat_result.st_uid = os.geteuid() + mock_stat_result.st_mode = 0o755 + mock_stat.return_value = mock_stat_result + + msit_path = MsitPath("/new/file.txt", PathConst.FILE, "w") + result = msit_path.check(path_exist=False) + self.assertTrue(result.endswith("/new/file.txt")) + + @patch("os.path.abspath") + @patch("os.path.normpath") + @patch("os.path.exists") + @patch("os.path.islink") + @patch("os.path.realpath") + @patch("msit.utils.path.is_file") + @patch("os.stat") + @patch("os.path.getsize") + def test_soft_link_validation( + self, + mock_getsize, + mock_stat, + mock_is_file, + mock_realpath, + mock_islink, + mock_exists, + mock_normpath, + mock_abspath, + ): + mock_abspath.side_effect = lambda x: x + mock_normpath.side_effect = lambda x: x + + mock_stat_result = MagicMock() + mock_stat_result.st_uid = os.geteuid() + mock_stat_result.st_mode = 0o755 + mock_stat.return_value = mock_stat_result + + mock_is_file.return_value = True + mock_exists.return_value = True + mock_islink.return_value = True + mock_realpath.return_value = "/real/path" + mock_getsize.return_value = 512 + msit_path = MsitPath("/symlink/path", PathConst.FILE, "r") + with self.assertRaises(MsitException) as e: + msit_path.check(soft_link_level=SOFT_LINK_LEVEL_STRICT) + self.assertIn("is a symlink. Usage prohibited.", str(e.exception)) + mock_islink.assert_called_with("/symlink/path") + mock_realpath.assert_called_with("/symlink/path") + + @patch("os.path.exists") + @patch("os.path.islink") + @patch("os.path.realpath") + def test_soft_link_non_validation(self, mock_realpath, mock_islink, mock_exists): + mock_exists.return_value = True + mock_islink.return_value = True + mock_realpath.return_value = "/real/path" + msit_path = MsitPath("/symlink/path", PathConst.FILE, "r") + with self.assertRaises(MsitException) as context: + msit_path.check(soft_link_level=4) + self.assertIn("The validation level of symbolic links must be one of ", str(context.exception)) + + @patch("msit.utils.path.is_file") + @patch("msit.utils.path.is_dir") + @patch("os.path.exists") + @patch("os.path.islink") + @patch("os.path.realpath") + @patch("os.stat") + def test_soft_link_ignore_validation( + self, mock_stat, mock_realpath, mock_islink, mock_exists, mock_is_dir, mock_is_file + ): + mock_stat_result = MagicMock() + mock_stat_result.st_uid = os.geteuid() + mock_stat_result.st_mode = 0o755 + mock_stat.return_value = mock_stat_result + mock_exists.return_value = True + mock_islink.return_value = True + mock_realpath.return_value = "/real/path" + mock_is_file.return_value = True + mock_is_dir.return_value = False + msit_path = MsitPath("/symlink/path", PathConst.FILE, "r") + result = msit_path.check(soft_link_level=SOFT_LINK_LEVEL_IGNORE) + self.assertEqual(result, "/real/path") + mock_is_file.assert_called_with("/real/path") + + @patch("msit.utils.path.is_file") + @patch("os.path.exists") + @patch("os.path.islink") + @patch("os.path.realpath") + @patch("os.stat") + def test_path_length_validation(self, mock_stat, mock_realpath, mock_islink, mock_exists, mock_is_file): + mock_stat_result = MagicMock() + mock_stat_result.st_uid = os.geteuid() + mock_stat_result.st_mode = 0o755 + mock_stat.return_value = mock_stat_result + mock_realpath.return_value = "/real/path" + mock_islink.return_value = False + mock_exists.return_value = True + mock_is_file.return_value = True + long_path = "/" + "a" * (_MAX_PATH_LENGTH + 1) + msit_path = MsitPath(long_path, PathConst.FILE, "r") + with self.assertRaises(MsitException) as e: + msit_path.check(path_exist=False) + self.assertIn("Current path length (4098) exceeds the limit (4096).", str(e.exception)) + + @patch("msit.utils.path.is_file") + @patch("os.stat") + @patch("os.geteuid") + def test_permission_validation(self, mock_geteuid, mock_stat, mock_is_file): + mock_geteuid.return_value = 1000 + mock_stat_result = MagicMock() + mock_stat_result.st_uid = 1000 + mock_stat_result.st_mode = 0o777 + mock_stat.return_value = mock_stat_result + mock_is_file.return_value = True + msit_path = MsitPath("/unsafe/path", PathConst.FILE, "r") + with self.assertRaises(MsitException) as e: + msit_path.check() + self.assertIn("Permissions for files (or directories) should not exceed 0o755 (rwxr-xr-x).", str(e.exception)) + + @patch("msit.utils.path.is_file") + @patch("os.stat") + @patch("os.geteuid") + def test_permission_validation(self, mock_geteuid, mock_stat, mock_is_file): + mock_geteuid.return_value = 1000 + mock_stat_result = MagicMock() + mock_stat_result.st_uid = 500 + mock_stat_result.st_mode = 0o777 + mock_stat.return_value = mock_stat_result + mock_is_file.return_value = True + msit_path = MsitPath("/unsafe/path", PathConst.FILE, "r") + with self.assertRaises(MsitException) as e: + msit_path.check() + self.assertIn("The owner of /unsafe/path must be root or the current user.", str(e.exception)) + + @patch("os.stat") + @patch("os.path.exists") + @patch("msit.utils.path.is_file") + def test_read_permission_denied(self, mock_is_file, mock_exists, mock_stat): + mock_exists.return_value = True + mock_is_file.return_value = True + stat_mock = MagicMock() + stat_mock.st_uid = os.geteuid() + stat_mock.st_mode = 0o300 + mock_stat.return_value = stat_mock + msit_path = MsitPath("/no_read.txt", PathConst.FILE, "r") + with self.assertRaises(MsitException) as e: + msit_path.check() + self.assertIn("not authorized to read", str(e.exception)) + + @patch("os.stat") + @patch("os.path.exists") + @patch("msit.utils.path.is_file") + def test_read_permission_granted(self, mock_is_file, mock_exists, mock_stat): + mock_exists.return_value = True + mock_is_file.return_value = True + stat_mock = MagicMock() + stat_mock.st_uid = os.geteuid() + stat_mock.st_mode = 0o400 + mock_stat.return_value = stat_mock + msit_path = MsitPath("/readable.txt", PathConst.FILE, "r") + msit_path.check() + + @patch("os.stat") + @patch("os.path.exists") + @patch("msit.utils.path.is_file") + def test_write_permission_denied(self, mock_is_file, mock_exists, mock_stat): + mock_exists.return_value = True + mock_is_file.return_value = True + stat_mock = MagicMock() + stat_mock.st_uid = os.geteuid() + stat_mock.st_mode = 0o500 + mock_stat.return_value = stat_mock + msit_path = MsitPath("/no_write.txt", PathConst.FILE, "w") + with self.assertRaises(MsitException) as e: + msit_path.check() + self.assertIn("not authorized to write", str(e.exception)) + + @patch("os.stat") + @patch("os.path.exists") + @patch("msit.utils.path.is_file") + def test_write_permission_granted(self, mock_is_file, mock_exists, mock_stat): + mock_exists.return_value = True + mock_is_file.return_value = True + stat_mock = MagicMock() + stat_mock.st_uid = os.geteuid() + stat_mock.st_mode = 0o600 + mock_stat.return_value = stat_mock + msit_path = MsitPath("/writable.txt", PathConst.FILE, "w") + msit_path.check() + + @patch("os.stat") + @patch("os.path.exists") + @patch("msit.utils.path.is_file") + def test_execute_permission_denied(self, mock_is_file, mock_exists, mock_stat): + mock_exists.return_value = True + mock_is_file.return_value = True + stat_mock = MagicMock() + stat_mock.st_uid = os.geteuid() + stat_mock.st_mode = 0o600 + mock_stat.return_value = stat_mock + msit_path = MsitPath("/no_execute.txt", PathConst.FILE, "e") + with self.assertRaises(MsitException) as e: + msit_path.check() + self.assertIn("not authorized to execute", str(e.exception)) + + @patch("os.stat") + @patch("os.path.exists") + @patch("msit.utils.path.is_file") + def test_execute_permission_granted(self, mock_is_file, mock_exists, mock_stat): + mock_exists.return_value = True + mock_is_file.return_value = True + stat_mock = MagicMock() + stat_mock.st_uid = os.geteuid() + stat_mock.st_mode = 0o500 + mock_stat.return_value = stat_mock + msit_path = MsitPath("/executable.txt", PathConst.FILE, "e") + msit_path.check() + + @patch("msit.utils.path.is_dir") + @patch("os.path.exists") + @patch("os.stat") + def test_directory_validation(self, mock_stat, mock_exists, mock_is_dir): + stat_mock = MagicMock() + stat_mock.st_uid = os.geteuid() + stat_mock.st_mode = 0o750 + mock_stat.return_value = stat_mock + mock_exists.return_value = True + mock_is_dir.return_value = True + msit_path = MsitPath("/valid/dir", PathConst.DIR, "r") + msit_path.check() + mock_is_dir.return_value = False + with self.assertRaises(MsitException) as e: + msit_path.check() + self.assertIn("is not a directory", str(e.exception)) + + @patch("msit.utils.path.is_dir") + @patch("os.path.exists") + @patch("os.stat") + def test_special_char_validation(self, mock_stat, mock_exists, mock_is_dir): + stat_mock = MagicMock() + stat_mock.st_uid = os.geteuid() + stat_mock.st_mode = 0o750 + mock_stat.return_value = stat_mock + mock_exists.return_value = True + mock_is_dir.return_value = True + msit_path = MsitPath("/valid/123%", PathConst.DIR, "r") + with self.assertRaises(MsitException) as e: + msit_path.check() + self.assertIn("Path /valid/123% contains special characters.", str(e.exception)) + + @patch("os.path.exists") + @patch("msit.utils.path.is_dir") + def test_directory_depth_validation(self, mock_is_dir, mock_exists): + mock_exists.return_value = True + mock_is_dir.return_value = True + over_depth_path = "/level1/" + "/".join([f"level{i}" for i in range(2, _MAX_DIR_DEPTH + 2)]) + msit_path = MsitPath(over_depth_path, PathConst.DIR, "r") + with self.assertRaises(MsitException) as e: + msit_path.check() + self.assertIn(f"Exceeded max directory depth ({_MAX_DIR_DEPTH})", str(e.exception)) + + @patch("os.path.exists") + @patch("msit.utils.path.is_file") + def test_filename_length_validation(self, mock_is_file, mock_exists): + mock_exists.return_value = True + mock_is_file.return_value = True + long_name = "a" * (_MAX_LAST_NAME_LENGTH + 1) + invalid_path = f"/normal_dir/{long_name}" + msit_path = MsitPath(invalid_path, PathConst.FILE, "r") + with self.assertRaises(MsitException) as e: + msit_path.check() + self.assertIn(f"length ({_MAX_LAST_NAME_LENGTH + 1}) exceeds", str(e.exception)) + + @patch("os.path.exists") + @patch("msit.utils.path.is_dir") + def test_multiple_long_directory_names(self, mock_is_dir, mock_exists): + mock_exists.return_value = True + mock_is_dir.return_value = True + long_dir = "b" * (_MAX_LAST_NAME_LENGTH + 1) + invalid_path = f"/{long_dir}/{long_dir}" + msit_path = MsitPath(invalid_path, PathConst.DIR, "r") + with self.assertRaises(MsitException) as e: + msit_path.check() + self.assertEqual(str(e.exception).count("exceeds the limit"), 1) + + @patch("os.path.exists") + @patch("msit.utils.path.is_dir") + def test_mixed_error_conditions(self, mock_is_dir, mock_exists): + mock_exists.return_value = True + mock_is_dir.return_value = True + long_dir = "c" * (_MAX_LAST_NAME_LENGTH + 1) + deep_path = "/" + "/".join([long_dir] * (_MAX_DIR_DEPTH + 2)) + msit_path = MsitPath(deep_path, PathConst.DIR, "r") + with self.assertRaises(MsitException) as e: + msit_path.check() + self.assertIn("Current path length (8738) exceeds the limit (4096).", str(e.exception)) + + @patch("os.path.exists") + @patch("msit.utils.path.is_dir") + @patch("msit.utils.path.get_dir_size") + def test_directory_size_validation(self, mock_get_dir_size, mock_is_dir, mock_exists): + mock_get_dir_size.return_value = 1024 * 1024 * 1024 + mock_is_dir.return_value = True + mock_exists.return_value = True + msit_path = MsitPath("/large/dir", PathConst.DIR, "r", size_limitation=100) + with self.assertRaises(MsitException) as e: + msit_path.check() + self.assertIn("Directory size exceeds the limit (100 Bytes).", str(e.exception)) + + +class TestGetDirSize(unittest.TestCase): + @patch("os.walk") + @patch("os.path.getsize") + def test_get_dir_size_success(self, mock_getsize, mock_walk): + dir_path = "/test" + mock_walk.return_value = [ + (dir_path, ["sub1"], ["file1", "file2"]), + (os.path.join(dir_path, "sub1"), [], ["file3"]), + ] + mock_getsize.side_effect = [100, 200, 300] + result = get_dir_size(dir_path, max_dir_depth=2) + self.assertEqual(result, 600) + + @patch("os.walk") + def test_get_dir_size_exceed_max_depth(self, mock_walk): + dir_path = "/test" + mock_walk.return_value = [ + (dir_path, ["sub1"], []), + (os.path.join(dir_path, "sub1"), ["sub2"], []), + (os.path.join(dir_path, "sub1", "sub2"), [], ["file"]), + ] + with self.assertRaises(MsitException) as cm: + get_dir_size(dir_path, max_dir_depth=1) + self.assertIn("exceeded max depth (1)", str(cm.exception)) + + +class TestMakeDir(unittest.TestCase): + @patch("msit.utils.path.Path") + @patch("msit.utils.path.MsitPath") + def test_make_dir_success(self, mock_msitpath, mock_path): + mock_msit_instance = MagicMock() + mock_msitpath.return_value = mock_msit_instance + mock_msit_instance.check.return_value = "/valid/dir" + mock_path_instance = MagicMock() + mock_path.return_value = mock_path_instance + make_dir("test_dir") + mock_msitpath.assert_called_once_with("test_dir", PathConst.DIR, "w") + mock_msit_instance.check.assert_called_once_with(path_exist=False) + mock_path.assert_called_once_with("/valid/dir") + mock_path_instance.mkdir.assert_called_once_with(mode=AUTHORITY_DIR, exist_ok=True, parents=False) + + @patch("msit.utils.path.Path") + @patch("msit.utils.path.MsitPath") + def test_make_dir_oserror_parents(self, mock_msitpath, mock_path): + mock_instance = MagicMock() + mock_msitpath.return_value = mock_instance + mock_instance.check.return_value = "/invalid/parent_dir" + mock_path.return_value.mkdir.side_effect = OSError("Parent missing") + with self.assertRaises(MsitException) as cm: + make_dir("bad_dir") + self.assertIn("Check if the parent directory", str(cm.exception)) + + +class TestChangePermission(unittest.TestCase): + @patch("os.chmod") + @patch("os.path.islink") + @patch("os.path.exists") + def test_change_permission_success(self, mock_exists, mock_islink, mock_chmod): + mock_exists.return_value = True + mock_islink.return_value = False + change_permission("/valid/file", 0o755) + mock_chmod.assert_called_once_with("/valid/file", 0o755) + + @patch("os.chmod") + @patch("os.path.islink") + @patch("os.path.exists") + def test_change_permission_skip_symlink(self, mock_exists, mock_islink, mock_chmod): + mock_exists.return_value = True + mock_islink.return_value = True + change_permission("/symlink", 0o755) + mock_chmod.assert_not_called() + + @patch("os.chmod") + @patch("os.path.islink") + @patch("os.path.exists") + def test_change_permission_permission_error(self, mock_exists, mock_islink, mock_chmod): + mock_exists.return_value = True + mock_islink.return_value = False + mock_chmod.side_effect = PermissionError("Permission denied") + with self.assertRaises(MsitException) as cm: + change_permission("/restricted/file", 0o777) + self.assertIn("Failed to set permissions (511)", str(cm.exception)) + self.assertIn("/restricted/file", str(cm.exception)) + + +class TestDiskSpaceCheck(unittest.TestCase): + @patch("msit.utils.path.disk_usage") + def test_not_enough_space_returns_true(self, mock_disk_usage: MagicMock): + mock_result = MagicMock() + mock_result.free = 1000 + mock_disk_usage.return_value = mock_result + result = is_enough_disk_space("/test/path", 500) + self.assertTrue(result) + mock_disk_usage.assert_called_once_with("/test/path") + + @patch("msit.utils.path.disk_usage") + def test_enough_space_returns_false(self, mock_disk_usage: MagicMock): + mock_result = MagicMock() + mock_result.free = 1000 + mock_disk_usage.return_value = mock_result + result = is_enough_disk_space("/test/path", 1500) + self.assertFalse(result) + + @patch("msit.utils.path.disk_usage") + def test_exact_space_edge_case(self, mock_disk_usage: MagicMock): + mock_result = MagicMock() + mock_result.free = 1000 + mock_disk_usage.return_value = mock_result + result = is_enough_disk_space("/test/path", 1000) + self.assertTrue(result) diff --git a/accuracy_tools/test/UT/utils_ut/test_toolkits.py b/accuracy_tools/test/UT/utils_ut/test_toolkits.py new file mode 100644 index 00000000000..dbfb83292ae --- /dev/null +++ b/accuracy_tools/test/UT/utils_ut/test_toolkits.py @@ -0,0 +1,446 @@ +import re +import unittest +from unittest.mock import MagicMock, Mock, call, patch + +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.toolkits import ( + CHECK_CSV_LEVEL_IGNORE, + CHECK_CSV_LEVEL_REPLACE, + CHECK_CSV_LEVEL_STRICT, + DistBackend, + filter_cmd, + get_net_output_nodes_from_graph_def, + get_rank, + get_valid_name, + is_input_yes, + run_subprocess, + safely_compute, + sanitize_csv_value, + seed_all, + set_ld_preload, + timestamp_sync, +) + + +class TestToolkitsFunctions(unittest.TestCase): + def test_filter_cmd(self): + cmd = ["ls", "-l", "|", "grep", "test", ">", "output.txt"] + with self.assertRaises(MsprobeException) as cm: + filtered_cmd = filter_cmd(cmd) + self.assertIn("The command contains invalid characters.", str(cm.exception)) + + def test_get_valid_name(self): + self.assertEqual(get_valid_name("/test/file.txt"), "test_file_txt") + self.assertEqual(get_valid_name("path/to:file"), "path_to_file") + + @patch("msprobe.utils.log.logger.warning") + def test_safely_compute(self, mock_warning): + @safely_compute + def divide(a, b): + return a / b + + self.assertEqual(divide(10, 2), 5) + self.assertIsNone(divide(10, 0)) + mock_warning.assert_called() + + @patch("msprobe.utils.log.logger.info") + @patch("subprocess.Popen") + def test_run_subprocess(self, mock_popen, mock_logger): + mock_process = MagicMock() + mock_process.poll.side_effect = [None, None, 0] + mock_process.communicate.return_value = (b"Output", b"") + mock_process.returncode = 0 + mock_popen.return_value.__enter__.return_value = mock_process + result = run_subprocess(["echo", "hello"], check_interval=0.001, capture_output=True) + self.assertEqual(result, "hello\n") + mock_logger.assert_called() + + @patch("msprobe.utils.log.logger.info") + @patch("subprocess.Popen") + def test_run_subprocess_with_none(self, mock_popen, mock_logger): + mock_process = MagicMock() + mock_process.poll.side_effect = [None, None, 0] + mock_process.communicate.return_value = (b"Output", b"") + mock_process.returncode = 0 + mock_popen.return_value.__enter__.return_value = mock_process + result = run_subprocess(["echo", "hello"], check_interval=0.001, capture_output=False) + self.assertIsNone(result) + mock_logger.assert_called() + + @patch("msprobe.utils.log.logger.error") + @patch("msprobe.utils.toolkits.Popen") + def test_run_subprocess_failure(self, mock_popen, mock_logger): + mock_process = MagicMock() + mock_process.poll.side_effect = [None, None, 0] + mock_process.communicate.return_value = (b"", b"Error") + mock_process.returncode = 1 + mock_popen.return_value.__enter__.return_value = mock_process + with self.assertRaises(MsprobeException) as e: + run_subprocess(["wrong_command"], check_interval=0.001, capture_output=True) + self.assertIn("Failed to execute command:", str(e.exception)) + mock_logger.assert_called() + + @patch("msprobe.utils.toolkits.Popen") + def test_run_subprocess_invalid_str_format(self, mock_popen): + mock_process = MagicMock() + mock_process.poll.side_effect = [None, None, 0] + mock_process.communicate.return_value = (b"", b"Error") + mock_process.returncode = 0 + mock_popen.return_value.__enter__.return_value = mock_process + with self.assertRaises(MsprobeException) as e: + run_subprocess("wrong_command", check_interval=0.001, capture_output=True) + self.assertIn("[ERROR] invalid data type. `cmd` must be a list of strings.", str(e.exception)) + + @patch("msprobe.utils.toolkits.Popen") + def test_run_subprocess_invalid_minus_format(self, mock_popen): + mock_process = MagicMock() + mock_process.poll.side_effect = [None, None, 0] + mock_process.communicate.return_value = (b"", b"Error") + mock_process.returncode = 1 + mock_popen.return_value.__enter__.return_value = mock_process + with self.assertRaises(MsprobeException) as e: + run_subprocess(["ls"], check_interval=-1, capture_output=True) + self.assertIn("[ERROR] invalid data type. `check_interval` must be a non-negative number.", str(e.exception)) + + +class TestDistBackend(unittest.TestCase): + def setUp(self): + self.patcher = patch("msprobe.utils.env.evars.get") + self.mock_evars_get = self.patcher.start() + self.mock_torch = MagicMock() + DistBackend.torch = self.mock_torch + + def tearDown(self): + self.patcher.stop() + DistBackend.torch = None + + def test_get_visible_device_valid(self): + self.mock_evars_get.return_value = "0,1" + result = DistBackend._get_visible_device("CUDA_VISIBLE_DEVICES") + self.assertEqual(result, 0) + self.mock_evars_get.assert_called_once_with("CUDA_VISIBLE_DEVICES", "0") + + def test_get_visible_device_invalid(self): + self.mock_evars_get.return_value = "invalid" + with self.assertRaises(MsprobeException) as context: + DistBackend._get_visible_device("ASCEND_VISIBLE_DEVICES") + self.assertIn("Please check the value", str(context.exception)) + + def test_is_device_available_npu_available(self): + self.mock_evars_get.return_value = "0" + self.mock_torch.npu.is_available.return_value = True + result = DistBackend._is_device_available("npu", "ASCEND_VISIBLE_DEVICES") + self.assertTrue(result) + self.mock_evars_get.assert_called_with("ASCEND_VISIBLE_DEVICES", "0") + + def test_is_device_available_npu_unavailable(self): + self.mock_evars_get.return_value = "0" + self.mock_torch.npu.is_available.return_value = False + result = DistBackend._is_device_available("npu", "ASCEND_VISIBLE_DEVICES") + self.assertFalse(result) + + def test_is_device_available_cuda_available(self): + self.mock_evars_get.return_value = "0" + self.mock_torch.cuda.is_available.return_value = True + result = DistBackend._is_device_available("cuda", "CUDA_VISIBLE_DEVICES") + self.assertTrue(result) + self.mock_evars_get.assert_called_with("CUDA_VISIBLE_DEVICES", "0") + + def test_is_device_available_cpu(self): + self.assertTrue(DistBackend._is_device_available("cpu", "")) + + @patch.object(DistBackend, "_is_device_available") + def test_get_global_device_priority(self, mock_is_available): + mock_is_available.side_effect = lambda device, _: device == "npu" + self.assertEqual(DistBackend._get_global_device(), "npu") + mock_is_available.side_effect = lambda device, _: device == "cuda" + self.assertEqual(DistBackend._get_global_device(), "cuda") + mock_is_available.side_effect = lambda device, _: False + self.assertEqual(DistBackend._get_global_device(), "cpu") + + def test_get_global_device_npu_available(self): + self.mock_evars_get.return_value = "0" + self.mock_torch.npu.is_available.return_value = True + self.mock_torch.cuda.is_available.return_value = False + self.assertEqual(DistBackend._get_global_device(), "npu") + + def test_get_global_device_fallback_to_cpu(self): + self.mock_torch.npu.is_available.return_value = False + self.mock_torch.cuda.is_available.return_value = False + self.assertEqual(DistBackend._get_global_device(), "cpu") + + @patch.object(DistBackend, "_get_global_device") + def test_get_method(self, mock_get_global): + mock_get_global.return_value = "npu" + self.assertEqual(DistBackend.get(), "hccl") + mock_get_global.return_value = "cuda" + self.assertEqual(DistBackend.get(), "nccl") + mock_get_global.return_value = "cpu" + self.assertEqual(DistBackend.get(), "gloo") + mock_get_global.return_value = "unknown" + self.assertEqual(DistBackend.get(), "cpu") + + +class TestTimestampSync(unittest.TestCase): + @patch("msprobe.utils.dependencies.dependent.get") + @patch("msprobe.utils.env.evars.get") + def test_timestamp_sync(self, mock_evars_get, mock_dependent_get): + mock_evars_get.side_effect = lambda key, default, typ=int: typ(default) + mock_torch = MagicMock() + mock_torch.distributed.is_initialized.return_value = False + mock_dependent_get.return_value = mock_torch + result = timestamp_sync(123456) + self.assertEqual(result, 123456) + + @patch("msprobe.utils.toolkits.dependent.get") + @patch("msprobe.utils.toolkits.evars.get") + def test_single_process_returns_original(self, mock_evars_get, mock_dependent_get): + mock_evars_get.return_value = 1 + mock_dependent_get.return_value = None + + result = timestamp_sync(12345) + self.assertEqual(result, 12345) + + @patch("msprobe.utils.toolkits.dependent.get") + @patch("msprobe.utils.toolkits.evars.get") + @patch("msprobe.utils.toolkits.DistBackend.get") + def test_distributed_sync_with_init(self, mock_backend_get, mock_evars_get, mock_dependent_get): + mock_evars_get.side_effect = lambda key, default, *_: {"LOCAL_WORLD_SIZE": 4, "LOCAL_RANK": 2}.get(key, default) + mock_torch = MagicMock() + mock_tensor = MagicMock() + mock_tensor.item.return_value = 54321 + mock_torch.tensor.return_value = mock_tensor + + mock_dist = MagicMock() + mock_dist.is_initialized.return_value = False + mock_dist.ReduceOp.MAX = "MAX" + mock_torch.distributed = mock_dist + + mock_dependent_get.return_value = mock_torch + mock_backend_get.return_value = "nccl" + + result = timestamp_sync(12345) + self.assertEqual(result, 54321) + mock_torch.tensor.assert_called_once_with(12345) + mock_dist.init_process_group.assert_called_once_with(backend="nccl", rank=2, world_size=4) + mock_dist.all_reduce.assert_called_once_with(mock_tensor, op="MAX") + mock_tensor.item.assert_called_once() + + @patch("msprobe.utils.toolkits.dependent.get") + @patch("msprobe.utils.toolkits.evars.get") + def test_already_initialized(self, mock_evars_get, mock_dependent_get): + mock_evars_get.side_effect = lambda key, default, *_: {"LOCAL_WORLD_SIZE": 4, "LOCAL_RANK": 2}.get(key, default) + mock_torch = MagicMock() + mock_tensor = MagicMock() + mock_tensor.item.return_value = 54321 + mock_torch.tensor.return_value = mock_tensor + + mock_dist = MagicMock() + mock_dist.is_initialized.return_value = True + mock_dist.ReduceOp.MAX = "MAX" + mock_torch.distributed = mock_dist + + mock_dependent_get.return_value = mock_torch + result = timestamp_sync(12345) + self.assertEqual(result, 54321) + mock_dist.init_process_group.assert_not_called() + + @patch("msprobe.utils.toolkits.dependent.get") + @patch("msprobe.utils.toolkits.evars.get") + def test_no_torch_returns_original(self, mock_evars_get, mock_dependent_get): + mock_evars_get.return_value = 4 + mock_dependent_get.return_value = None + result = timestamp_sync(12345) + self.assertEqual(result, 12345) + + +class TestGetRank(unittest.TestCase): + @patch("msprobe.utils.dependencies.dependent.get") + def test_torch_initialized_returns_rank(self, mock_dependent_get): + mock_torch = MagicMock() + mock_torch.distributed.is_initialized.return_value = True + mock_torch.distributed.get_rank.return_value = 2 + mock_dependent_get.return_value = mock_torch + result = get_rank() + self.assertEqual(result, "2") + + @patch("msprobe.utils.dependencies.dependent.get") + def test_torch_not_initialized_returns_empty(self, mock_dependent_get): + mock_torch = MagicMock() + mock_torch.distributed.is_initialized.return_value = False + mock_dependent_get.return_value = mock_torch + result = get_rank() + self.assertEqual(result, "") + + @patch("msprobe.utils.dependencies.dependent.get") + def test_no_torch_returns_empty(self, mock_dependent_get): + mock_dependent_get.return_value = None + result = get_rank() + self.assertEqual(result, "") + + +class TestSeedAll(unittest.TestCase): + @patch("msprobe.utils.toolkits.evars.set") + @patch("msprobe.utils.toolkits.seed") + @patch("numpy.random.seed") + @patch("msprobe.utils.toolkits.dependent.get") + @patch("msprobe.utils.toolkits.logger.info") + def test_seed_all_with_full_deps( + self, mock_logger, mock_dependent_get, mock_np_seed, mock_random_seed, mock_evars_set + ): + mock_torch = MagicMock() + mock_torch.version.cuda = "10.2.100" + mock_torch.cuda = MagicMock() + mock_torch.backends = MagicMock() + mock_torch_npu = MagicMock() + + mock_dependent_get.side_effect = lambda x: {"torch": mock_torch, "torch_npu": mock_torch_npu}.get(x) + + seed_all(666) + expected_evar_calls = [ + call("LCCL_DETERMINISTIC", "1"), + call("HCCL_DETERMINISTIC", "true"), + call("PYTHONHASHSEED", "666"), + call("ATB_MATMUL_SHUFFLE_K_ENABLE", "0"), + call("ATB_LLM_LCOC_ENABLE", "0"), + call("CUBLAS_WORKSPACE_CONFIG", ":4096:8"), + ] + mock_evars_set.assert_has_calls(expected_evar_calls, any_order=True) + mock_random_seed.assert_called_once_with(666) + mock_np_seed.assert_called_once_with(666) + + mock_torch.manual_seed.assert_called_once_with(666) + mock_torch.use_deterministic_algorithms.assert_called_once_with(mode=True) + mock_torch.cuda.manual_seed.assert_called_once_with(666) + mock_torch.cuda.manual_seed_all.assert_called_once_with(666) + mock_torch.backends.cudnn.deterministic = True + mock_torch.backends.cudnn.enable = False + mock_torch.backends.cudnn.benchmark = False + + mock_torch_npu.npu.manual_seed.assert_called_once_with(666) + mock_torch_npu.npu.manual_seed_all.assert_called_once_with(666) + + mock_logger.assert_called_once_with("Enable deterministic computation sucess! current seed is 666.") + + @patch("msprobe.utils.toolkits.evars.set") + @patch("msprobe.utils.toolkits.dependent.get") + def test_seed_all_without_cuda(self, mock_dependent_get, mock_evars_set): + mock_torch = MagicMock() + del mock_torch.cuda + mock_torch.version.cuda = None + mock_dependent_get.return_value = mock_torch + seed_all(666) + self.assertFalse(hasattr(mock_torch, "cuda")) + cublas_calls = [call.args for call in mock_evars_set.mock_calls if call.args[0] == "CUBLAS_WORKSPACE_CONFIG"] + self.assertEqual(len(cublas_calls), 0) + + +class MockNode: + def __init__(self, name, inputs): + self.name = name + self.input = inputs + + +class TestGetNetOutputNodes(unittest.TestCase): + def test_single_output_node(self): + node_a = MockNode("A", []) + node_b = MockNode("B", ["A"]) + node_c = MockNode("C", ["B"]) + graph_def = Mock() + graph_def.node = [node_a, node_b, node_c] + result = get_net_output_nodes_from_graph_def(graph_def) + self.assertEqual(result, ["C"]) + + def test_multiple_output_nodes(self): + graph_def = MagicMock() + node_a = MockNode("A", []) + node_b = MockNode("B", ["A"]) + node_c = MockNode("C", ["A"]) + graph_def.node = [node_a, node_b, node_c] + result = get_net_output_nodes_from_graph_def(graph_def) + self.assertCountEqual(result, ["B", "C"]) + + def test_empty_graph(self): + graph_def = MagicMock() + graph_def.node = [] + result = get_net_output_nodes_from_graph_def(graph_def) + self.assertEqual(result, []) + + +class TestSanitizeCsvValue(unittest.TestCase): + def test_sanitize_csv_value_ignore(self): + value = "malicious;value" + result = sanitize_csv_value(value, CHECK_CSV_LEVEL_IGNORE) + self.assertEqual(result, value) + + def test_sanitize_csv_value_non_string(self): + value = 123 + result = sanitize_csv_value(value, CHECK_CSV_LEVEL_STRICT) + self.assertEqual(result, value) + + def test_sanitize_csv_value_safe_number(self): + value = "3.14" + result = sanitize_csv_value(value, CHECK_CSV_LEVEL_STRICT) + self.assertEqual(result, value) + + @patch("msprobe.utils.toolkits._MALICIOUS_CSV_PATTERN", re.compile(r";")) + def test_sanitize_csv_value_malicious_strict(self): + value = "malicious;value" + with self.assertRaises(MsprobeException) as e: + sanitize_csv_value(value, CHECK_CSV_LEVEL_STRICT) + self.assertIn( + "Malicious value detected: malicious;value, please check the value written to the csv.", str(e.exception) + ) + + @patch("msprobe.utils.toolkits._MALICIOUS_CSV_PATTERN", re.compile(r";")) + def test_sanitize_csv_value_malicious_replace(self): + value = "malicious;value" + result = sanitize_csv_value(value, CHECK_CSV_LEVEL_REPLACE) + self.assertEqual(result, "") + + +class TestIsInputYes(unittest.TestCase): + @patch("builtins.input", return_value="yes") + def test_is_input_yes_positive(self, mock_input): + self.assertTrue(is_input_yes("Prompt: ")) + + @patch("builtins.input", return_value="no") + def test_is_input_yes_negative(self, mock_input): + self.assertFalse(is_input_yes("Prompt: ")) + + @patch("builtins.input", return_value=" YES ") + def test_is_input_yes_whitespace(self, mock_input): + self.assertTrue(is_input_yes("Prompt: ")) + + @patch("builtins.input", side_effect=KeyboardInterrupt) + @patch("msprobe.utils.toolkits.logger.info") + def test_is_input_yes_interrupted(self, mock_logger, mock_input): + self.assertFalse(is_input_yes("Prompt: ")) + mock_logger.assert_called_with('Input interrupted. Defaulting to "no".') + + +class TestSetLdPreload(unittest.TestCase): + @patch("msprobe.utils.toolkits.evars") + @patch("msprobe.utils.toolkits.logger") + def test_existing_ld_preload_updates_value(self, mock_logger: MagicMock, mock_evars: MagicMock): + mock_evars.get.return_value = "existing_lib.so" + set_ld_preload("new_lib.so") + mock_evars.get.assert_called_once_with("LD_PRELOAD", required=False) + mock_evars.set.assert_called_once_with("LD_PRELOAD", "new_lib.so:existing_lib.so") + mock_logger.info.assert_called_once_with("Environment updated with .so library new_lib.so.") + + @patch("msprobe.utils.toolkits.evars") + @patch("msprobe.utils.toolkits.logger") + def test_no_existing_ld_preload_sets_new_value(self, mock_logger: MagicMock, mock_evars: MagicMock): + mock_evars.get.return_value = None + set_ld_preload("new_lib.so") + mock_evars.get.assert_called_once_with("LD_PRELOAD", required=False) + mock_evars.set.assert_called_once_with("LD_PRELOAD", "new_lib.so") + mock_logger.info.assert_called_once_with("Environment updated with .so library new_lib.so.") + + @patch("msprobe.utils.toolkits.evars") + @patch("msprobe.utils.toolkits.logger") + def test_empty_ld_preload_sets_new_value(self, mock_logger: MagicMock, mock_evars: MagicMock): + mock_evars.get.return_value = "" + set_ld_preload("new_lib.so") + mock_evars.set.assert_called_once_with("LD_PRELOAD", "new_lib.so") diff --git a/accuracy_tools/third_party/.keep b/accuracy_tools/third_party/.keep new file mode 100644 index 00000000000..e69de29bb2d -- Gitee